From 7e7cbc1ac16f06e4f2528308aecb669219b1e75d Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 14 Aug 2024 23:03:03 -0700 Subject: [PATCH] fix setattr for index and data --- nwb_linkml/src/nwb_linkml/includes/hdmf.py | 37 +++++++++++++++------ nwb_linkml/tests/test_includes/test_hdmf.py | 23 +++++++++++-- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 9383be9..82022de 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -65,10 +65,6 @@ class DynamicTableMixin(BaseModel): def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]: return {k: getattr(self, k) for i, k in enumerate(self.colnames)} - @property - def _columns_list(self) -> List[Union[list, "NDArray", "VectorDataMixin"]]: - return [getattr(self, k) for i, k in enumerate(self.colnames)] - @overload def __getitem__(self, item: str) -> Union[list, "NDArray", "VectorDataMixin"]: ... @@ -283,7 +279,7 @@ class DynamicTableMixin(BaseModel): model[key] = VectorIndex(name=key, description="", value=val) else: model[key] = VectorData(name=key, description="", value=val) - except ValidationError as e: + except ValidationError as e: # pragma: no cover raise ValidationError( f"field {key} cannot be cast to VectorData from {val}" ) from e @@ -423,24 +419,24 @@ class VectorIndexMixin(BaseModel, Generic[T]): kwargs["value"] = value super().__init__(**kwargs) - def _getitem_helper(self, arg: int) -> Union[list, NDArray]: + def _slice(self, arg: int) -> slice: """ Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper` """ start = 0 if arg == 0 else self.value[arg - 1] end = self.value[arg] - return self.target.value[slice(start, end)] + return slice(start, end) def __getitem__(self, item: Union[int, slice, Iterable]) -> Any: if self.target is None: return self.value[item] else: if isinstance(item, (int, np.integer)): - return self._getitem_helper(item) + return self.target.value[self._slice(item)] elif isinstance(item, (slice, Iterable)): if isinstance(item, slice): item = range(*item.indices(len(self.value))) - return [self._getitem_helper(i) for i in item] + return [self.target.value[self._slice(i)] for i in item] else: raise AttributeError(f"Could not index with {item}") @@ -458,8 +454,27 @@ class VectorIndexMixin(BaseModel, Generic[T]): """ if self.target: - # __getitem__ will return the indexed reference to the target - self[key] = value + if isinstance(key, (int, np.integer)): + self.target.value[self._slice(key)] = value + elif isinstance(key, (slice, Iterable)): + if isinstance(key, slice): + key = range(*key.indices(len(self.value))) + + if isinstance(value, Iterable): + if len(key) != len(value): + raise ValueError( + "Can only assign equal-length iterable to a slice, manually index the" + " ragged values of of the target VectorData object if you need more" + " control" + ) + for i, subval in zip(key, value): + self.target.value[self._slice(i)] = subval + else: + for i in key: + self.target.value[self._slice(i)] = value + else: # pragma: no cover + raise AttributeError(f"Could not index with {key}") + else: self.value[key] = value diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index 7e34c66..32a4e5f 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -241,7 +241,7 @@ def test_dynamictable_resolve_index(): assert inst.new_col_2._index is inst.new_col_2_index -def dynamictable_assert_equal_length(): +def test_dynamictable_assert_equal_length(): """ Dynamictable validates that columns are of equal length """ @@ -277,7 +277,7 @@ def dynamictable_assert_equal_length(): "new_col_1": hdmf.VectorData(value=np.arange(100)), "new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5), } - with pytest.raises(ValidationError, pattern="Columns are not of equal length"): + with pytest.raises(ValidationError, match="Columns are not of equal length"): _ = MyDT(**cols) @@ -324,6 +324,15 @@ def test_vectordata_indexing(): assert data[0] == 1 data[0] = 0 + # indexes by themselves are the same + index_notarget = hdmf.VectorIndex(value=index_array) + assert index_notarget[0] == index_array[0] + assert all(index_notarget[0:3] == index_array[0:3]) + oldval = index_array[0] + index_notarget[0] = 5 + assert index_notarget[0] == 5 + index_notarget[0] = oldval + index = hdmf.VectorIndex(value=index_array, target=data) data._index = index @@ -338,8 +347,16 @@ def test_vectordata_indexing(): assert all(subitem == i) # setting uses the same indexing logic - data[0][:] = 5 + data[0] = 5 assert all(data[0] == 5) + data[0:3] = [5, 4, 3] + assert all(data[0] == 5) + assert all(data[1] == 4) + assert all(data[2] == 3) + data[0:3] = 6 + assert all(data[0] == 6) + assert all(data[1] == 6) + assert all(data[2] == 6) def test_vectordata_getattr():