mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 06:04:28 +00:00
reference vector series tests, fix model tests
This commit is contained in:
parent
10965743eb
commit
24494b8ee4
2 changed files with 62 additions and 23 deletions
|
@ -748,9 +748,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
|||
for index, span, and timeseries
|
||||
"""
|
||||
|
||||
idx_start: NDArray[Any, int]
|
||||
count: NDArray[Any, int]
|
||||
timeseries: NDArray[Any, BaseModel]
|
||||
idx_start: NDArray[Shape["*"], int]
|
||||
count: NDArray[Shape["*"], int]
|
||||
timeseries: NDArray
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
|
||||
|
@ -789,11 +789,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
|||
)
|
||||
|
||||
if isinstance(item, (int, np.integer)):
|
||||
return self.timeseries[self._slice_helper(item)]
|
||||
elif isinstance(item, slice):
|
||||
return [self.timeseries[subitem] for subitem in self._slice_helper(item)]
|
||||
elif isinstance(item, Iterable):
|
||||
return [self.timeseries[self._slice_helper(subitem)] for subitem in item]
|
||||
return self.timeseries[item][self._slice_helper(item)]
|
||||
elif isinstance(item, (slice, Iterable)):
|
||||
if isinstance(item, slice):
|
||||
item = range(*item.indices(len(self.idx_start)))
|
||||
return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dont know how to index with {item}, must be an int, slice, or iterable"
|
||||
|
@ -806,13 +806,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
|||
" never done in the core schema."
|
||||
)
|
||||
if isinstance(key, (int, np.integer)):
|
||||
self.timeseries[self._slice_helper(key)] = value
|
||||
elif isinstance(key, slice):
|
||||
for subitem in self._slice_helper(key):
|
||||
self.timeseries[subitem] = value
|
||||
elif isinstance(key, Iterable):
|
||||
self.timeseries[key][self._slice_helper(key)] = value
|
||||
elif isinstance(key, (slice, Iterable)):
|
||||
if isinstance(key, slice):
|
||||
key = range(*key.indices(len(self.idx_start)))
|
||||
|
||||
if isinstance(value, Iterable):
|
||||
if len(key) != len(value):
|
||||
raise ValueError(
|
||||
"Can only assign equal-length iterable to a slice, manually index the"
|
||||
" target Timeseries object if you need more control"
|
||||
)
|
||||
for subitem, subvalue in zip(key, value):
|
||||
self.timeseries[subitem][self._slice_helper(subitem)] = subvalue
|
||||
else:
|
||||
for subitem in key:
|
||||
self.timeseries[self._slice_helper(subitem)] = value
|
||||
self.timeseries[subitem][self._slice_helper(subitem)] = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dont know how to index with {key}, must be an int, slice, or iterable"
|
||||
|
@ -898,3 +907,6 @@ if "pytest" in sys.modules:
|
|||
"""DynamicTableRegion subclass for testing"""
|
||||
|
||||
pass
|
||||
|
||||
class TimeSeriesReferenceVectorData(TimeSeriesReferenceVectorDataMixin):
|
||||
pass
|
||||
|
|
|
@ -611,6 +611,35 @@ def test_mixed_aligned_dynamictable(aligned_table):
|
|||
assert len(array) == index_array[i]
|
||||
|
||||
|
||||
def test_timeseriesreferencevectordata_index():
|
||||
"""
|
||||
TimeSeriesReferenceVectorData should be able to do the thing it does
|
||||
"""
|
||||
generator = np.random.default_rng()
|
||||
timeseries = np.array([np.arange(100)] * 10)
|
||||
|
||||
counts = generator.integers(1, 10, (10,))
|
||||
idx_start = np.arange(0, 100, 10)
|
||||
|
||||
response = hdmf.TimeSeriesReferenceVectorData(
|
||||
idx_start=idx_start,
|
||||
count=counts,
|
||||
timeseries=timeseries,
|
||||
)
|
||||
for i in range(len(counts)):
|
||||
assert len(response[i]) == counts[i]
|
||||
items = response[3:5]
|
||||
assert all(items[0] == timeseries[3][idx_start[3] : idx_start[3] + counts[3]])
|
||||
assert all(items[1] == timeseries[4][idx_start[4] : idx_start[4] + counts[4]])
|
||||
|
||||
response[0] = np.zeros((counts[0],))
|
||||
assert all(response[0] == 0)
|
||||
|
||||
response[1:3] = [np.zeros((counts[1],)), np.ones((counts[2],))]
|
||||
assert all(response[1] == 0)
|
||||
assert all(response[2] == 1)
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# Model-based tests
|
||||
# --------------------------------------------------
|
||||
|
@ -623,7 +652,6 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
|
|||
series, electrodes = electrical_series
|
||||
|
||||
colnames = [
|
||||
"id",
|
||||
"x",
|
||||
"y",
|
||||
"group",
|
||||
|
@ -632,20 +660,19 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
|
|||
"extra_column",
|
||||
]
|
||||
dtypes = [
|
||||
np.dtype("int64"),
|
||||
np.dtype("float64"),
|
||||
np.dtype("float64"),
|
||||
] + ([np.dtype("O")] * 4)
|
||||
|
||||
row = electrodes[0]
|
||||
# successfully get a single row :)
|
||||
assert row.shape == (1, 7)
|
||||
assert row.shape == (1, 6)
|
||||
assert row.dtypes.values.tolist() == dtypes
|
||||
assert row.columns.tolist() == colnames
|
||||
|
||||
# slice a range of rows
|
||||
rows = electrodes[0:3]
|
||||
assert rows.shape == (3, 7)
|
||||
assert rows.shape == (3, 6)
|
||||
assert rows.dtypes.values.tolist() == dtypes
|
||||
assert rows.columns.tolist() == colnames
|
||||
|
||||
|
@ -656,7 +683,7 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
|
|||
# get a single cell
|
||||
val = electrodes[0, "y"]
|
||||
assert val == 5
|
||||
val = electrodes[0, 2]
|
||||
val = electrodes[0, 1]
|
||||
assert val == 5
|
||||
|
||||
# get a slice of rows and columns
|
||||
|
@ -698,8 +725,8 @@ def test_dynamictable_region_basic_electricalseries(electrical_series):
|
|||
# b) every other object in the chain is strictly validated,
|
||||
# so we assume if we got a right shaped df that it is the correct one.
|
||||
# feel free to @ me when i am wrong about this
|
||||
assert all(row.id == 4)
|
||||
assert row.shape == (1, 7)
|
||||
assert all(row.index == 4)
|
||||
assert row.shape == (1, 6)
|
||||
# and we should still be preserving the model that is the contents of the cell of this row
|
||||
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
|
||||
# objects and that's as far as we are going to chase the recursion in this basic indexing test
|
||||
|
@ -709,7 +736,7 @@ def test_dynamictable_region_basic_electricalseries(electrical_series):
|
|||
# getting a list of table rows is actually correct behavior here because
|
||||
# this list of table rows is actually the cell of another table
|
||||
rows = series.electrodes[0:3]
|
||||
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
|
||||
assert all([all(row.index == idx) for row, idx in zip(rows, [4, 3, 2])])
|
||||
|
||||
|
||||
def test_aligned_dynamictable_ictable(intracellular_recordings_table):
|
||||
|
|
Loading…
Reference in a new issue