reference vector series tests, fix model tests

This commit is contained in:
sneakers-the-rat 2024-08-15 01:43:42 -07:00
parent 10965743eb
commit 24494b8ee4
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 62 additions and 23 deletions

View file

@ -748,9 +748,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
for index, span, and timeseries for index, span, and timeseries
""" """
idx_start: NDArray[Any, int] idx_start: NDArray[Shape["*"], int]
count: NDArray[Any, int] count: NDArray[Shape["*"], int]
timeseries: NDArray[Any, BaseModel] timeseries: NDArray
@model_validator(mode="after") @model_validator(mode="after")
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin": def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
@ -789,11 +789,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
) )
if isinstance(item, (int, np.integer)): if isinstance(item, (int, np.integer)):
return self.timeseries[self._slice_helper(item)] return self.timeseries[item][self._slice_helper(item)]
elif isinstance(item, slice): elif isinstance(item, (slice, Iterable)):
return [self.timeseries[subitem] for subitem in self._slice_helper(item)] if isinstance(item, slice):
elif isinstance(item, Iterable): item = range(*item.indices(len(self.idx_start)))
return [self.timeseries[self._slice_helper(subitem)] for subitem in item] return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item]
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {item}, must be an int, slice, or iterable" f"Dont know how to index with {item}, must be an int, slice, or iterable"
@ -806,13 +806,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
" never done in the core schema." " never done in the core schema."
) )
if isinstance(key, (int, np.integer)): if isinstance(key, (int, np.integer)):
self.timeseries[self._slice_helper(key)] = value self.timeseries[key][self._slice_helper(key)] = value
elif isinstance(key, slice): elif isinstance(key, (slice, Iterable)):
for subitem in self._slice_helper(key): if isinstance(key, slice):
self.timeseries[subitem] = value key = range(*key.indices(len(self.idx_start)))
elif isinstance(key, Iterable):
for subitem in key: if isinstance(value, Iterable):
self.timeseries[self._slice_helper(subitem)] = value if len(key) != len(value):
raise ValueError(
"Can only assign equal-length iterable to a slice, manually index the"
" target Timeseries object if you need more control"
)
for subitem, subvalue in zip(key, value):
self.timeseries[subitem][self._slice_helper(subitem)] = subvalue
else:
for subitem in key:
self.timeseries[subitem][self._slice_helper(subitem)] = value
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {key}, must be an int, slice, or iterable" f"Dont know how to index with {key}, must be an int, slice, or iterable"
@ -898,3 +907,6 @@ if "pytest" in sys.modules:
"""DynamicTableRegion subclass for testing""" """DynamicTableRegion subclass for testing"""
pass pass
class TimeSeriesReferenceVectorData(TimeSeriesReferenceVectorDataMixin):
pass

View file

@ -611,6 +611,35 @@ def test_mixed_aligned_dynamictable(aligned_table):
assert len(array) == index_array[i] assert len(array) == index_array[i]
def test_timeseriesreferencevectordata_index():
"""
TimeSeriesReferenceVectorData should be able to do the thing it does
"""
generator = np.random.default_rng()
timeseries = np.array([np.arange(100)] * 10)
counts = generator.integers(1, 10, (10,))
idx_start = np.arange(0, 100, 10)
response = hdmf.TimeSeriesReferenceVectorData(
idx_start=idx_start,
count=counts,
timeseries=timeseries,
)
for i in range(len(counts)):
assert len(response[i]) == counts[i]
items = response[3:5]
assert all(items[0] == timeseries[3][idx_start[3] : idx_start[3] + counts[3]])
assert all(items[1] == timeseries[4][idx_start[4] : idx_start[4] + counts[4]])
response[0] = np.zeros((counts[0],))
assert all(response[0] == 0)
response[1:3] = [np.zeros((counts[1],)), np.ones((counts[2],))]
assert all(response[1] == 0)
assert all(response[2] == 1)
# -------------------------------------------------- # --------------------------------------------------
# Model-based tests # Model-based tests
# -------------------------------------------------- # --------------------------------------------------
@ -623,7 +652,6 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
series, electrodes = electrical_series series, electrodes = electrical_series
colnames = [ colnames = [
"id",
"x", "x",
"y", "y",
"group", "group",
@ -632,20 +660,19 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
"extra_column", "extra_column",
] ]
dtypes = [ dtypes = [
np.dtype("int64"),
np.dtype("float64"), np.dtype("float64"),
np.dtype("float64"), np.dtype("float64"),
] + ([np.dtype("O")] * 4) ] + ([np.dtype("O")] * 4)
row = electrodes[0] row = electrodes[0]
# successfully get a single row :) # successfully get a single row :)
assert row.shape == (1, 7) assert row.shape == (1, 6)
assert row.dtypes.values.tolist() == dtypes assert row.dtypes.values.tolist() == dtypes
assert row.columns.tolist() == colnames assert row.columns.tolist() == colnames
# slice a range of rows # slice a range of rows
rows = electrodes[0:3] rows = electrodes[0:3]
assert rows.shape == (3, 7) assert rows.shape == (3, 6)
assert rows.dtypes.values.tolist() == dtypes assert rows.dtypes.values.tolist() == dtypes
assert rows.columns.tolist() == colnames assert rows.columns.tolist() == colnames
@ -656,7 +683,7 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
# get a single cell # get a single cell
val = electrodes[0, "y"] val = electrodes[0, "y"]
assert val == 5 assert val == 5
val = electrodes[0, 2] val = electrodes[0, 1]
assert val == 5 assert val == 5
# get a slice of rows and columns # get a slice of rows and columns
@ -698,8 +725,8 @@ def test_dynamictable_region_basic_electricalseries(electrical_series):
# b) every other object in the chain is strictly validated, # b) every other object in the chain is strictly validated,
# so we assume if we got a right shaped df that it is the correct one. # so we assume if we got a right shaped df that it is the correct one.
# feel free to @ me when i am wrong about this # feel free to @ me when i am wrong about this
assert all(row.id == 4) assert all(row.index == 4)
assert row.shape == (1, 7) assert row.shape == (1, 6)
# and we should still be preserving the model that is the contents of the cell of this row # and we should still be preserving the model that is the contents of the cell of this row
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup # so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
# objects and that's as far as we are going to chase the recursion in this basic indexing test # objects and that's as far as we are going to chase the recursion in this basic indexing test
@ -709,7 +736,7 @@ def test_dynamictable_region_basic_electricalseries(electrical_series):
# getting a list of table rows is actually correct behavior here because # getting a list of table rows is actually correct behavior here because
# this list of table rows is actually the cell of another table # this list of table rows is actually the cell of another table
rows = series.electrodes[0:3] rows = series.electrodes[0:3]
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])]) assert all([all(row.index == idx) for row, idx in zip(rows, [4, 3, 2])])
def test_aligned_dynamictable_ictable(intracellular_recordings_table): def test_aligned_dynamictable_ictable(intracellular_recordings_table):