mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 06:04:28 +00:00
reference vector series tests, fix model tests
This commit is contained in:
parent
10965743eb
commit
24494b8ee4
2 changed files with 62 additions and 23 deletions
|
@ -748,9 +748,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
||||||
for index, span, and timeseries
|
for index, span, and timeseries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
idx_start: NDArray[Any, int]
|
idx_start: NDArray[Shape["*"], int]
|
||||||
count: NDArray[Any, int]
|
count: NDArray[Shape["*"], int]
|
||||||
timeseries: NDArray[Any, BaseModel]
|
timeseries: NDArray
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
|
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
|
||||||
|
@ -789,11 +789,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(item, (int, np.integer)):
|
if isinstance(item, (int, np.integer)):
|
||||||
return self.timeseries[self._slice_helper(item)]
|
return self.timeseries[item][self._slice_helper(item)]
|
||||||
elif isinstance(item, slice):
|
elif isinstance(item, (slice, Iterable)):
|
||||||
return [self.timeseries[subitem] for subitem in self._slice_helper(item)]
|
if isinstance(item, slice):
|
||||||
elif isinstance(item, Iterable):
|
item = range(*item.indices(len(self.idx_start)))
|
||||||
return [self.timeseries[self._slice_helper(subitem)] for subitem in item]
|
return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dont know how to index with {item}, must be an int, slice, or iterable"
|
f"Dont know how to index with {item}, must be an int, slice, or iterable"
|
||||||
|
@ -806,13 +806,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
||||||
" never done in the core schema."
|
" never done in the core schema."
|
||||||
)
|
)
|
||||||
if isinstance(key, (int, np.integer)):
|
if isinstance(key, (int, np.integer)):
|
||||||
self.timeseries[self._slice_helper(key)] = value
|
self.timeseries[key][self._slice_helper(key)] = value
|
||||||
elif isinstance(key, slice):
|
elif isinstance(key, (slice, Iterable)):
|
||||||
for subitem in self._slice_helper(key):
|
if isinstance(key, slice):
|
||||||
self.timeseries[subitem] = value
|
key = range(*key.indices(len(self.idx_start)))
|
||||||
elif isinstance(key, Iterable):
|
|
||||||
for subitem in key:
|
if isinstance(value, Iterable):
|
||||||
self.timeseries[self._slice_helper(subitem)] = value
|
if len(key) != len(value):
|
||||||
|
raise ValueError(
|
||||||
|
"Can only assign equal-length iterable to a slice, manually index the"
|
||||||
|
" target Timeseries object if you need more control"
|
||||||
|
)
|
||||||
|
for subitem, subvalue in zip(key, value):
|
||||||
|
self.timeseries[subitem][self._slice_helper(subitem)] = subvalue
|
||||||
|
else:
|
||||||
|
for subitem in key:
|
||||||
|
self.timeseries[subitem][self._slice_helper(subitem)] = value
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Dont know how to index with {key}, must be an int, slice, or iterable"
|
f"Dont know how to index with {key}, must be an int, slice, or iterable"
|
||||||
|
@ -898,3 +907,6 @@ if "pytest" in sys.modules:
|
||||||
"""DynamicTableRegion subclass for testing"""
|
"""DynamicTableRegion subclass for testing"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class TimeSeriesReferenceVectorData(TimeSeriesReferenceVectorDataMixin):
|
||||||
|
pass
|
||||||
|
|
|
@ -611,6 +611,35 @@ def test_mixed_aligned_dynamictable(aligned_table):
|
||||||
assert len(array) == index_array[i]
|
assert len(array) == index_array[i]
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeseriesreferencevectordata_index():
|
||||||
|
"""
|
||||||
|
TimeSeriesReferenceVectorData should be able to do the thing it does
|
||||||
|
"""
|
||||||
|
generator = np.random.default_rng()
|
||||||
|
timeseries = np.array([np.arange(100)] * 10)
|
||||||
|
|
||||||
|
counts = generator.integers(1, 10, (10,))
|
||||||
|
idx_start = np.arange(0, 100, 10)
|
||||||
|
|
||||||
|
response = hdmf.TimeSeriesReferenceVectorData(
|
||||||
|
idx_start=idx_start,
|
||||||
|
count=counts,
|
||||||
|
timeseries=timeseries,
|
||||||
|
)
|
||||||
|
for i in range(len(counts)):
|
||||||
|
assert len(response[i]) == counts[i]
|
||||||
|
items = response[3:5]
|
||||||
|
assert all(items[0] == timeseries[3][idx_start[3] : idx_start[3] + counts[3]])
|
||||||
|
assert all(items[1] == timeseries[4][idx_start[4] : idx_start[4] + counts[4]])
|
||||||
|
|
||||||
|
response[0] = np.zeros((counts[0],))
|
||||||
|
assert all(response[0] == 0)
|
||||||
|
|
||||||
|
response[1:3] = [np.zeros((counts[1],)), np.ones((counts[2],))]
|
||||||
|
assert all(response[1] == 0)
|
||||||
|
assert all(response[2] == 1)
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------
|
# --------------------------------------------------
|
||||||
# Model-based tests
|
# Model-based tests
|
||||||
# --------------------------------------------------
|
# --------------------------------------------------
|
||||||
|
@ -623,7 +652,6 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
|
||||||
series, electrodes = electrical_series
|
series, electrodes = electrical_series
|
||||||
|
|
||||||
colnames = [
|
colnames = [
|
||||||
"id",
|
|
||||||
"x",
|
"x",
|
||||||
"y",
|
"y",
|
||||||
"group",
|
"group",
|
||||||
|
@ -632,20 +660,19 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
|
||||||
"extra_column",
|
"extra_column",
|
||||||
]
|
]
|
||||||
dtypes = [
|
dtypes = [
|
||||||
np.dtype("int64"),
|
|
||||||
np.dtype("float64"),
|
np.dtype("float64"),
|
||||||
np.dtype("float64"),
|
np.dtype("float64"),
|
||||||
] + ([np.dtype("O")] * 4)
|
] + ([np.dtype("O")] * 4)
|
||||||
|
|
||||||
row = electrodes[0]
|
row = electrodes[0]
|
||||||
# successfully get a single row :)
|
# successfully get a single row :)
|
||||||
assert row.shape == (1, 7)
|
assert row.shape == (1, 6)
|
||||||
assert row.dtypes.values.tolist() == dtypes
|
assert row.dtypes.values.tolist() == dtypes
|
||||||
assert row.columns.tolist() == colnames
|
assert row.columns.tolist() == colnames
|
||||||
|
|
||||||
# slice a range of rows
|
# slice a range of rows
|
||||||
rows = electrodes[0:3]
|
rows = electrodes[0:3]
|
||||||
assert rows.shape == (3, 7)
|
assert rows.shape == (3, 6)
|
||||||
assert rows.dtypes.values.tolist() == dtypes
|
assert rows.dtypes.values.tolist() == dtypes
|
||||||
assert rows.columns.tolist() == colnames
|
assert rows.columns.tolist() == colnames
|
||||||
|
|
||||||
|
@ -656,7 +683,7 @@ def test_dynamictable_indexing_electricalseries(electrical_series):
|
||||||
# get a single cell
|
# get a single cell
|
||||||
val = electrodes[0, "y"]
|
val = electrodes[0, "y"]
|
||||||
assert val == 5
|
assert val == 5
|
||||||
val = electrodes[0, 2]
|
val = electrodes[0, 1]
|
||||||
assert val == 5
|
assert val == 5
|
||||||
|
|
||||||
# get a slice of rows and columns
|
# get a slice of rows and columns
|
||||||
|
@ -698,8 +725,8 @@ def test_dynamictable_region_basic_electricalseries(electrical_series):
|
||||||
# b) every other object in the chain is strictly validated,
|
# b) every other object in the chain is strictly validated,
|
||||||
# so we assume if we got a right shaped df that it is the correct one.
|
# so we assume if we got a right shaped df that it is the correct one.
|
||||||
# feel free to @ me when i am wrong about this
|
# feel free to @ me when i am wrong about this
|
||||||
assert all(row.id == 4)
|
assert all(row.index == 4)
|
||||||
assert row.shape == (1, 7)
|
assert row.shape == (1, 6)
|
||||||
# and we should still be preserving the model that is the contents of the cell of this row
|
# and we should still be preserving the model that is the contents of the cell of this row
|
||||||
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
|
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
|
||||||
# objects and that's as far as we are going to chase the recursion in this basic indexing test
|
# objects and that's as far as we are going to chase the recursion in this basic indexing test
|
||||||
|
@ -709,7 +736,7 @@ def test_dynamictable_region_basic_electricalseries(electrical_series):
|
||||||
# getting a list of table rows is actually correct behavior here because
|
# getting a list of table rows is actually correct behavior here because
|
||||||
# this list of table rows is actually the cell of another table
|
# this list of table rows is actually the cell of another table
|
||||||
rows = series.electrodes[0:3]
|
rows = series.electrodes[0:3]
|
||||||
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
|
assert all([all(row.index == idx) for row, idx in zip(rows, [4, 3, 2])])
|
||||||
|
|
||||||
|
|
||||||
def test_aligned_dynamictable_ictable(intracellular_recordings_table):
|
def test_aligned_dynamictable_ictable(intracellular_recordings_table):
|
||||||
|
|
Loading…
Reference in a new issue