mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-10 00:34:29 +00:00
region tests
This commit is contained in:
parent
7cb8eea6fe
commit
36add1a306
2 changed files with 404 additions and 215 deletions
|
@ -48,9 +48,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
but simplifying along the way :)
|
but simplifying along the way :)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow", validate_assignment=True)
|
||||||
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
|
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
|
||||||
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
|
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
|
||||||
|
"id",
|
||||||
"name",
|
"name",
|
||||||
"colnames",
|
"colnames",
|
||||||
"description",
|
"description",
|
||||||
|
@ -116,6 +117,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
return self._columns[item]
|
return self._columns[item]
|
||||||
if isinstance(item, (int, slice, np.integer, np.ndarray)):
|
if isinstance(item, (int, slice, np.integer, np.ndarray)):
|
||||||
data = self._slice_range(item)
|
data = self._slice_range(item)
|
||||||
|
index = self.id[item]
|
||||||
elif isinstance(item, tuple):
|
elif isinstance(item, tuple):
|
||||||
if len(item) != 2:
|
if len(item) != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -133,11 +135,15 @@ class DynamicTableMixin(BaseModel):
|
||||||
return self._columns[cols][rows]
|
return self._columns[cols][rows]
|
||||||
|
|
||||||
data = self._slice_range(rows, cols)
|
data = self._slice_range(rows, cols)
|
||||||
|
index = self.id[rows]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsure how to get item with key {item}")
|
raise ValueError(f"Unsure how to get item with key {item}")
|
||||||
|
|
||||||
# cast to DF
|
# cast to DF
|
||||||
return pd.DataFrame(data)
|
if not isinstance(index, Iterable):
|
||||||
|
index = [index]
|
||||||
|
index = pd.Index(data=index)
|
||||||
|
return pd.DataFrame(data, index=index)
|
||||||
|
|
||||||
def _slice_range(
|
def _slice_range(
|
||||||
self, rows: Union[int, slice, np.ndarray], cols: Optional[Union[str, List[str]]] = None
|
self, rows: Union[int, slice, np.ndarray], cols: Optional[Union[str, List[str]]] = None
|
||||||
|
@ -149,31 +155,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
data = {}
|
data = {}
|
||||||
for k in cols:
|
for k in cols:
|
||||||
if isinstance(rows, np.ndarray):
|
if isinstance(rows, np.ndarray):
|
||||||
|
# help wanted - this is probably cr*zy slow
|
||||||
val = [self._columns[k][i] for i in rows]
|
val = [self._columns[k][i] for i in rows]
|
||||||
else:
|
else:
|
||||||
val = self._columns[k][rows]
|
val = self._columns[k][rows]
|
||||||
|
|
||||||
# scalars need to be wrapped in series for pandas
|
# scalars need to be wrapped in series for pandas
|
||||||
|
# do this by the iterability of the rows index not the value because
|
||||||
|
# we want all lengths from this method to be equal, and if the rows are
|
||||||
|
# scalar, that means length == 1
|
||||||
if not isinstance(rows, (Iterable, slice)):
|
if not isinstance(rows, (Iterable, slice)):
|
||||||
val = pd.Series([val])
|
val = [val]
|
||||||
|
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: Any) -> None:
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
raise NotImplementedError("TODO")
|
raise NotImplementedError("TODO") # pragma: no cover
|
||||||
|
|
||||||
def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]):
|
def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]):
|
||||||
"""
|
"""
|
||||||
Add a column, appending it to ``colnames``
|
Add a column, appending it to ``colnames``
|
||||||
"""
|
"""
|
||||||
# don't use this while building the model
|
# don't use this while building the model
|
||||||
if not getattr(self, "__pydantic_complete__", False):
|
if not getattr(self, "__pydantic_complete__", False): # pragma: no cover
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
if key not in self.model_fields_set and not key.endswith("_index"):
|
if key not in self.model_fields_set and not key.endswith("_index"):
|
||||||
self.colnames.append(key)
|
self.colnames.append(key)
|
||||||
|
|
||||||
|
# we get a recursion error if we setattr without having first added to
|
||||||
|
# extras if we need it to be there
|
||||||
|
if key not in self.model_fields and key not in self.__pydantic_extra__:
|
||||||
|
self.__pydantic_extra__[key] = value
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
def __getattr__(self, item: str) -> Any:
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
@ -303,8 +318,8 @@ class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
Ensure that all columns are equal length
|
Ensure that all columns are equal length
|
||||||
"""
|
"""
|
||||||
lengths = [len(v) for v in self._columns.values()]
|
lengths = [len(v) for v in self._columns.values()] + [len(self.id)]
|
||||||
assert [length == lengths[0] for length in lengths], (
|
assert all([length == lengths[0] for length in lengths]), (
|
||||||
"Columns are not of equal length! "
|
"Columns are not of equal length! "
|
||||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
)
|
)
|
||||||
|
@ -430,9 +445,21 @@ class VectorIndexMixin(BaseModel, Generic[T]):
|
||||||
raise AttributeError(f"Could not index with {item}")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
"""
|
||||||
# VectorIndex is the thing that knows how to do the slicing
|
Set a value on the :attr:`.target` .
|
||||||
self._index[key] = value
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Even though we correct the indexing logic from HDMF where the
|
||||||
|
_data_ is the thing that is provided by the API when one accesses
|
||||||
|
table.data (rather than table.data_index as hdmf does),
|
||||||
|
we will set to the target here (rather than to the index)
|
||||||
|
to be consistent. To modify the index, modify `self.value` directly
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.target:
|
||||||
|
# __getitem__ will return the indexed reference to the target
|
||||||
|
self[key] = value
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
@ -463,9 +490,19 @@ class DynamicTableRegionMixin(BaseModel):
|
||||||
_index: Optional["VectorIndex"] = None
|
_index: Optional["VectorIndex"] = None
|
||||||
|
|
||||||
table: "DynamicTableMixin"
|
table: "DynamicTableMixin"
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray[Shape["*"], int]] = None
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice, Iterable]) -> Any:
|
@overload
|
||||||
|
def __getitem__(self, item: int) -> pd.DataFrame: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(
|
||||||
|
self, item: Union[slice, Iterable]
|
||||||
|
) -> List[pd.DataFrame]: ...
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self, item: Union[int, slice, Iterable]
|
||||||
|
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
|
||||||
"""
|
"""
|
||||||
Use ``value`` to index the table. Works analogously to ``VectorIndex`` despite
|
Use ``value`` to index the table. Works analogously to ``VectorIndex`` despite
|
||||||
this being a subclass of ``VectorData``
|
this being a subclass of ``VectorData``
|
||||||
|
@ -486,6 +523,10 @@ class DynamicTableRegionMixin(BaseModel):
|
||||||
if isinstance(item, (int, np.integer)):
|
if isinstance(item, (int, np.integer)):
|
||||||
return self.table[self.value[item]]
|
return self.table[self.value[item]]
|
||||||
elif isinstance(item, (slice, Iterable)):
|
elif isinstance(item, (slice, Iterable)):
|
||||||
|
# Return a list of dataframe rows because this is most often used
|
||||||
|
# as a column in a DynamicTable, so while it would normally be
|
||||||
|
# ideal to just return the slice as above as a single df,
|
||||||
|
# we need each row to be separate to fill the column
|
||||||
if isinstance(item, slice):
|
if isinstance(item, slice):
|
||||||
item = range(*item.indices(len(self.value)))
|
item = range(*item.indices(len(self.value)))
|
||||||
return [self.table[self.value[i]] for i in item]
|
return [self.table[self.value[i]] for i in item]
|
||||||
|
@ -737,3 +778,8 @@ if "pytest" in sys.modules:
|
||||||
"""VectorIndex subclass for testing"""
|
"""VectorIndex subclass for testing"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class DynamicTableRegion(DynamicTableRegionMixin, VectorData):
|
||||||
|
"""DynamicTableRegion subclass for testing"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -9,213 +11,19 @@ from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorI
|
||||||
|
|
||||||
# FIXME: Make this just be the output of the provider by patching into import machinery
|
# FIXME: Make this just be the output of the provider by patching into import machinery
|
||||||
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||||
DynamicTable,
|
|
||||||
DynamicTableRegion,
|
|
||||||
ElectrodeGroup,
|
ElectrodeGroup,
|
||||||
VectorIndex,
|
|
||||||
VoltageClampStimulusSeries,
|
VoltageClampStimulusSeries,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .conftest import _ragged_array
|
from .conftest import _ragged_array
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_indexing(electrical_series):
|
|
||||||
"""
|
|
||||||
Can index values from a dynamictable
|
|
||||||
"""
|
|
||||||
series, electrodes = electrical_series
|
|
||||||
|
|
||||||
colnames = [
|
|
||||||
"id",
|
|
||||||
"x",
|
|
||||||
"y",
|
|
||||||
"group",
|
|
||||||
"group_name",
|
|
||||||
"location",
|
|
||||||
"extra_column",
|
|
||||||
]
|
|
||||||
dtypes = [
|
|
||||||
np.dtype("int64"),
|
|
||||||
np.dtype("float64"),
|
|
||||||
np.dtype("float64"),
|
|
||||||
] + ([np.dtype("O")] * 4)
|
|
||||||
|
|
||||||
row = electrodes[0]
|
|
||||||
# successfully get a single row :)
|
|
||||||
assert row.shape == (1, 7)
|
|
||||||
assert row.dtypes.values.tolist() == dtypes
|
|
||||||
assert row.columns.tolist() == colnames
|
|
||||||
|
|
||||||
# slice a range of rows
|
|
||||||
rows = electrodes[0:3]
|
|
||||||
assert rows.shape == (3, 7)
|
|
||||||
assert rows.dtypes.values.tolist() == dtypes
|
|
||||||
assert rows.columns.tolist() == colnames
|
|
||||||
|
|
||||||
# get a single column
|
|
||||||
col = electrodes["y"]
|
|
||||||
assert all(col.value == [5, 6, 7, 8, 9])
|
|
||||||
|
|
||||||
# get a single cell
|
|
||||||
val = electrodes[0, "y"]
|
|
||||||
assert val == 5
|
|
||||||
val = electrodes[0, 2]
|
|
||||||
assert val == 5
|
|
||||||
|
|
||||||
# get a slice of rows and columns
|
|
||||||
subsection = electrodes[0:3, 0:3]
|
|
||||||
assert subsection.shape == (3, 3)
|
|
||||||
assert subsection.columns.tolist() == colnames[0:3]
|
|
||||||
assert subsection.dtypes.values.tolist() == dtypes[0:3]
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_ragged(units):
|
|
||||||
"""
|
|
||||||
Should be able to index ragged arrays using an implicit _index column
|
|
||||||
|
|
||||||
Also tests:
|
|
||||||
- passing arrays directly instead of wrapping in vectordata/index specifically,
|
|
||||||
if the models in the fixture instantiate then this works
|
|
||||||
"""
|
|
||||||
units, spike_times, spike_idx = units
|
|
||||||
|
|
||||||
# ensure we don't pivot to long when indexing
|
|
||||||
assert units[0].shape[0] == 1
|
|
||||||
# check that we got the indexing boundaries corrunect
|
|
||||||
# (and that we are forwarding attr calls to the dataframe by accessing shape
|
|
||||||
for i in range(units.shape[0]):
|
|
||||||
assert np.all(units.iloc[i, 0] == spike_times[i])
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_region_basic(electrical_series):
|
|
||||||
"""
|
|
||||||
DynamicTableRegion should be able to refer to a row or rows of another table
|
|
||||||
itself as a column within a table
|
|
||||||
"""
|
|
||||||
series, electrodes = electrical_series
|
|
||||||
row = series.electrodes[0]
|
|
||||||
# check that we correctly got the 4th row instead of the 0th row,
|
|
||||||
# since the indexed table was constructed with inverted indexes because it's a test, ya dummy.
|
|
||||||
# we will only vaguely check the basic functionality here bc
|
|
||||||
# a) the indexing behavior of the indexed objects is tested above, and
|
|
||||||
# b) every other object in the chain is strictly validated,
|
|
||||||
# so we assume if we got a right shaped df that it is the correct one.
|
|
||||||
# feel free to @ me when i am wrong about this
|
|
||||||
assert all(row.id == 4)
|
|
||||||
assert row.shape == (1, 7)
|
|
||||||
# and we should still be preserving the model that is the contents of the cell of this row
|
|
||||||
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
|
|
||||||
# objects and that's as far as we are going to chase the recursion in this basic indexing test
|
|
||||||
# ElectrodeGroup is strictly validating so an instance check is all we need.
|
|
||||||
assert isinstance(row.group.values[0], ElectrodeGroup)
|
|
||||||
|
|
||||||
# getting a list of table rows is actually correct behavior here because
|
|
||||||
# this list of table rows is actually the cell of another table
|
|
||||||
rows = series.electrodes[0:3]
|
|
||||||
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_region_ragged():
|
|
||||||
"""
|
|
||||||
Dynamictables can also have indexes so that they are ragged arrays of column rows
|
|
||||||
"""
|
|
||||||
spike_times, spike_idx = _ragged_array(24)
|
|
||||||
spike_times_flat = np.concatenate(spike_times)
|
|
||||||
|
|
||||||
# construct a secondary index that selects overlapping segments of the first table
|
|
||||||
value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4])
|
|
||||||
idx = np.array([3, 6, 9])
|
|
||||||
|
|
||||||
table = DynamicTable(
|
|
||||||
name="table",
|
|
||||||
description="a table what else would it be",
|
|
||||||
id=np.arange(len(spike_idx)),
|
|
||||||
timeseries=spike_times_flat,
|
|
||||||
timeseries_index=spike_idx,
|
|
||||||
)
|
|
||||||
region = DynamicTableRegion(
|
|
||||||
name="dynamictableregion",
|
|
||||||
description="this field should be optional",
|
|
||||||
table=table,
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
index = VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx)
|
|
||||||
region._index = index
|
|
||||||
rows = region[1]
|
|
||||||
# i guess this is right?
|
|
||||||
# the region should be a set of three rows of the table, with a ragged array column timeseries
|
|
||||||
# like...
|
|
||||||
#
|
|
||||||
# id timeseries
|
|
||||||
# 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...
|
|
||||||
# 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ...
|
|
||||||
# 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ...
|
|
||||||
assert rows.shape == (3, 2)
|
|
||||||
assert all(rows.id == [1, 2, 3])
|
|
||||||
assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())])
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_append_column():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_append_row():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_extra_coercion():
|
|
||||||
"""
|
|
||||||
Extra fields should be coerced to VectorData and have their
|
|
||||||
indexing relationships handled when passed as plain arrays.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_aligned_dynamictable(intracellular_recordings_table):
|
|
||||||
"""
|
|
||||||
Multiple aligned dynamictables should be indexable with a multiindex
|
|
||||||
"""
|
|
||||||
# can get a single row.. (check correctness below)
|
|
||||||
row = intracellular_recordings_table[0]
|
|
||||||
# can get a single table with its name
|
|
||||||
stimuli = intracellular_recordings_table["stimuli"]
|
|
||||||
assert stimuli.shape == (10, 1)
|
|
||||||
|
|
||||||
# nab a few rows to make the dataframe
|
|
||||||
rows = intracellular_recordings_table[0:3]
|
|
||||||
assert all(
|
|
||||||
rows.columns
|
|
||||||
== pd.MultiIndex.from_tuples(
|
|
||||||
[
|
|
||||||
("electrodes", "index"),
|
|
||||||
("electrodes", "electrode"),
|
|
||||||
("stimuli", "index"),
|
|
||||||
("stimuli", "stimulus"),
|
|
||||||
("responses", "index"),
|
|
||||||
("responses", "response"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# ensure that we get the actual values from the TimeSeriesReferenceVectorData
|
|
||||||
# also tested separately
|
|
||||||
# each individual cell should be an array of VoltageClampStimulusSeries...
|
|
||||||
# and then we should be able to index within that as well
|
|
||||||
stims = rows["stimuli", "stimulus"][0]
|
|
||||||
for i in range(len(stims)):
|
|
||||||
assert isinstance(stims[i], VoltageClampStimulusSeries)
|
|
||||||
assert all([i == val for val in stims[i][:]])
|
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------------
|
# --------------------------------------------------
|
||||||
# Direct mixin tests
|
# Unit tests on mixins directly (model tests below)
|
||||||
# --------------------------------------------------
|
# --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_mixin_indexing():
|
@pytest.fixture()
|
||||||
"""
|
def basic_table() -> tuple[DynamicTableMixin, dict[str, NDArray[Shape["10"], int]]]:
|
||||||
Can index values from a dynamictable
|
|
||||||
"""
|
|
||||||
|
|
||||||
class MyData(DynamicTableMixin):
|
class MyData(DynamicTableMixin):
|
||||||
col_1: hdmf.VectorData[NDArray[Shape["*"], int]]
|
col_1: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||||
col_2: hdmf.VectorData[NDArray[Shape["*"], int]]
|
col_2: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||||
|
@ -228,8 +36,18 @@ def test_dynamictable_mixin_indexing():
|
||||||
"col_4": np.arange(10),
|
"col_4": np.arange(10),
|
||||||
"col_5": np.arange(10),
|
"col_5": np.arange(10),
|
||||||
}
|
}
|
||||||
|
return MyData, cols
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_mixin_indexing(basic_table):
|
||||||
|
"""
|
||||||
|
Can index values from a dynamictable
|
||||||
|
"""
|
||||||
|
MyData, cols = basic_table
|
||||||
|
|
||||||
colnames = [c for c in cols]
|
colnames = [c for c in cols]
|
||||||
inst = MyData(**cols)
|
inst = MyData(**cols)
|
||||||
|
assert len(inst) == 10
|
||||||
|
|
||||||
row = inst[0]
|
row = inst[0]
|
||||||
# successfully get a single row :)
|
# successfully get a single row :)
|
||||||
|
@ -251,9 +69,28 @@ def test_dynamictable_mixin_indexing():
|
||||||
assert val == 5
|
assert val == 5
|
||||||
|
|
||||||
# get a slice of rows and columns
|
# get a slice of rows and columns
|
||||||
subsection = inst[0:3, 0:3]
|
val = inst[0:3, 0:3]
|
||||||
assert subsection.shape == (3, 3)
|
assert val.shape == (3, 3)
|
||||||
assert subsection.columns.tolist() == colnames[0:3]
|
assert val.columns.tolist() == colnames[0:3]
|
||||||
|
|
||||||
|
# slice of rows with string colname
|
||||||
|
val = inst[0:2, "col_1"]
|
||||||
|
assert val.shape == (2, 1)
|
||||||
|
assert val.columns.tolist() == ["col_1"]
|
||||||
|
|
||||||
|
# array of rows
|
||||||
|
# crazy slow but we'll work on perf later
|
||||||
|
val = inst[np.arange(2), "col_1"]
|
||||||
|
assert val.shape == (2, 1)
|
||||||
|
assert val.columns.tolist() == ["col_1"]
|
||||||
|
|
||||||
|
# should raise an error on a 3d index
|
||||||
|
with pytest.raises(ValueError, match=".*2-dimensional.*"):
|
||||||
|
_ = inst[1, 1, 1]
|
||||||
|
|
||||||
|
# error on unhandled indexing type
|
||||||
|
with pytest.raises(ValueError, match="Unsure how to get item with key.*"):
|
||||||
|
_ = inst[5.5]
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_mixin_colnames():
|
def test_dynamictable_mixin_colnames():
|
||||||
|
@ -337,9 +174,12 @@ def test_dynamictable_mixin_getattr():
|
||||||
assert isinstance(inst.existing_col, hdmf.VectorData)
|
assert isinstance(inst.existing_col, hdmf.VectorData)
|
||||||
assert all(inst.existing_col.value == col.value)
|
assert all(inst.existing_col.value == col.value)
|
||||||
|
|
||||||
# df lookup for thsoe that don't
|
# df lookup for those that don't
|
||||||
assert isinstance(inst.columns, pd.Index)
|
assert isinstance(inst.columns, pd.Index)
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
_ = inst.really_fake_name_that_pandas_and_pydantic_definitely_dont_define
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_coercion():
|
def test_dynamictable_coercion():
|
||||||
"""
|
"""
|
||||||
|
@ -348,15 +188,19 @@ def test_dynamictable_coercion():
|
||||||
|
|
||||||
class MyDT(DynamicTableMixin):
|
class MyDT(DynamicTableMixin):
|
||||||
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
|
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
|
||||||
|
optional_col: Optional[hdmf.VectorData[NDArray[Shape["* col"], int]]]
|
||||||
|
|
||||||
cols = {
|
cols = {
|
||||||
"existing_col": np.arange(10),
|
"existing_col": np.arange(10),
|
||||||
|
"optional_col": np.arange(10),
|
||||||
"new_col_1": np.arange(10),
|
"new_col_1": np.arange(10),
|
||||||
}
|
}
|
||||||
inst = MyDT(**cols)
|
inst = MyDT(**cols)
|
||||||
assert isinstance(inst.existing_col, hdmf.VectorData)
|
assert isinstance(inst.existing_col, hdmf.VectorData)
|
||||||
|
assert isinstance(inst.optional_col, hdmf.VectorData)
|
||||||
assert isinstance(inst.new_col_1, hdmf.VectorData)
|
assert isinstance(inst.new_col_1, hdmf.VectorData)
|
||||||
assert all(inst.existing_col.value == np.arange(10))
|
assert all(inst.existing_col.value == np.arange(10))
|
||||||
|
assert all(inst.optional_col.value == np.arange(10))
|
||||||
assert all(inst.new_col_1.value == np.arange(10))
|
assert all(inst.new_col_1.value == np.arange(10))
|
||||||
|
|
||||||
|
|
||||||
|
@ -409,14 +253,14 @@ def dynamictable_assert_equal_length():
|
||||||
"existing_col": np.arange(10),
|
"existing_col": np.arange(10),
|
||||||
"new_col_1": hdmf.VectorData(value=np.arange(11)),
|
"new_col_1": hdmf.VectorData(value=np.arange(11)),
|
||||||
}
|
}
|
||||||
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
|
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||||
_ = MyDT(**cols)
|
_ = MyDT(**cols)
|
||||||
|
|
||||||
cols = {
|
cols = {
|
||||||
"existing_col": np.arange(11),
|
"existing_col": np.arange(11),
|
||||||
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||||
}
|
}
|
||||||
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
|
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||||
_ = MyDT(**cols)
|
_ = MyDT(**cols)
|
||||||
|
|
||||||
# wrong lengths are fine as long as the index is good
|
# wrong lengths are fine as long as the index is good
|
||||||
|
@ -437,6 +281,32 @@ def dynamictable_assert_equal_length():
|
||||||
_ = MyDT(**cols)
|
_ = MyDT(**cols)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_setattr():
|
||||||
|
"""
|
||||||
|
Setting a new column as an attribute adds it to colnames and reruns validations
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyDT(DynamicTableMixin):
|
||||||
|
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
|
||||||
|
|
||||||
|
cols = {
|
||||||
|
"existing_col": hdmf.VectorData(value=np.arange(10)),
|
||||||
|
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||||
|
}
|
||||||
|
inst = MyDT(existing_col=cols["existing_col"])
|
||||||
|
assert inst.colnames == ["existing_col"]
|
||||||
|
|
||||||
|
inst.new_col_1 = cols["new_col_1"]
|
||||||
|
assert inst.colnames == ["existing_col", "new_col_1"]
|
||||||
|
assert inst[:].columns.tolist() == ["existing_col", "new_col_1"]
|
||||||
|
# length unchanged because id should be the same
|
||||||
|
assert len(inst) == 10
|
||||||
|
|
||||||
|
# model validators should be called to ensure equal length
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
inst.new_col_2 = hdmf.VectorData(value=np.arange(11))
|
||||||
|
|
||||||
|
|
||||||
def test_vectordata_indexing():
|
def test_vectordata_indexing():
|
||||||
"""
|
"""
|
||||||
Vectordata/VectorIndex pairs should know how to index off each other
|
Vectordata/VectorIndex pairs should know how to index off each other
|
||||||
|
@ -449,6 +319,10 @@ def test_vectordata_indexing():
|
||||||
|
|
||||||
# before we have an index, things should work as normal, indexing a 1D array
|
# before we have an index, things should work as normal, indexing a 1D array
|
||||||
assert data[0] == 0
|
assert data[0] == 0
|
||||||
|
# and setting values
|
||||||
|
data[0] = 1
|
||||||
|
assert data[0] == 1
|
||||||
|
data[0] = 0
|
||||||
|
|
||||||
index = hdmf.VectorIndex(value=index_array, target=data)
|
index = hdmf.VectorIndex(value=index_array, target=data)
|
||||||
data._index = index
|
data._index = index
|
||||||
|
@ -468,6 +342,31 @@ def test_vectordata_indexing():
|
||||||
assert all(data[0] == 5)
|
assert all(data[0] == 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectordata_getattr():
|
||||||
|
"""
|
||||||
|
VectorData and VectorIndex both forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
data = hdmf.VectorData(value=np.arange(100))
|
||||||
|
index = hdmf.VectorIndex(value=np.arange(10, 101, 10), target=data)
|
||||||
|
|
||||||
|
# get attrs that we defined on the models
|
||||||
|
# i.e. no attribute errors here
|
||||||
|
_ = data.model_fields
|
||||||
|
_ = index.model_fields
|
||||||
|
|
||||||
|
# but for things that aren't defined, get the numpy method
|
||||||
|
# note that index should not try and get the sum from the target -
|
||||||
|
# that would be hella confusing. we only refer to the target when indexing.
|
||||||
|
assert data.sum() == np.sum(np.arange(100))
|
||||||
|
assert index.sum() == np.sum(np.arange(10, 101, 10))
|
||||||
|
|
||||||
|
# and also raise attribute errors when nothing is found
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
_ = data.super_fake_attr_name
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
_ = index.super_fake_attr_name
|
||||||
|
|
||||||
|
|
||||||
def test_vectordata_generic_numpydantic_validation():
|
def test_vectordata_generic_numpydantic_validation():
|
||||||
"""
|
"""
|
||||||
Using VectorData as a generic with a numpydantic array annotation should still validate
|
Using VectorData as a generic with a numpydantic array annotation should still validate
|
||||||
|
@ -481,3 +380,247 @@ def test_vectordata_generic_numpydantic_validation():
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
_ = MyDT(existing_col=np.zeros((4, 5, 6), dtype=int))
|
_ = MyDT(existing_col=np.zeros((4, 5, 6), dtype=int))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail
|
||||||
|
def test_dynamictable_append_row():
|
||||||
|
raise NotImplementedError("Reminder to implement row appending")
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_region_indexing(basic_table):
|
||||||
|
"""
|
||||||
|
Without an index, DynamicTableRegion should just be a single-row index into
|
||||||
|
another table
|
||||||
|
"""
|
||||||
|
model, cols = basic_table
|
||||||
|
inst = model(**cols)
|
||||||
|
|
||||||
|
index = np.array([9, 4, 8, 3, 7, 2, 6, 1, 5, 0])
|
||||||
|
|
||||||
|
table_region = hdmf.DynamicTableRegion(value=index, table=inst)
|
||||||
|
|
||||||
|
row = table_region[1]
|
||||||
|
assert all(row.iloc[0] == index[1])
|
||||||
|
|
||||||
|
# slices
|
||||||
|
rows = table_region[3:5]
|
||||||
|
assert all(rows[0].iloc[0] == index[3])
|
||||||
|
assert all(rows[1].iloc[0] == index[4])
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert all([row.shape == (1, 5) for row in rows])
|
||||||
|
|
||||||
|
# out of order fine too
|
||||||
|
oorder = [2, 5, 4]
|
||||||
|
rows = table_region[oorder]
|
||||||
|
assert len(rows) == 3
|
||||||
|
assert all([row.shape == (1, 5) for row in rows])
|
||||||
|
for i, idx in enumerate(oorder):
|
||||||
|
assert all(rows[i].iloc[0] == index[idx])
|
||||||
|
|
||||||
|
# also works when used as a column in a table
|
||||||
|
class AnotherTable(DynamicTableMixin):
|
||||||
|
region: hdmf.DynamicTableRegion
|
||||||
|
another_col: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||||
|
|
||||||
|
inst2 = AnotherTable(region=table_region, another_col=np.arange(10))
|
||||||
|
rows = inst2[0:3]
|
||||||
|
col = rows.region
|
||||||
|
for i in range(3):
|
||||||
|
assert all(col[i].iloc[0] == index[i])
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_region_ragged():
|
||||||
|
"""
|
||||||
|
Dynamictables can also have indexes so that they are ragged arrays of column rows
|
||||||
|
"""
|
||||||
|
spike_times, spike_idx = _ragged_array(24)
|
||||||
|
spike_times_flat = np.concatenate(spike_times)
|
||||||
|
|
||||||
|
# construct a secondary index that selects overlapping segments of the first table
|
||||||
|
value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4])
|
||||||
|
idx = np.array([3, 6, 9])
|
||||||
|
|
||||||
|
table = DynamicTableMixin(
|
||||||
|
name="table",
|
||||||
|
description="a table what else would it be",
|
||||||
|
id=np.arange(len(spike_idx)),
|
||||||
|
another_column=np.arange(len(spike_idx) - 1, -1, -1),
|
||||||
|
timeseries=spike_times_flat,
|
||||||
|
timeseries_index=spike_idx,
|
||||||
|
)
|
||||||
|
region = hdmf.DynamicTableRegion(
|
||||||
|
table=table,
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
index = hdmf.VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx)
|
||||||
|
region._index = index
|
||||||
|
|
||||||
|
rows = region[1]
|
||||||
|
# i guess this is right?
|
||||||
|
# the region should be a set of three rows of the table, with a ragged array column timeseries
|
||||||
|
# like...
|
||||||
|
#
|
||||||
|
# id timeseries
|
||||||
|
# 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...
|
||||||
|
# 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ...
|
||||||
|
# 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ...
|
||||||
|
assert rows.shape == (3, 2)
|
||||||
|
assert all(rows.index.to_numpy() == [1, 2, 3])
|
||||||
|
assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())])
|
||||||
|
|
||||||
|
rows = region[0:2]
|
||||||
|
for i in range(2):
|
||||||
|
assert all(
|
||||||
|
[all(row[1].timeseries == i) for i, row in zip(range(i, i + 3), rows[i].iterrows())]
|
||||||
|
)
|
||||||
|
|
||||||
|
# also works when used as a column in a table
|
||||||
|
class AnotherTable(DynamicTableMixin):
|
||||||
|
region: hdmf.DynamicTableRegion
|
||||||
|
yet_another_col: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||||
|
|
||||||
|
inst2 = AnotherTable(region=region, yet_another_col=np.arange(len(idx)))
|
||||||
|
row = inst2[0]
|
||||||
|
assert row.shape == (1, 2)
|
||||||
|
assert row.iloc[0, 0].equals(region[0])
|
||||||
|
|
||||||
|
rows = inst2[0:3]
|
||||||
|
for i, df in enumerate(rows.iloc[:, 0]):
|
||||||
|
assert df.equals(region[i])
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------
|
||||||
|
# Model-based tests
|
||||||
|
# --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_indexing_electricalseries(electrical_series):
|
||||||
|
"""
|
||||||
|
Can index values from a dynamictable
|
||||||
|
"""
|
||||||
|
series, electrodes = electrical_series
|
||||||
|
|
||||||
|
colnames = [
|
||||||
|
"id",
|
||||||
|
"x",
|
||||||
|
"y",
|
||||||
|
"group",
|
||||||
|
"group_name",
|
||||||
|
"location",
|
||||||
|
"extra_column",
|
||||||
|
]
|
||||||
|
dtypes = [
|
||||||
|
np.dtype("int64"),
|
||||||
|
np.dtype("float64"),
|
||||||
|
np.dtype("float64"),
|
||||||
|
] + ([np.dtype("O")] * 4)
|
||||||
|
|
||||||
|
row = electrodes[0]
|
||||||
|
# successfully get a single row :)
|
||||||
|
assert row.shape == (1, 7)
|
||||||
|
assert row.dtypes.values.tolist() == dtypes
|
||||||
|
assert row.columns.tolist() == colnames
|
||||||
|
|
||||||
|
# slice a range of rows
|
||||||
|
rows = electrodes[0:3]
|
||||||
|
assert rows.shape == (3, 7)
|
||||||
|
assert rows.dtypes.values.tolist() == dtypes
|
||||||
|
assert rows.columns.tolist() == colnames
|
||||||
|
|
||||||
|
# get a single column
|
||||||
|
col = electrodes["y"]
|
||||||
|
assert all(col.value == [5, 6, 7, 8, 9])
|
||||||
|
|
||||||
|
# get a single cell
|
||||||
|
val = electrodes[0, "y"]
|
||||||
|
assert val == 5
|
||||||
|
val = electrodes[0, 2]
|
||||||
|
assert val == 5
|
||||||
|
|
||||||
|
# get a slice of rows and columns
|
||||||
|
subsection = electrodes[0:3, 0:3]
|
||||||
|
assert subsection.shape == (3, 3)
|
||||||
|
assert subsection.columns.tolist() == colnames[0:3]
|
||||||
|
assert subsection.dtypes.values.tolist() == dtypes[0:3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_ragged_units(units):
|
||||||
|
"""
|
||||||
|
Should be able to index ragged arrays using an implicit _index column
|
||||||
|
|
||||||
|
Also tests:
|
||||||
|
- passing arrays directly instead of wrapping in vectordata/index specifically,
|
||||||
|
if the models in the fixture instantiate then this works
|
||||||
|
"""
|
||||||
|
units, spike_times, spike_idx = units
|
||||||
|
|
||||||
|
# ensure we don't pivot to long when indexing
|
||||||
|
assert units[0].shape[0] == 1
|
||||||
|
# check that we got the indexing boundaries corrunect
|
||||||
|
# (and that we are forwarding attr calls to the dataframe by accessing shape
|
||||||
|
for i in range(units.shape[0]):
|
||||||
|
assert np.all(units.iloc[i, 0] == spike_times[i])
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_region_basic_electricalseries(electrical_series):
|
||||||
|
"""
|
||||||
|
DynamicTableRegion should be able to refer to a row or rows of another table
|
||||||
|
itself as a column within a table
|
||||||
|
"""
|
||||||
|
series, electrodes = electrical_series
|
||||||
|
row = series.electrodes[0]
|
||||||
|
# check that we correctly got the 4th row instead of the 0th row,
|
||||||
|
# since the indexed table was constructed with inverted indexes because it's a test, ya dummy.
|
||||||
|
# we will only vaguely check the basic functionality here bc
|
||||||
|
# a) the indexing behavior of the indexed objects is tested above, and
|
||||||
|
# b) every other object in the chain is strictly validated,
|
||||||
|
# so we assume if we got a right shaped df that it is the correct one.
|
||||||
|
# feel free to @ me when i am wrong about this
|
||||||
|
assert all(row.id == 4)
|
||||||
|
assert row.shape == (1, 7)
|
||||||
|
# and we should still be preserving the model that is the contents of the cell of this row
|
||||||
|
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
|
||||||
|
# objects and that's as far as we are going to chase the recursion in this basic indexing test
|
||||||
|
# ElectrodeGroup is strictly validating so an instance check is all we need.
|
||||||
|
assert isinstance(row.group.values[0], ElectrodeGroup)
|
||||||
|
|
||||||
|
# getting a list of table rows is actually correct behavior here because
|
||||||
|
# this list of table rows is actually the cell of another table
|
||||||
|
rows = series.electrodes[0:3]
|
||||||
|
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
|
||||||
|
|
||||||
|
|
||||||
|
def test_aligned_dynamictable_ictable(intracellular_recordings_table):
|
||||||
|
"""
|
||||||
|
Multiple aligned dynamictables should be indexable with a multiindex
|
||||||
|
"""
|
||||||
|
# can get a single row.. (check correctness below)
|
||||||
|
row = intracellular_recordings_table[0]
|
||||||
|
# can get a single table with its name
|
||||||
|
stimuli = intracellular_recordings_table["stimuli"]
|
||||||
|
assert stimuli.shape == (10, 1)
|
||||||
|
|
||||||
|
# nab a few rows to make the dataframe
|
||||||
|
rows = intracellular_recordings_table[0:3]
|
||||||
|
assert all(
|
||||||
|
rows.columns
|
||||||
|
== pd.MultiIndex.from_tuples(
|
||||||
|
[
|
||||||
|
("electrodes", "index"),
|
||||||
|
("electrodes", "electrode"),
|
||||||
|
("stimuli", "index"),
|
||||||
|
("stimuli", "stimulus"),
|
||||||
|
("responses", "index"),
|
||||||
|
("responses", "response"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure that we get the actual values from the TimeSeriesReferenceVectorData
|
||||||
|
# also tested separately
|
||||||
|
# each individual cell should be an array of VoltageClampStimulusSeries...
|
||||||
|
# and then we should be able to index within that as well
|
||||||
|
stims = rows["stimuli", "stimulus"][0]
|
||||||
|
for i in range(len(stims)):
|
||||||
|
assert isinstance(stims[i], VoltageClampStimulusSeries)
|
||||||
|
assert all([i == val for val in stims[i][:]])
|
||||||
|
|
Loading…
Reference in a new issue