fix setattr for index and data

This commit is contained in:
sneakers-the-rat 2024-08-14 23:03:03 -07:00
parent 54409c7b28
commit 7e7cbc1ac1
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 46 additions and 14 deletions

View file

@ -65,10 +65,6 @@ class DynamicTableMixin(BaseModel):
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
return {k: getattr(self, k) for i, k in enumerate(self.colnames)}
@property
def _columns_list(self) -> List[Union[list, "NDArray", "VectorDataMixin"]]:
return [getattr(self, k) for i, k in enumerate(self.colnames)]
@overload
def __getitem__(self, item: str) -> Union[list, "NDArray", "VectorDataMixin"]: ...
@ -283,7 +279,7 @@ class DynamicTableMixin(BaseModel):
model[key] = VectorIndex(name=key, description="", value=val)
else:
model[key] = VectorData(name=key, description="", value=val)
except ValidationError as e:
except ValidationError as e: # pragma: no cover
raise ValidationError(
f"field {key} cannot be cast to VectorData from {val}"
) from e
@ -423,24 +419,24 @@ class VectorIndexMixin(BaseModel, Generic[T]):
kwargs["value"] = value
super().__init__(**kwargs)
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
def _slice(self, arg: int) -> slice:
"""
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
"""
start = 0 if arg == 0 else self.value[arg - 1]
end = self.value[arg]
return self.target.value[slice(start, end)]
return slice(start, end)
def __getitem__(self, item: Union[int, slice, Iterable]) -> Any:
if self.target is None:
return self.value[item]
else:
if isinstance(item, (int, np.integer)):
return self._getitem_helper(item)
return self.target.value[self._slice(item)]
elif isinstance(item, (slice, Iterable)):
if isinstance(item, slice):
item = range(*item.indices(len(self.value)))
return [self._getitem_helper(i) for i in item]
return [self.target.value[self._slice(i)] for i in item]
else:
raise AttributeError(f"Could not index with {item}")
@ -458,8 +454,27 @@ class VectorIndexMixin(BaseModel, Generic[T]):
"""
if self.target:
# __getitem__ will return the indexed reference to the target
self[key] = value
if isinstance(key, (int, np.integer)):
self.target.value[self._slice(key)] = value
elif isinstance(key, (slice, Iterable)):
if isinstance(key, slice):
key = range(*key.indices(len(self.value)))
if isinstance(value, Iterable):
if len(key) != len(value):
raise ValueError(
"Can only assign equal-length iterable to a slice, manually index the"
" ragged values of of the target VectorData object if you need more"
" control"
)
for i, subval in zip(key, value):
self.target.value[self._slice(i)] = subval
else:
for i in key:
self.target.value[self._slice(i)] = value
else: # pragma: no cover
raise AttributeError(f"Could not index with {key}")
else:
self.value[key] = value

View file

@ -241,7 +241,7 @@ def test_dynamictable_resolve_index():
assert inst.new_col_2._index is inst.new_col_2_index
def dynamictable_assert_equal_length():
def test_dynamictable_assert_equal_length():
"""
Dynamictable validates that columns are of equal length
"""
@ -277,7 +277,7 @@ def dynamictable_assert_equal_length():
"new_col_1": hdmf.VectorData(value=np.arange(100)),
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5),
}
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
with pytest.raises(ValidationError, match="Columns are not of equal length"):
_ = MyDT(**cols)
@ -324,6 +324,15 @@ def test_vectordata_indexing():
assert data[0] == 1
data[0] = 0
# indexes by themselves are the same
index_notarget = hdmf.VectorIndex(value=index_array)
assert index_notarget[0] == index_array[0]
assert all(index_notarget[0:3] == index_array[0:3])
oldval = index_array[0]
index_notarget[0] = 5
assert index_notarget[0] == 5
index_notarget[0] = oldval
index = hdmf.VectorIndex(value=index_array, target=data)
data._index = index
@ -338,8 +347,16 @@ def test_vectordata_indexing():
assert all(subitem == i)
# setting uses the same indexing logic
data[0][:] = 5
data[0] = 5
assert all(data[0] == 5)
data[0:3] = [5, 4, 3]
assert all(data[0] == 5)
assert all(data[1] == 4)
assert all(data[2] == 3)
data[0:3] = 6
assert all(data[0] == 6)
assert all(data[1] == 6)
assert all(data[2] == 6)
def test_vectordata_getattr():