diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 0c4e7ce..7fa3c08 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -748,9 +748,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin): for index, span, and timeseries """ - idx_start: NDArray[Any, int] - count: NDArray[Any, int] - timeseries: NDArray[Any, BaseModel] + idx_start: NDArray[Shape["*"], int] + count: NDArray[Shape["*"], int] + timeseries: NDArray @model_validator(mode="after") def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin": @@ -789,11 +789,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin): ) if isinstance(item, (int, np.integer)): - return self.timeseries[self._slice_helper(item)] - elif isinstance(item, slice): - return [self.timeseries[subitem] for subitem in self._slice_helper(item)] - elif isinstance(item, Iterable): - return [self.timeseries[self._slice_helper(subitem)] for subitem in item] + return self.timeseries[item][self._slice_helper(item)] + elif isinstance(item, (slice, Iterable)): + if isinstance(item, slice): + item = range(*item.indices(len(self.idx_start))) + return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item] else: raise ValueError( f"Dont know how to index with {item}, must be an int, slice, or iterable" @@ -806,13 +806,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin): " never done in the core schema." ) if isinstance(key, (int, np.integer)): - self.timeseries[self._slice_helper(key)] = value - elif isinstance(key, slice): - for subitem in self._slice_helper(key): - self.timeseries[subitem] = value - elif isinstance(key, Iterable): - for subitem in key: - self.timeseries[self._slice_helper(subitem)] = value + self.timeseries[key][self._slice_helper(key)] = value + elif isinstance(key, (slice, Iterable)): + if isinstance(key, slice): + key = range(*key.indices(len(self.idx_start))) + + if isinstance(value, Iterable): + if len(key) != len(value): + raise ValueError( + "Can only assign equal-length iterable to a slice, manually index the" + " target Timeseries object if you need more control" + ) + for subitem, subvalue in zip(key, value): + self.timeseries[subitem][self._slice_helper(subitem)] = subvalue + else: + for subitem in key: + self.timeseries[subitem][self._slice_helper(subitem)] = value else: raise ValueError( f"Dont know how to index with {key}, must be an int, slice, or iterable" @@ -898,3 +907,6 @@ if "pytest" in sys.modules: """DynamicTableRegion subclass for testing""" pass + + class TimeSeriesReferenceVectorData(TimeSeriesReferenceVectorDataMixin): + pass diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index fb9a3e2..6c5d51a 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -611,6 +611,35 @@ def test_mixed_aligned_dynamictable(aligned_table): assert len(array) == index_array[i] +def test_timeseriesreferencevectordata_index(): + """ + TimeSeriesReferenceVectorData should be able to do the thing it does + """ + generator = np.random.default_rng() + timeseries = np.array([np.arange(100)] * 10) + + counts = generator.integers(1, 10, (10,)) + idx_start = np.arange(0, 100, 10) + + response = hdmf.TimeSeriesReferenceVectorData( + idx_start=idx_start, + count=counts, + timeseries=timeseries, + ) + for i in range(len(counts)): + assert len(response[i]) == counts[i] + items = response[3:5] + assert all(items[0] == timeseries[3][idx_start[3] : idx_start[3] + counts[3]]) + assert all(items[1] == timeseries[4][idx_start[4] : idx_start[4] + counts[4]]) + + response[0] = np.zeros((counts[0],)) + assert all(response[0] == 0) + + response[1:3] = [np.zeros((counts[1],)), np.ones((counts[2],))] + assert all(response[1] == 0) + assert all(response[2] == 1) + + # -------------------------------------------------- # Model-based tests # -------------------------------------------------- @@ -623,7 +652,6 @@ def test_dynamictable_indexing_electricalseries(electrical_series): series, electrodes = electrical_series colnames = [ - "id", "x", "y", "group", @@ -632,20 +660,19 @@ def test_dynamictable_indexing_electricalseries(electrical_series): "extra_column", ] dtypes = [ - np.dtype("int64"), np.dtype("float64"), np.dtype("float64"), ] + ([np.dtype("O")] * 4) row = electrodes[0] # successfully get a single row :) - assert row.shape == (1, 7) + assert row.shape == (1, 6) assert row.dtypes.values.tolist() == dtypes assert row.columns.tolist() == colnames # slice a range of rows rows = electrodes[0:3] - assert rows.shape == (3, 7) + assert rows.shape == (3, 6) assert rows.dtypes.values.tolist() == dtypes assert rows.columns.tolist() == colnames @@ -656,7 +683,7 @@ def test_dynamictable_indexing_electricalseries(electrical_series): # get a single cell val = electrodes[0, "y"] assert val == 5 - val = electrodes[0, 2] + val = electrodes[0, 1] assert val == 5 # get a slice of rows and columns @@ -698,8 +725,8 @@ def test_dynamictable_region_basic_electricalseries(electrical_series): # b) every other object in the chain is strictly validated, # so we assume if we got a right shaped df that it is the correct one. # feel free to @ me when i am wrong about this - assert all(row.id == 4) - assert row.shape == (1, 7) + assert all(row.index == 4) + assert row.shape == (1, 6) # and we should still be preserving the model that is the contents of the cell of this row # so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup # objects and that's as far as we are going to chase the recursion in this basic indexing test @@ -709,7 +736,7 @@ def test_dynamictable_region_basic_electricalseries(electrical_series): # getting a list of table rows is actually correct behavior here because # this list of table rows is actually the cell of another table rows = series.electrodes[0:3] - assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])]) + assert all([all(row.index == idx) for row, idx in zip(rows, [4, 3, 2])]) def test_aligned_dynamictable_ictable(intracellular_recordings_table):