mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-10 00:34:29 +00:00
fix setattr for index and data
This commit is contained in:
parent
54409c7b28
commit
7e7cbc1ac1
2 changed files with 46 additions and 14 deletions
|
@ -65,10 +65,6 @@ class DynamicTableMixin(BaseModel):
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
return {k: getattr(self, k) for i, k in enumerate(self.colnames)}
|
return {k: getattr(self, k) for i, k in enumerate(self.colnames)}
|
||||||
|
|
||||||
@property
|
|
||||||
def _columns_list(self) -> List[Union[list, "NDArray", "VectorDataMixin"]]:
|
|
||||||
return [getattr(self, k) for i, k in enumerate(self.colnames)]
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, item: str) -> Union[list, "NDArray", "VectorDataMixin"]: ...
|
def __getitem__(self, item: str) -> Union[list, "NDArray", "VectorDataMixin"]: ...
|
||||||
|
|
||||||
|
@ -283,7 +279,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
model[key] = VectorIndex(name=key, description="", value=val)
|
model[key] = VectorIndex(name=key, description="", value=val)
|
||||||
else:
|
else:
|
||||||
model[key] = VectorData(name=key, description="", value=val)
|
model[key] = VectorData(name=key, description="", value=val)
|
||||||
except ValidationError as e:
|
except ValidationError as e: # pragma: no cover
|
||||||
raise ValidationError(
|
raise ValidationError(
|
||||||
f"field {key} cannot be cast to VectorData from {val}"
|
f"field {key} cannot be cast to VectorData from {val}"
|
||||||
) from e
|
) from e
|
||||||
|
@ -423,24 +419,24 @@ class VectorIndexMixin(BaseModel, Generic[T]):
|
||||||
kwargs["value"] = value
|
kwargs["value"] = value
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _slice(self, arg: int) -> slice:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
"""
|
"""
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.value[slice(start, end)]
|
return slice(start, end)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice, Iterable]) -> Any:
|
def __getitem__(self, item: Union[int, slice, Iterable]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
else:
|
else:
|
||||||
if isinstance(item, (int, np.integer)):
|
if isinstance(item, (int, np.integer)):
|
||||||
return self._getitem_helper(item)
|
return self.target.value[self._slice(item)]
|
||||||
elif isinstance(item, (slice, Iterable)):
|
elif isinstance(item, (slice, Iterable)):
|
||||||
if isinstance(item, slice):
|
if isinstance(item, slice):
|
||||||
item = range(*item.indices(len(self.value)))
|
item = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in item]
|
return [self.target.value[self._slice(i)] for i in item]
|
||||||
else:
|
else:
|
||||||
raise AttributeError(f"Could not index with {item}")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
|
@ -458,8 +454,27 @@ class VectorIndexMixin(BaseModel, Generic[T]):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.target:
|
if self.target:
|
||||||
# __getitem__ will return the indexed reference to the target
|
if isinstance(key, (int, np.integer)):
|
||||||
self[key] = value
|
self.target.value[self._slice(key)] = value
|
||||||
|
elif isinstance(key, (slice, Iterable)):
|
||||||
|
if isinstance(key, slice):
|
||||||
|
key = range(*key.indices(len(self.value)))
|
||||||
|
|
||||||
|
if isinstance(value, Iterable):
|
||||||
|
if len(key) != len(value):
|
||||||
|
raise ValueError(
|
||||||
|
"Can only assign equal-length iterable to a slice, manually index the"
|
||||||
|
" ragged values of of the target VectorData object if you need more"
|
||||||
|
" control"
|
||||||
|
)
|
||||||
|
for i, subval in zip(key, value):
|
||||||
|
self.target.value[self._slice(i)] = subval
|
||||||
|
else:
|
||||||
|
for i in key:
|
||||||
|
self.target.value[self._slice(i)] = value
|
||||||
|
else: # pragma: no cover
|
||||||
|
raise AttributeError(f"Could not index with {key}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
|
|
@ -241,7 +241,7 @@ def test_dynamictable_resolve_index():
|
||||||
assert inst.new_col_2._index is inst.new_col_2_index
|
assert inst.new_col_2._index is inst.new_col_2_index
|
||||||
|
|
||||||
|
|
||||||
def dynamictable_assert_equal_length():
|
def test_dynamictable_assert_equal_length():
|
||||||
"""
|
"""
|
||||||
Dynamictable validates that columns are of equal length
|
Dynamictable validates that columns are of equal length
|
||||||
"""
|
"""
|
||||||
|
@ -277,7 +277,7 @@ def dynamictable_assert_equal_length():
|
||||||
"new_col_1": hdmf.VectorData(value=np.arange(100)),
|
"new_col_1": hdmf.VectorData(value=np.arange(100)),
|
||||||
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5),
|
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5),
|
||||||
}
|
}
|
||||||
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
|
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||||
_ = MyDT(**cols)
|
_ = MyDT(**cols)
|
||||||
|
|
||||||
|
|
||||||
|
@ -324,6 +324,15 @@ def test_vectordata_indexing():
|
||||||
assert data[0] == 1
|
assert data[0] == 1
|
||||||
data[0] = 0
|
data[0] = 0
|
||||||
|
|
||||||
|
# indexes by themselves are the same
|
||||||
|
index_notarget = hdmf.VectorIndex(value=index_array)
|
||||||
|
assert index_notarget[0] == index_array[0]
|
||||||
|
assert all(index_notarget[0:3] == index_array[0:3])
|
||||||
|
oldval = index_array[0]
|
||||||
|
index_notarget[0] = 5
|
||||||
|
assert index_notarget[0] == 5
|
||||||
|
index_notarget[0] = oldval
|
||||||
|
|
||||||
index = hdmf.VectorIndex(value=index_array, target=data)
|
index = hdmf.VectorIndex(value=index_array, target=data)
|
||||||
data._index = index
|
data._index = index
|
||||||
|
|
||||||
|
@ -338,8 +347,16 @@ def test_vectordata_indexing():
|
||||||
assert all(subitem == i)
|
assert all(subitem == i)
|
||||||
|
|
||||||
# setting uses the same indexing logic
|
# setting uses the same indexing logic
|
||||||
data[0][:] = 5
|
data[0] = 5
|
||||||
assert all(data[0] == 5)
|
assert all(data[0] == 5)
|
||||||
|
data[0:3] = [5, 4, 3]
|
||||||
|
assert all(data[0] == 5)
|
||||||
|
assert all(data[1] == 4)
|
||||||
|
assert all(data[2] == 3)
|
||||||
|
data[0:3] = 6
|
||||||
|
assert all(data[0] == 6)
|
||||||
|
assert all(data[1] == 6)
|
||||||
|
assert all(data[2] == 6)
|
||||||
|
|
||||||
|
|
||||||
def test_vectordata_getattr():
|
def test_vectordata_getattr():
|
||||||
|
|
Loading…
Reference in a new issue