mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-10 00:34:29 +00:00
region tests
This commit is contained in:
parent
7cb8eea6fe
commit
36add1a306
2 changed files with 404 additions and 215 deletions
|
@ -48,9 +48,10 @@ class DynamicTableMixin(BaseModel):
|
|||
but simplifying along the way :)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra="allow", validate_assignment=True)
|
||||
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
|
||||
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
|
||||
"id",
|
||||
"name",
|
||||
"colnames",
|
||||
"description",
|
||||
|
@ -116,6 +117,7 @@ class DynamicTableMixin(BaseModel):
|
|||
return self._columns[item]
|
||||
if isinstance(item, (int, slice, np.integer, np.ndarray)):
|
||||
data = self._slice_range(item)
|
||||
index = self.id[item]
|
||||
elif isinstance(item, tuple):
|
||||
if len(item) != 2:
|
||||
raise ValueError(
|
||||
|
@ -133,11 +135,15 @@ class DynamicTableMixin(BaseModel):
|
|||
return self._columns[cols][rows]
|
||||
|
||||
data = self._slice_range(rows, cols)
|
||||
index = self.id[rows]
|
||||
else:
|
||||
raise ValueError(f"Unsure how to get item with key {item}")
|
||||
|
||||
# cast to DF
|
||||
return pd.DataFrame(data)
|
||||
if not isinstance(index, Iterable):
|
||||
index = [index]
|
||||
index = pd.Index(data=index)
|
||||
return pd.DataFrame(data, index=index)
|
||||
|
||||
def _slice_range(
|
||||
self, rows: Union[int, slice, np.ndarray], cols: Optional[Union[str, List[str]]] = None
|
||||
|
@ -149,31 +155,40 @@ class DynamicTableMixin(BaseModel):
|
|||
data = {}
|
||||
for k in cols:
|
||||
if isinstance(rows, np.ndarray):
|
||||
# help wanted - this is probably cr*zy slow
|
||||
val = [self._columns[k][i] for i in rows]
|
||||
else:
|
||||
val = self._columns[k][rows]
|
||||
|
||||
# scalars need to be wrapped in series for pandas
|
||||
# do this by the iterability of the rows index not the value because
|
||||
# we want all lengths from this method to be equal, and if the rows are
|
||||
# scalar, that means length == 1
|
||||
if not isinstance(rows, (Iterable, slice)):
|
||||
val = pd.Series([val])
|
||||
val = [val]
|
||||
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
raise NotImplementedError("TODO")
|
||||
raise NotImplementedError("TODO") # pragma: no cover
|
||||
|
||||
def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]):
|
||||
"""
|
||||
Add a column, appending it to ``colnames``
|
||||
"""
|
||||
# don't use this while building the model
|
||||
if not getattr(self, "__pydantic_complete__", False):
|
||||
if not getattr(self, "__pydantic_complete__", False): # pragma: no cover
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
if key not in self.model_fields_set and not key.endswith("_index"):
|
||||
self.colnames.append(key)
|
||||
|
||||
# we get a recursion error if we setattr without having first added to
|
||||
# extras if we need it to be there
|
||||
if key not in self.model_fields and key not in self.__pydantic_extra__:
|
||||
self.__pydantic_extra__[key] = value
|
||||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
|
@ -303,8 +318,8 @@ class DynamicTableMixin(BaseModel):
|
|||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
lengths = [len(v) for v in self._columns.values()] + [len(self.id)]
|
||||
assert all([length == lengths[0] for length in lengths]), (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
|
@ -430,9 +445,21 @@ class VectorIndexMixin(BaseModel, Generic[T]):
|
|||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
# VectorIndex is the thing that knows how to do the slicing
|
||||
self._index[key] = value
|
||||
"""
|
||||
Set a value on the :attr:`.target` .
|
||||
|
||||
.. note::
|
||||
|
||||
Even though we correct the indexing logic from HDMF where the
|
||||
_data_ is the thing that is provided by the API when one accesses
|
||||
table.data (rather than table.data_index as hdmf does),
|
||||
we will set to the target here (rather than to the index)
|
||||
to be consistent. To modify the index, modify `self.value` directly
|
||||
|
||||
"""
|
||||
if self.target:
|
||||
# __getitem__ will return the indexed reference to the target
|
||||
self[key] = value
|
||||
else:
|
||||
self.value[key] = value
|
||||
|
||||
|
@ -463,9 +490,19 @@ class DynamicTableRegionMixin(BaseModel):
|
|||
_index: Optional["VectorIndex"] = None
|
||||
|
||||
table: "DynamicTableMixin"
|
||||
value: Optional[NDArray] = None
|
||||
value: Optional[NDArray[Shape["*"], int]] = None
|
||||
|
||||
def __getitem__(self, item: Union[int, slice, Iterable]) -> Any:
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> pd.DataFrame: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(
|
||||
self, item: Union[slice, Iterable]
|
||||
) -> List[pd.DataFrame]: ...
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice, Iterable]
|
||||
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
|
||||
"""
|
||||
Use ``value`` to index the table. Works analogously to ``VectorIndex`` despite
|
||||
this being a subclass of ``VectorData``
|
||||
|
@ -486,6 +523,10 @@ class DynamicTableRegionMixin(BaseModel):
|
|||
if isinstance(item, (int, np.integer)):
|
||||
return self.table[self.value[item]]
|
||||
elif isinstance(item, (slice, Iterable)):
|
||||
# Return a list of dataframe rows because this is most often used
|
||||
# as a column in a DynamicTable, so while it would normally be
|
||||
# ideal to just return the slice as above as a single df,
|
||||
# we need each row to be separate to fill the column
|
||||
if isinstance(item, slice):
|
||||
item = range(*item.indices(len(self.value)))
|
||||
return [self.table[self.value[i]] for i in item]
|
||||
|
@ -737,3 +778,8 @@ if "pytest" in sys.modules:
|
|||
"""VectorIndex subclass for testing"""
|
||||
|
||||
pass
|
||||
|
||||
class DynamicTableRegion(DynamicTableRegionMixin, VectorData):
|
||||
"""DynamicTableRegion subclass for testing"""
|
||||
|
||||
pass
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
@ -9,213 +11,19 @@ from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorI
|
|||
|
||||
# FIXME: Make this just be the output of the provider by patching into import machinery
|
||||
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||
DynamicTable,
|
||||
DynamicTableRegion,
|
||||
ElectrodeGroup,
|
||||
VectorIndex,
|
||||
VoltageClampStimulusSeries,
|
||||
)
|
||||
|
||||
from .conftest import _ragged_array
|
||||
|
||||
|
||||
def test_dynamictable_indexing(electrical_series):
|
||||
"""
|
||||
Can index values from a dynamictable
|
||||
"""
|
||||
series, electrodes = electrical_series
|
||||
|
||||
colnames = [
|
||||
"id",
|
||||
"x",
|
||||
"y",
|
||||
"group",
|
||||
"group_name",
|
||||
"location",
|
||||
"extra_column",
|
||||
]
|
||||
dtypes = [
|
||||
np.dtype("int64"),
|
||||
np.dtype("float64"),
|
||||
np.dtype("float64"),
|
||||
] + ([np.dtype("O")] * 4)
|
||||
|
||||
row = electrodes[0]
|
||||
# successfully get a single row :)
|
||||
assert row.shape == (1, 7)
|
||||
assert row.dtypes.values.tolist() == dtypes
|
||||
assert row.columns.tolist() == colnames
|
||||
|
||||
# slice a range of rows
|
||||
rows = electrodes[0:3]
|
||||
assert rows.shape == (3, 7)
|
||||
assert rows.dtypes.values.tolist() == dtypes
|
||||
assert rows.columns.tolist() == colnames
|
||||
|
||||
# get a single column
|
||||
col = electrodes["y"]
|
||||
assert all(col.value == [5, 6, 7, 8, 9])
|
||||
|
||||
# get a single cell
|
||||
val = electrodes[0, "y"]
|
||||
assert val == 5
|
||||
val = electrodes[0, 2]
|
||||
assert val == 5
|
||||
|
||||
# get a slice of rows and columns
|
||||
subsection = electrodes[0:3, 0:3]
|
||||
assert subsection.shape == (3, 3)
|
||||
assert subsection.columns.tolist() == colnames[0:3]
|
||||
assert subsection.dtypes.values.tolist() == dtypes[0:3]
|
||||
|
||||
|
||||
def test_dynamictable_ragged(units):
|
||||
"""
|
||||
Should be able to index ragged arrays using an implicit _index column
|
||||
|
||||
Also tests:
|
||||
- passing arrays directly instead of wrapping in vectordata/index specifically,
|
||||
if the models in the fixture instantiate then this works
|
||||
"""
|
||||
units, spike_times, spike_idx = units
|
||||
|
||||
# ensure we don't pivot to long when indexing
|
||||
assert units[0].shape[0] == 1
|
||||
# check that we got the indexing boundaries corrunect
|
||||
# (and that we are forwarding attr calls to the dataframe by accessing shape
|
||||
for i in range(units.shape[0]):
|
||||
assert np.all(units.iloc[i, 0] == spike_times[i])
|
||||
|
||||
|
||||
def test_dynamictable_region_basic(electrical_series):
|
||||
"""
|
||||
DynamicTableRegion should be able to refer to a row or rows of another table
|
||||
itself as a column within a table
|
||||
"""
|
||||
series, electrodes = electrical_series
|
||||
row = series.electrodes[0]
|
||||
# check that we correctly got the 4th row instead of the 0th row,
|
||||
# since the indexed table was constructed with inverted indexes because it's a test, ya dummy.
|
||||
# we will only vaguely check the basic functionality here bc
|
||||
# a) the indexing behavior of the indexed objects is tested above, and
|
||||
# b) every other object in the chain is strictly validated,
|
||||
# so we assume if we got a right shaped df that it is the correct one.
|
||||
# feel free to @ me when i am wrong about this
|
||||
assert all(row.id == 4)
|
||||
assert row.shape == (1, 7)
|
||||
# and we should still be preserving the model that is the contents of the cell of this row
|
||||
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
|
||||
# objects and that's as far as we are going to chase the recursion in this basic indexing test
|
||||
# ElectrodeGroup is strictly validating so an instance check is all we need.
|
||||
assert isinstance(row.group.values[0], ElectrodeGroup)
|
||||
|
||||
# getting a list of table rows is actually correct behavior here because
|
||||
# this list of table rows is actually the cell of another table
|
||||
rows = series.electrodes[0:3]
|
||||
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
|
||||
|
||||
|
||||
def test_dynamictable_region_ragged():
|
||||
"""
|
||||
Dynamictables can also have indexes so that they are ragged arrays of column rows
|
||||
"""
|
||||
spike_times, spike_idx = _ragged_array(24)
|
||||
spike_times_flat = np.concatenate(spike_times)
|
||||
|
||||
# construct a secondary index that selects overlapping segments of the first table
|
||||
value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4])
|
||||
idx = np.array([3, 6, 9])
|
||||
|
||||
table = DynamicTable(
|
||||
name="table",
|
||||
description="a table what else would it be",
|
||||
id=np.arange(len(spike_idx)),
|
||||
timeseries=spike_times_flat,
|
||||
timeseries_index=spike_idx,
|
||||
)
|
||||
region = DynamicTableRegion(
|
||||
name="dynamictableregion",
|
||||
description="this field should be optional",
|
||||
table=table,
|
||||
value=value,
|
||||
)
|
||||
index = VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx)
|
||||
region._index = index
|
||||
rows = region[1]
|
||||
# i guess this is right?
|
||||
# the region should be a set of three rows of the table, with a ragged array column timeseries
|
||||
# like...
|
||||
#
|
||||
# id timeseries
|
||||
# 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...
|
||||
# 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ...
|
||||
# 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ...
|
||||
assert rows.shape == (3, 2)
|
||||
assert all(rows.id == [1, 2, 3])
|
||||
assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())])
|
||||
|
||||
|
||||
def test_dynamictable_append_column():
|
||||
pass
|
||||
|
||||
|
||||
def test_dynamictable_append_row():
|
||||
pass
|
||||
|
||||
|
||||
def test_dynamictable_extra_coercion():
|
||||
"""
|
||||
Extra fields should be coerced to VectorData and have their
|
||||
indexing relationships handled when passed as plain arrays.
|
||||
"""
|
||||
|
||||
|
||||
def test_aligned_dynamictable(intracellular_recordings_table):
|
||||
"""
|
||||
Multiple aligned dynamictables should be indexable with a multiindex
|
||||
"""
|
||||
# can get a single row.. (check correctness below)
|
||||
row = intracellular_recordings_table[0]
|
||||
# can get a single table with its name
|
||||
stimuli = intracellular_recordings_table["stimuli"]
|
||||
assert stimuli.shape == (10, 1)
|
||||
|
||||
# nab a few rows to make the dataframe
|
||||
rows = intracellular_recordings_table[0:3]
|
||||
assert all(
|
||||
rows.columns
|
||||
== pd.MultiIndex.from_tuples(
|
||||
[
|
||||
("electrodes", "index"),
|
||||
("electrodes", "electrode"),
|
||||
("stimuli", "index"),
|
||||
("stimuli", "stimulus"),
|
||||
("responses", "index"),
|
||||
("responses", "response"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# ensure that we get the actual values from the TimeSeriesReferenceVectorData
|
||||
# also tested separately
|
||||
# each individual cell should be an array of VoltageClampStimulusSeries...
|
||||
# and then we should be able to index within that as well
|
||||
stims = rows["stimuli", "stimulus"][0]
|
||||
for i in range(len(stims)):
|
||||
assert isinstance(stims[i], VoltageClampStimulusSeries)
|
||||
assert all([i == val for val in stims[i][:]])
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# Direct mixin tests
|
||||
# Unit tests on mixins directly (model tests below)
|
||||
# --------------------------------------------------
|
||||
|
||||
|
||||
def test_dynamictable_mixin_indexing():
|
||||
"""
|
||||
Can index values from a dynamictable
|
||||
"""
|
||||
|
||||
@pytest.fixture()
|
||||
def basic_table() -> tuple[DynamicTableMixin, dict[str, NDArray[Shape["10"], int]]]:
|
||||
class MyData(DynamicTableMixin):
|
||||
col_1: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
col_2: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
|
@ -228,8 +36,18 @@ def test_dynamictable_mixin_indexing():
|
|||
"col_4": np.arange(10),
|
||||
"col_5": np.arange(10),
|
||||
}
|
||||
return MyData, cols
|
||||
|
||||
|
||||
def test_dynamictable_mixin_indexing(basic_table):
|
||||
"""
|
||||
Can index values from a dynamictable
|
||||
"""
|
||||
MyData, cols = basic_table
|
||||
|
||||
colnames = [c for c in cols]
|
||||
inst = MyData(**cols)
|
||||
assert len(inst) == 10
|
||||
|
||||
row = inst[0]
|
||||
# successfully get a single row :)
|
||||
|
@ -251,9 +69,28 @@ def test_dynamictable_mixin_indexing():
|
|||
assert val == 5
|
||||
|
||||
# get a slice of rows and columns
|
||||
subsection = inst[0:3, 0:3]
|
||||
assert subsection.shape == (3, 3)
|
||||
assert subsection.columns.tolist() == colnames[0:3]
|
||||
val = inst[0:3, 0:3]
|
||||
assert val.shape == (3, 3)
|
||||
assert val.columns.tolist() == colnames[0:3]
|
||||
|
||||
# slice of rows with string colname
|
||||
val = inst[0:2, "col_1"]
|
||||
assert val.shape == (2, 1)
|
||||
assert val.columns.tolist() == ["col_1"]
|
||||
|
||||
# array of rows
|
||||
# crazy slow but we'll work on perf later
|
||||
val = inst[np.arange(2), "col_1"]
|
||||
assert val.shape == (2, 1)
|
||||
assert val.columns.tolist() == ["col_1"]
|
||||
|
||||
# should raise an error on a 3d index
|
||||
with pytest.raises(ValueError, match=".*2-dimensional.*"):
|
||||
_ = inst[1, 1, 1]
|
||||
|
||||
# error on unhandled indexing type
|
||||
with pytest.raises(ValueError, match="Unsure how to get item with key.*"):
|
||||
_ = inst[5.5]
|
||||
|
||||
|
||||
def test_dynamictable_mixin_colnames():
|
||||
|
@ -337,9 +174,12 @@ def test_dynamictable_mixin_getattr():
|
|||
assert isinstance(inst.existing_col, hdmf.VectorData)
|
||||
assert all(inst.existing_col.value == col.value)
|
||||
|
||||
# df lookup for thsoe that don't
|
||||
# df lookup for those that don't
|
||||
assert isinstance(inst.columns, pd.Index)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
_ = inst.really_fake_name_that_pandas_and_pydantic_definitely_dont_define
|
||||
|
||||
|
||||
def test_dynamictable_coercion():
|
||||
"""
|
||||
|
@ -348,15 +188,19 @@ def test_dynamictable_coercion():
|
|||
|
||||
class MyDT(DynamicTableMixin):
|
||||
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
|
||||
optional_col: Optional[hdmf.VectorData[NDArray[Shape["* col"], int]]]
|
||||
|
||||
cols = {
|
||||
"existing_col": np.arange(10),
|
||||
"optional_col": np.arange(10),
|
||||
"new_col_1": np.arange(10),
|
||||
}
|
||||
inst = MyDT(**cols)
|
||||
assert isinstance(inst.existing_col, hdmf.VectorData)
|
||||
assert isinstance(inst.optional_col, hdmf.VectorData)
|
||||
assert isinstance(inst.new_col_1, hdmf.VectorData)
|
||||
assert all(inst.existing_col.value == np.arange(10))
|
||||
assert all(inst.optional_col.value == np.arange(10))
|
||||
assert all(inst.new_col_1.value == np.arange(10))
|
||||
|
||||
|
||||
|
@ -409,14 +253,14 @@ def dynamictable_assert_equal_length():
|
|||
"existing_col": np.arange(10),
|
||||
"new_col_1": hdmf.VectorData(value=np.arange(11)),
|
||||
}
|
||||
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
|
||||
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||
_ = MyDT(**cols)
|
||||
|
||||
cols = {
|
||||
"existing_col": np.arange(11),
|
||||
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||
}
|
||||
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
|
||||
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||
_ = MyDT(**cols)
|
||||
|
||||
# wrong lengths are fine as long as the index is good
|
||||
|
@ -437,6 +281,32 @@ def dynamictable_assert_equal_length():
|
|||
_ = MyDT(**cols)
|
||||
|
||||
|
||||
def test_dynamictable_setattr():
|
||||
"""
|
||||
Setting a new column as an attribute adds it to colnames and reruns validations
|
||||
"""
|
||||
|
||||
class MyDT(DynamicTableMixin):
|
||||
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
|
||||
|
||||
cols = {
|
||||
"existing_col": hdmf.VectorData(value=np.arange(10)),
|
||||
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||
}
|
||||
inst = MyDT(existing_col=cols["existing_col"])
|
||||
assert inst.colnames == ["existing_col"]
|
||||
|
||||
inst.new_col_1 = cols["new_col_1"]
|
||||
assert inst.colnames == ["existing_col", "new_col_1"]
|
||||
assert inst[:].columns.tolist() == ["existing_col", "new_col_1"]
|
||||
# length unchanged because id should be the same
|
||||
assert len(inst) == 10
|
||||
|
||||
# model validators should be called to ensure equal length
|
||||
with pytest.raises(ValidationError):
|
||||
inst.new_col_2 = hdmf.VectorData(value=np.arange(11))
|
||||
|
||||
|
||||
def test_vectordata_indexing():
|
||||
"""
|
||||
Vectordata/VectorIndex pairs should know how to index off each other
|
||||
|
@ -449,6 +319,10 @@ def test_vectordata_indexing():
|
|||
|
||||
# before we have an index, things should work as normal, indexing a 1D array
|
||||
assert data[0] == 0
|
||||
# and setting values
|
||||
data[0] = 1
|
||||
assert data[0] == 1
|
||||
data[0] = 0
|
||||
|
||||
index = hdmf.VectorIndex(value=index_array, target=data)
|
||||
data._index = index
|
||||
|
@ -468,6 +342,31 @@ def test_vectordata_indexing():
|
|||
assert all(data[0] == 5)
|
||||
|
||||
|
||||
def test_vectordata_getattr():
|
||||
"""
|
||||
VectorData and VectorIndex both forward getattr to ``value``
|
||||
"""
|
||||
data = hdmf.VectorData(value=np.arange(100))
|
||||
index = hdmf.VectorIndex(value=np.arange(10, 101, 10), target=data)
|
||||
|
||||
# get attrs that we defined on the models
|
||||
# i.e. no attribute errors here
|
||||
_ = data.model_fields
|
||||
_ = index.model_fields
|
||||
|
||||
# but for things that aren't defined, get the numpy method
|
||||
# note that index should not try and get the sum from the target -
|
||||
# that would be hella confusing. we only refer to the target when indexing.
|
||||
assert data.sum() == np.sum(np.arange(100))
|
||||
assert index.sum() == np.sum(np.arange(10, 101, 10))
|
||||
|
||||
# and also raise attribute errors when nothing is found
|
||||
with pytest.raises(AttributeError):
|
||||
_ = data.super_fake_attr_name
|
||||
with pytest.raises(AttributeError):
|
||||
_ = index.super_fake_attr_name
|
||||
|
||||
|
||||
def test_vectordata_generic_numpydantic_validation():
|
||||
"""
|
||||
Using VectorData as a generic with a numpydantic array annotation should still validate
|
||||
|
@ -481,3 +380,247 @@ def test_vectordata_generic_numpydantic_validation():
|
|||
|
||||
with pytest.raises(ValidationError):
|
||||
_ = MyDT(existing_col=np.zeros((4, 5, 6), dtype=int))
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_dynamictable_append_row():
|
||||
raise NotImplementedError("Reminder to implement row appending")
|
||||
|
||||
|
||||
def test_dynamictable_region_indexing(basic_table):
|
||||
"""
|
||||
Without an index, DynamicTableRegion should just be a single-row index into
|
||||
another table
|
||||
"""
|
||||
model, cols = basic_table
|
||||
inst = model(**cols)
|
||||
|
||||
index = np.array([9, 4, 8, 3, 7, 2, 6, 1, 5, 0])
|
||||
|
||||
table_region = hdmf.DynamicTableRegion(value=index, table=inst)
|
||||
|
||||
row = table_region[1]
|
||||
assert all(row.iloc[0] == index[1])
|
||||
|
||||
# slices
|
||||
rows = table_region[3:5]
|
||||
assert all(rows[0].iloc[0] == index[3])
|
||||
assert all(rows[1].iloc[0] == index[4])
|
||||
assert len(rows) == 2
|
||||
assert all([row.shape == (1, 5) for row in rows])
|
||||
|
||||
# out of order fine too
|
||||
oorder = [2, 5, 4]
|
||||
rows = table_region[oorder]
|
||||
assert len(rows) == 3
|
||||
assert all([row.shape == (1, 5) for row in rows])
|
||||
for i, idx in enumerate(oorder):
|
||||
assert all(rows[i].iloc[0] == index[idx])
|
||||
|
||||
# also works when used as a column in a table
|
||||
class AnotherTable(DynamicTableMixin):
|
||||
region: hdmf.DynamicTableRegion
|
||||
another_col: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
|
||||
inst2 = AnotherTable(region=table_region, another_col=np.arange(10))
|
||||
rows = inst2[0:3]
|
||||
col = rows.region
|
||||
for i in range(3):
|
||||
assert all(col[i].iloc[0] == index[i])
|
||||
|
||||
|
||||
def test_dynamictable_region_ragged():
|
||||
"""
|
||||
Dynamictables can also have indexes so that they are ragged arrays of column rows
|
||||
"""
|
||||
spike_times, spike_idx = _ragged_array(24)
|
||||
spike_times_flat = np.concatenate(spike_times)
|
||||
|
||||
# construct a secondary index that selects overlapping segments of the first table
|
||||
value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4])
|
||||
idx = np.array([3, 6, 9])
|
||||
|
||||
table = DynamicTableMixin(
|
||||
name="table",
|
||||
description="a table what else would it be",
|
||||
id=np.arange(len(spike_idx)),
|
||||
another_column=np.arange(len(spike_idx) - 1, -1, -1),
|
||||
timeseries=spike_times_flat,
|
||||
timeseries_index=spike_idx,
|
||||
)
|
||||
region = hdmf.DynamicTableRegion(
|
||||
table=table,
|
||||
value=value,
|
||||
)
|
||||
index = hdmf.VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx)
|
||||
region._index = index
|
||||
|
||||
rows = region[1]
|
||||
# i guess this is right?
|
||||
# the region should be a set of three rows of the table, with a ragged array column timeseries
|
||||
# like...
|
||||
#
|
||||
# id timeseries
|
||||
# 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...
|
||||
# 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ...
|
||||
# 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ...
|
||||
assert rows.shape == (3, 2)
|
||||
assert all(rows.index.to_numpy() == [1, 2, 3])
|
||||
assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())])
|
||||
|
||||
rows = region[0:2]
|
||||
for i in range(2):
|
||||
assert all(
|
||||
[all(row[1].timeseries == i) for i, row in zip(range(i, i + 3), rows[i].iterrows())]
|
||||
)
|
||||
|
||||
# also works when used as a column in a table
|
||||
class AnotherTable(DynamicTableMixin):
|
||||
region: hdmf.DynamicTableRegion
|
||||
yet_another_col: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
|
||||
inst2 = AnotherTable(region=region, yet_another_col=np.arange(len(idx)))
|
||||
row = inst2[0]
|
||||
assert row.shape == (1, 2)
|
||||
assert row.iloc[0, 0].equals(region[0])
|
||||
|
||||
rows = inst2[0:3]
|
||||
for i, df in enumerate(rows.iloc[:, 0]):
|
||||
assert df.equals(region[i])
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# Model-based tests
|
||||
# --------------------------------------------------
|
||||
|
||||
|
||||
def test_dynamictable_indexing_electricalseries(electrical_series):
|
||||
"""
|
||||
Can index values from a dynamictable
|
||||
"""
|
||||
series, electrodes = electrical_series
|
||||
|
||||
colnames = [
|
||||
"id",
|
||||
"x",
|
||||
"y",
|
||||
"group",
|
||||
"group_name",
|
||||
"location",
|
||||
"extra_column",
|
||||
]
|
||||
dtypes = [
|
||||
np.dtype("int64"),
|
||||
np.dtype("float64"),
|
||||
np.dtype("float64"),
|
||||
] + ([np.dtype("O")] * 4)
|
||||
|
||||
row = electrodes[0]
|
||||
# successfully get a single row :)
|
||||
assert row.shape == (1, 7)
|
||||
assert row.dtypes.values.tolist() == dtypes
|
||||
assert row.columns.tolist() == colnames
|
||||
|
||||
# slice a range of rows
|
||||
rows = electrodes[0:3]
|
||||
assert rows.shape == (3, 7)
|
||||
assert rows.dtypes.values.tolist() == dtypes
|
||||
assert rows.columns.tolist() == colnames
|
||||
|
||||
# get a single column
|
||||
col = electrodes["y"]
|
||||
assert all(col.value == [5, 6, 7, 8, 9])
|
||||
|
||||
# get a single cell
|
||||
val = electrodes[0, "y"]
|
||||
assert val == 5
|
||||
val = electrodes[0, 2]
|
||||
assert val == 5
|
||||
|
||||
# get a slice of rows and columns
|
||||
subsection = electrodes[0:3, 0:3]
|
||||
assert subsection.shape == (3, 3)
|
||||
assert subsection.columns.tolist() == colnames[0:3]
|
||||
assert subsection.dtypes.values.tolist() == dtypes[0:3]
|
||||
|
||||
|
||||
def test_dynamictable_ragged_units(units):
|
||||
"""
|
||||
Should be able to index ragged arrays using an implicit _index column
|
||||
|
||||
Also tests:
|
||||
- passing arrays directly instead of wrapping in vectordata/index specifically,
|
||||
if the models in the fixture instantiate then this works
|
||||
"""
|
||||
units, spike_times, spike_idx = units
|
||||
|
||||
# ensure we don't pivot to long when indexing
|
||||
assert units[0].shape[0] == 1
|
||||
# check that we got the indexing boundaries corrunect
|
||||
# (and that we are forwarding attr calls to the dataframe by accessing shape
|
||||
for i in range(units.shape[0]):
|
||||
assert np.all(units.iloc[i, 0] == spike_times[i])
|
||||
|
||||
|
||||
def test_dynamictable_region_basic_electricalseries(electrical_series):
|
||||
"""
|
||||
DynamicTableRegion should be able to refer to a row or rows of another table
|
||||
itself as a column within a table
|
||||
"""
|
||||
series, electrodes = electrical_series
|
||||
row = series.electrodes[0]
|
||||
# check that we correctly got the 4th row instead of the 0th row,
|
||||
# since the indexed table was constructed with inverted indexes because it's a test, ya dummy.
|
||||
# we will only vaguely check the basic functionality here bc
|
||||
# a) the indexing behavior of the indexed objects is tested above, and
|
||||
# b) every other object in the chain is strictly validated,
|
||||
# so we assume if we got a right shaped df that it is the correct one.
|
||||
# feel free to @ me when i am wrong about this
|
||||
assert all(row.id == 4)
|
||||
assert row.shape == (1, 7)
|
||||
# and we should still be preserving the model that is the contents of the cell of this row
|
||||
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
|
||||
# objects and that's as far as we are going to chase the recursion in this basic indexing test
|
||||
# ElectrodeGroup is strictly validating so an instance check is all we need.
|
||||
assert isinstance(row.group.values[0], ElectrodeGroup)
|
||||
|
||||
# getting a list of table rows is actually correct behavior here because
|
||||
# this list of table rows is actually the cell of another table
|
||||
rows = series.electrodes[0:3]
|
||||
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
|
||||
|
||||
|
||||
def test_aligned_dynamictable_ictable(intracellular_recordings_table):
|
||||
"""
|
||||
Multiple aligned dynamictables should be indexable with a multiindex
|
||||
"""
|
||||
# can get a single row.. (check correctness below)
|
||||
row = intracellular_recordings_table[0]
|
||||
# can get a single table with its name
|
||||
stimuli = intracellular_recordings_table["stimuli"]
|
||||
assert stimuli.shape == (10, 1)
|
||||
|
||||
# nab a few rows to make the dataframe
|
||||
rows = intracellular_recordings_table[0:3]
|
||||
assert all(
|
||||
rows.columns
|
||||
== pd.MultiIndex.from_tuples(
|
||||
[
|
||||
("electrodes", "index"),
|
||||
("electrodes", "electrode"),
|
||||
("stimuli", "index"),
|
||||
("stimuli", "stimulus"),
|
||||
("responses", "index"),
|
||||
("responses", "response"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# ensure that we get the actual values from the TimeSeriesReferenceVectorData
|
||||
# also tested separately
|
||||
# each individual cell should be an array of VoltageClampStimulusSeries...
|
||||
# and then we should be able to index within that as well
|
||||
stims = rows["stimuli", "stimulus"][0]
|
||||
for i in range(len(stims)):
|
||||
assert isinstance(stims[i], VoltageClampStimulusSeries)
|
||||
assert all([i == val for val in stims[i][:]])
|
||||
|
|
Loading…
Reference in a new issue