region tests

This commit is contained in:
sneakers-the-rat 2024-08-14 22:17:03 -07:00
parent 7cb8eea6fe
commit 36add1a306
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 404 additions and 215 deletions

View file

@ -48,9 +48,10 @@ class DynamicTableMixin(BaseModel):
but simplifying along the way :) but simplifying along the way :)
""" """
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", validate_assignment=True)
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]] __pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ( NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
"id",
"name", "name",
"colnames", "colnames",
"description", "description",
@ -116,6 +117,7 @@ class DynamicTableMixin(BaseModel):
return self._columns[item] return self._columns[item]
if isinstance(item, (int, slice, np.integer, np.ndarray)): if isinstance(item, (int, slice, np.integer, np.ndarray)):
data = self._slice_range(item) data = self._slice_range(item)
index = self.id[item]
elif isinstance(item, tuple): elif isinstance(item, tuple):
if len(item) != 2: if len(item) != 2:
raise ValueError( raise ValueError(
@ -133,11 +135,15 @@ class DynamicTableMixin(BaseModel):
return self._columns[cols][rows] return self._columns[cols][rows]
data = self._slice_range(rows, cols) data = self._slice_range(rows, cols)
index = self.id[rows]
else: else:
raise ValueError(f"Unsure how to get item with key {item}") raise ValueError(f"Unsure how to get item with key {item}")
# cast to DF # cast to DF
return pd.DataFrame(data) if not isinstance(index, Iterable):
index = [index]
index = pd.Index(data=index)
return pd.DataFrame(data, index=index)
def _slice_range( def _slice_range(
self, rows: Union[int, slice, np.ndarray], cols: Optional[Union[str, List[str]]] = None self, rows: Union[int, slice, np.ndarray], cols: Optional[Union[str, List[str]]] = None
@ -149,31 +155,40 @@ class DynamicTableMixin(BaseModel):
data = {} data = {}
for k in cols: for k in cols:
if isinstance(rows, np.ndarray): if isinstance(rows, np.ndarray):
# help wanted - this is probably cr*zy slow
val = [self._columns[k][i] for i in rows] val = [self._columns[k][i] for i in rows]
else: else:
val = self._columns[k][rows] val = self._columns[k][rows]
# scalars need to be wrapped in series for pandas # scalars need to be wrapped in series for pandas
# do this by the iterability of the rows index not the value because
# we want all lengths from this method to be equal, and if the rows are
# scalar, that means length == 1
if not isinstance(rows, (Iterable, slice)): if not isinstance(rows, (Iterable, slice)):
val = pd.Series([val]) val = [val]
data[k] = val data[k] = val
return data return data
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:
raise NotImplementedError("TODO") raise NotImplementedError("TODO") # pragma: no cover
def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]): def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]):
""" """
Add a column, appending it to ``colnames`` Add a column, appending it to ``colnames``
""" """
# don't use this while building the model # don't use this while building the model
if not getattr(self, "__pydantic_complete__", False): if not getattr(self, "__pydantic_complete__", False): # pragma: no cover
return super().__setattr__(key, value) return super().__setattr__(key, value)
if key not in self.model_fields_set and not key.endswith("_index"): if key not in self.model_fields_set and not key.endswith("_index"):
self.colnames.append(key) self.colnames.append(key)
# we get a recursion error if we setattr without having first added to
# extras if we need it to be there
if key not in self.model_fields and key not in self.__pydantic_extra__:
self.__pydantic_extra__[key] = value
return super().__setattr__(key, value) return super().__setattr__(key, value)
def __getattr__(self, item: str) -> Any: def __getattr__(self, item: str) -> Any:
@ -303,8 +318,8 @@ class DynamicTableMixin(BaseModel):
""" """
Ensure that all columns are equal length Ensure that all columns are equal length
""" """
lengths = [len(v) for v in self._columns.values()] lengths = [len(v) for v in self._columns.values()] + [len(self.id)]
assert [length == lengths[0] for length in lengths], ( assert all([length == lengths[0] for length in lengths]), (
"Columns are not of equal length! " "Columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
) )
@ -430,9 +445,21 @@ class VectorIndexMixin(BaseModel, Generic[T]):
raise AttributeError(f"Could not index with {item}") raise AttributeError(f"Could not index with {item}")
def __setitem__(self, key: Union[int, slice], value: Any) -> None: def __setitem__(self, key: Union[int, slice], value: Any) -> None:
if self._index: """
# VectorIndex is the thing that knows how to do the slicing Set a value on the :attr:`.target` .
self._index[key] = value
.. note::
Even though we correct the indexing logic from HDMF where the
_data_ is the thing that is provided by the API when one accesses
table.data (rather than table.data_index as hdmf does),
we will set to the target here (rather than to the index)
to be consistent. To modify the index, modify `self.value` directly
"""
if self.target:
# __getitem__ will return the indexed reference to the target
self[key] = value
else: else:
self.value[key] = value self.value[key] = value
@ -463,9 +490,19 @@ class DynamicTableRegionMixin(BaseModel):
_index: Optional["VectorIndex"] = None _index: Optional["VectorIndex"] = None
table: "DynamicTableMixin" table: "DynamicTableMixin"
value: Optional[NDArray] = None value: Optional[NDArray[Shape["*"], int]] = None
def __getitem__(self, item: Union[int, slice, Iterable]) -> Any: @overload
def __getitem__(self, item: int) -> pd.DataFrame: ...
@overload
def __getitem__(
self, item: Union[slice, Iterable]
) -> List[pd.DataFrame]: ...
def __getitem__(
self, item: Union[int, slice, Iterable]
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
""" """
Use ``value`` to index the table. Works analogously to ``VectorIndex`` despite Use ``value`` to index the table. Works analogously to ``VectorIndex`` despite
this being a subclass of ``VectorData`` this being a subclass of ``VectorData``
@ -486,6 +523,10 @@ class DynamicTableRegionMixin(BaseModel):
if isinstance(item, (int, np.integer)): if isinstance(item, (int, np.integer)):
return self.table[self.value[item]] return self.table[self.value[item]]
elif isinstance(item, (slice, Iterable)): elif isinstance(item, (slice, Iterable)):
# Return a list of dataframe rows because this is most often used
# as a column in a DynamicTable, so while it would normally be
# ideal to just return the slice as above as a single df,
# we need each row to be separate to fill the column
if isinstance(item, slice): if isinstance(item, slice):
item = range(*item.indices(len(self.value))) item = range(*item.indices(len(self.value)))
return [self.table[self.value[i]] for i in item] return [self.table[self.value[i]] for i in item]
@ -737,3 +778,8 @@ if "pytest" in sys.modules:
"""VectorIndex subclass for testing""" """VectorIndex subclass for testing"""
pass pass
class DynamicTableRegion(DynamicTableRegionMixin, VectorData):
"""DynamicTableRegion subclass for testing"""
pass

View file

@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
@ -9,213 +11,19 @@ from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorI
# FIXME: Make this just be the output of the provider by patching into import machinery # FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
DynamicTable,
DynamicTableRegion,
ElectrodeGroup, ElectrodeGroup,
VectorIndex,
VoltageClampStimulusSeries, VoltageClampStimulusSeries,
) )
from .conftest import _ragged_array from .conftest import _ragged_array
def test_dynamictable_indexing(electrical_series):
"""
Can index values from a dynamictable
"""
series, electrodes = electrical_series
colnames = [
"id",
"x",
"y",
"group",
"group_name",
"location",
"extra_column",
]
dtypes = [
np.dtype("int64"),
np.dtype("float64"),
np.dtype("float64"),
] + ([np.dtype("O")] * 4)
row = electrodes[0]
# successfully get a single row :)
assert row.shape == (1, 7)
assert row.dtypes.values.tolist() == dtypes
assert row.columns.tolist() == colnames
# slice a range of rows
rows = electrodes[0:3]
assert rows.shape == (3, 7)
assert rows.dtypes.values.tolist() == dtypes
assert rows.columns.tolist() == colnames
# get a single column
col = electrodes["y"]
assert all(col.value == [5, 6, 7, 8, 9])
# get a single cell
val = electrodes[0, "y"]
assert val == 5
val = electrodes[0, 2]
assert val == 5
# get a slice of rows and columns
subsection = electrodes[0:3, 0:3]
assert subsection.shape == (3, 3)
assert subsection.columns.tolist() == colnames[0:3]
assert subsection.dtypes.values.tolist() == dtypes[0:3]
def test_dynamictable_ragged(units):
"""
Should be able to index ragged arrays using an implicit _index column
Also tests:
- passing arrays directly instead of wrapping in vectordata/index specifically,
if the models in the fixture instantiate then this works
"""
units, spike_times, spike_idx = units
# ensure we don't pivot to long when indexing
assert units[0].shape[0] == 1
# check that we got the indexing boundaries corrunect
# (and that we are forwarding attr calls to the dataframe by accessing shape
for i in range(units.shape[0]):
assert np.all(units.iloc[i, 0] == spike_times[i])
def test_dynamictable_region_basic(electrical_series):
"""
DynamicTableRegion should be able to refer to a row or rows of another table
itself as a column within a table
"""
series, electrodes = electrical_series
row = series.electrodes[0]
# check that we correctly got the 4th row instead of the 0th row,
# since the indexed table was constructed with inverted indexes because it's a test, ya dummy.
# we will only vaguely check the basic functionality here bc
# a) the indexing behavior of the indexed objects is tested above, and
# b) every other object in the chain is strictly validated,
# so we assume if we got a right shaped df that it is the correct one.
# feel free to @ me when i am wrong about this
assert all(row.id == 4)
assert row.shape == (1, 7)
# and we should still be preserving the model that is the contents of the cell of this row
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
# objects and that's as far as we are going to chase the recursion in this basic indexing test
# ElectrodeGroup is strictly validating so an instance check is all we need.
assert isinstance(row.group.values[0], ElectrodeGroup)
# getting a list of table rows is actually correct behavior here because
# this list of table rows is actually the cell of another table
rows = series.electrodes[0:3]
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
def test_dynamictable_region_ragged():
"""
Dynamictables can also have indexes so that they are ragged arrays of column rows
"""
spike_times, spike_idx = _ragged_array(24)
spike_times_flat = np.concatenate(spike_times)
# construct a secondary index that selects overlapping segments of the first table
value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4])
idx = np.array([3, 6, 9])
table = DynamicTable(
name="table",
description="a table what else would it be",
id=np.arange(len(spike_idx)),
timeseries=spike_times_flat,
timeseries_index=spike_idx,
)
region = DynamicTableRegion(
name="dynamictableregion",
description="this field should be optional",
table=table,
value=value,
)
index = VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx)
region._index = index
rows = region[1]
# i guess this is right?
# the region should be a set of three rows of the table, with a ragged array column timeseries
# like...
#
# id timeseries
# 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...
# 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ...
# 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ...
assert rows.shape == (3, 2)
assert all(rows.id == [1, 2, 3])
assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())])
def test_dynamictable_append_column():
pass
def test_dynamictable_append_row():
pass
def test_dynamictable_extra_coercion():
"""
Extra fields should be coerced to VectorData and have their
indexing relationships handled when passed as plain arrays.
"""
def test_aligned_dynamictable(intracellular_recordings_table):
"""
Multiple aligned dynamictables should be indexable with a multiindex
"""
# can get a single row.. (check correctness below)
row = intracellular_recordings_table[0]
# can get a single table with its name
stimuli = intracellular_recordings_table["stimuli"]
assert stimuli.shape == (10, 1)
# nab a few rows to make the dataframe
rows = intracellular_recordings_table[0:3]
assert all(
rows.columns
== pd.MultiIndex.from_tuples(
[
("electrodes", "index"),
("electrodes", "electrode"),
("stimuli", "index"),
("stimuli", "stimulus"),
("responses", "index"),
("responses", "response"),
]
)
)
# ensure that we get the actual values from the TimeSeriesReferenceVectorData
# also tested separately
# each individual cell should be an array of VoltageClampStimulusSeries...
# and then we should be able to index within that as well
stims = rows["stimuli", "stimulus"][0]
for i in range(len(stims)):
assert isinstance(stims[i], VoltageClampStimulusSeries)
assert all([i == val for val in stims[i][:]])
# -------------------------------------------------- # --------------------------------------------------
# Direct mixin tests # Unit tests on mixins directly (model tests below)
# -------------------------------------------------- # --------------------------------------------------
def test_dynamictable_mixin_indexing(): @pytest.fixture()
""" def basic_table() -> tuple[DynamicTableMixin, dict[str, NDArray[Shape["10"], int]]]:
Can index values from a dynamictable
"""
class MyData(DynamicTableMixin): class MyData(DynamicTableMixin):
col_1: hdmf.VectorData[NDArray[Shape["*"], int]] col_1: hdmf.VectorData[NDArray[Shape["*"], int]]
col_2: hdmf.VectorData[NDArray[Shape["*"], int]] col_2: hdmf.VectorData[NDArray[Shape["*"], int]]
@ -228,8 +36,18 @@ def test_dynamictable_mixin_indexing():
"col_4": np.arange(10), "col_4": np.arange(10),
"col_5": np.arange(10), "col_5": np.arange(10),
} }
return MyData, cols
def test_dynamictable_mixin_indexing(basic_table):
"""
Can index values from a dynamictable
"""
MyData, cols = basic_table
colnames = [c for c in cols] colnames = [c for c in cols]
inst = MyData(**cols) inst = MyData(**cols)
assert len(inst) == 10
row = inst[0] row = inst[0]
# successfully get a single row :) # successfully get a single row :)
@ -251,9 +69,28 @@ def test_dynamictable_mixin_indexing():
assert val == 5 assert val == 5
# get a slice of rows and columns # get a slice of rows and columns
subsection = inst[0:3, 0:3] val = inst[0:3, 0:3]
assert subsection.shape == (3, 3) assert val.shape == (3, 3)
assert subsection.columns.tolist() == colnames[0:3] assert val.columns.tolist() == colnames[0:3]
# slice of rows with string colname
val = inst[0:2, "col_1"]
assert val.shape == (2, 1)
assert val.columns.tolist() == ["col_1"]
# array of rows
# crazy slow but we'll work on perf later
val = inst[np.arange(2), "col_1"]
assert val.shape == (2, 1)
assert val.columns.tolist() == ["col_1"]
# should raise an error on a 3d index
with pytest.raises(ValueError, match=".*2-dimensional.*"):
_ = inst[1, 1, 1]
# error on unhandled indexing type
with pytest.raises(ValueError, match="Unsure how to get item with key.*"):
_ = inst[5.5]
def test_dynamictable_mixin_colnames(): def test_dynamictable_mixin_colnames():
@ -337,9 +174,12 @@ def test_dynamictable_mixin_getattr():
assert isinstance(inst.existing_col, hdmf.VectorData) assert isinstance(inst.existing_col, hdmf.VectorData)
assert all(inst.existing_col.value == col.value) assert all(inst.existing_col.value == col.value)
# df lookup for thsoe that don't # df lookup for those that don't
assert isinstance(inst.columns, pd.Index) assert isinstance(inst.columns, pd.Index)
with pytest.raises(AttributeError):
_ = inst.really_fake_name_that_pandas_and_pydantic_definitely_dont_define
def test_dynamictable_coercion(): def test_dynamictable_coercion():
""" """
@ -348,15 +188,19 @@ def test_dynamictable_coercion():
class MyDT(DynamicTableMixin): class MyDT(DynamicTableMixin):
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]] existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
optional_col: Optional[hdmf.VectorData[NDArray[Shape["* col"], int]]]
cols = { cols = {
"existing_col": np.arange(10), "existing_col": np.arange(10),
"optional_col": np.arange(10),
"new_col_1": np.arange(10), "new_col_1": np.arange(10),
} }
inst = MyDT(**cols) inst = MyDT(**cols)
assert isinstance(inst.existing_col, hdmf.VectorData) assert isinstance(inst.existing_col, hdmf.VectorData)
assert isinstance(inst.optional_col, hdmf.VectorData)
assert isinstance(inst.new_col_1, hdmf.VectorData) assert isinstance(inst.new_col_1, hdmf.VectorData)
assert all(inst.existing_col.value == np.arange(10)) assert all(inst.existing_col.value == np.arange(10))
assert all(inst.optional_col.value == np.arange(10))
assert all(inst.new_col_1.value == np.arange(10)) assert all(inst.new_col_1.value == np.arange(10))
@ -409,14 +253,14 @@ def dynamictable_assert_equal_length():
"existing_col": np.arange(10), "existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(11)), "new_col_1": hdmf.VectorData(value=np.arange(11)),
} }
with pytest.raises(ValidationError, pattern="Columns are not of equal length"): with pytest.raises(ValidationError, match="Columns are not of equal length"):
_ = MyDT(**cols) _ = MyDT(**cols)
cols = { cols = {
"existing_col": np.arange(11), "existing_col": np.arange(11),
"new_col_1": hdmf.VectorData(value=np.arange(10)), "new_col_1": hdmf.VectorData(value=np.arange(10)),
} }
with pytest.raises(ValidationError, pattern="Columns are not of equal length"): with pytest.raises(ValidationError, match="Columns are not of equal length"):
_ = MyDT(**cols) _ = MyDT(**cols)
# wrong lengths are fine as long as the index is good # wrong lengths are fine as long as the index is good
@ -437,6 +281,32 @@ def dynamictable_assert_equal_length():
_ = MyDT(**cols) _ = MyDT(**cols)
def test_dynamictable_setattr():
"""
Setting a new column as an attribute adds it to colnames and reruns validations
"""
class MyDT(DynamicTableMixin):
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
cols = {
"existing_col": hdmf.VectorData(value=np.arange(10)),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
}
inst = MyDT(existing_col=cols["existing_col"])
assert inst.colnames == ["existing_col"]
inst.new_col_1 = cols["new_col_1"]
assert inst.colnames == ["existing_col", "new_col_1"]
assert inst[:].columns.tolist() == ["existing_col", "new_col_1"]
# length unchanged because id should be the same
assert len(inst) == 10
# model validators should be called to ensure equal length
with pytest.raises(ValidationError):
inst.new_col_2 = hdmf.VectorData(value=np.arange(11))
def test_vectordata_indexing(): def test_vectordata_indexing():
""" """
Vectordata/VectorIndex pairs should know how to index off each other Vectordata/VectorIndex pairs should know how to index off each other
@ -449,6 +319,10 @@ def test_vectordata_indexing():
# before we have an index, things should work as normal, indexing a 1D array # before we have an index, things should work as normal, indexing a 1D array
assert data[0] == 0 assert data[0] == 0
# and setting values
data[0] = 1
assert data[0] == 1
data[0] = 0
index = hdmf.VectorIndex(value=index_array, target=data) index = hdmf.VectorIndex(value=index_array, target=data)
data._index = index data._index = index
@ -468,6 +342,31 @@ def test_vectordata_indexing():
assert all(data[0] == 5) assert all(data[0] == 5)
def test_vectordata_getattr():
"""
VectorData and VectorIndex both forward getattr to ``value``
"""
data = hdmf.VectorData(value=np.arange(100))
index = hdmf.VectorIndex(value=np.arange(10, 101, 10), target=data)
# get attrs that we defined on the models
# i.e. no attribute errors here
_ = data.model_fields
_ = index.model_fields
# but for things that aren't defined, get the numpy method
# note that index should not try and get the sum from the target -
# that would be hella confusing. we only refer to the target when indexing.
assert data.sum() == np.sum(np.arange(100))
assert index.sum() == np.sum(np.arange(10, 101, 10))
# and also raise attribute errors when nothing is found
with pytest.raises(AttributeError):
_ = data.super_fake_attr_name
with pytest.raises(AttributeError):
_ = index.super_fake_attr_name
def test_vectordata_generic_numpydantic_validation(): def test_vectordata_generic_numpydantic_validation():
""" """
Using VectorData as a generic with a numpydantic array annotation should still validate Using VectorData as a generic with a numpydantic array annotation should still validate
@ -481,3 +380,247 @@ def test_vectordata_generic_numpydantic_validation():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
_ = MyDT(existing_col=np.zeros((4, 5, 6), dtype=int)) _ = MyDT(existing_col=np.zeros((4, 5, 6), dtype=int))
@pytest.mark.xfail
def test_dynamictable_append_row():
raise NotImplementedError("Reminder to implement row appending")
def test_dynamictable_region_indexing(basic_table):
"""
Without an index, DynamicTableRegion should just be a single-row index into
another table
"""
model, cols = basic_table
inst = model(**cols)
index = np.array([9, 4, 8, 3, 7, 2, 6, 1, 5, 0])
table_region = hdmf.DynamicTableRegion(value=index, table=inst)
row = table_region[1]
assert all(row.iloc[0] == index[1])
# slices
rows = table_region[3:5]
assert all(rows[0].iloc[0] == index[3])
assert all(rows[1].iloc[0] == index[4])
assert len(rows) == 2
assert all([row.shape == (1, 5) for row in rows])
# out of order fine too
oorder = [2, 5, 4]
rows = table_region[oorder]
assert len(rows) == 3
assert all([row.shape == (1, 5) for row in rows])
for i, idx in enumerate(oorder):
assert all(rows[i].iloc[0] == index[idx])
# also works when used as a column in a table
class AnotherTable(DynamicTableMixin):
region: hdmf.DynamicTableRegion
another_col: hdmf.VectorData[NDArray[Shape["*"], int]]
inst2 = AnotherTable(region=table_region, another_col=np.arange(10))
rows = inst2[0:3]
col = rows.region
for i in range(3):
assert all(col[i].iloc[0] == index[i])
def test_dynamictable_region_ragged():
"""
Dynamictables can also have indexes so that they are ragged arrays of column rows
"""
spike_times, spike_idx = _ragged_array(24)
spike_times_flat = np.concatenate(spike_times)
# construct a secondary index that selects overlapping segments of the first table
value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4])
idx = np.array([3, 6, 9])
table = DynamicTableMixin(
name="table",
description="a table what else would it be",
id=np.arange(len(spike_idx)),
another_column=np.arange(len(spike_idx) - 1, -1, -1),
timeseries=spike_times_flat,
timeseries_index=spike_idx,
)
region = hdmf.DynamicTableRegion(
table=table,
value=value,
)
index = hdmf.VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx)
region._index = index
rows = region[1]
# i guess this is right?
# the region should be a set of three rows of the table, with a ragged array column timeseries
# like...
#
# id timeseries
# 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...
# 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ...
# 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ...
assert rows.shape == (3, 2)
assert all(rows.index.to_numpy() == [1, 2, 3])
assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())])
rows = region[0:2]
for i in range(2):
assert all(
[all(row[1].timeseries == i) for i, row in zip(range(i, i + 3), rows[i].iterrows())]
)
# also works when used as a column in a table
class AnotherTable(DynamicTableMixin):
region: hdmf.DynamicTableRegion
yet_another_col: hdmf.VectorData[NDArray[Shape["*"], int]]
inst2 = AnotherTable(region=region, yet_another_col=np.arange(len(idx)))
row = inst2[0]
assert row.shape == (1, 2)
assert row.iloc[0, 0].equals(region[0])
rows = inst2[0:3]
for i, df in enumerate(rows.iloc[:, 0]):
assert df.equals(region[i])
# --------------------------------------------------
# Model-based tests
# --------------------------------------------------
def test_dynamictable_indexing_electricalseries(electrical_series):
"""
Can index values from a dynamictable
"""
series, electrodes = electrical_series
colnames = [
"id",
"x",
"y",
"group",
"group_name",
"location",
"extra_column",
]
dtypes = [
np.dtype("int64"),
np.dtype("float64"),
np.dtype("float64"),
] + ([np.dtype("O")] * 4)
row = electrodes[0]
# successfully get a single row :)
assert row.shape == (1, 7)
assert row.dtypes.values.tolist() == dtypes
assert row.columns.tolist() == colnames
# slice a range of rows
rows = electrodes[0:3]
assert rows.shape == (3, 7)
assert rows.dtypes.values.tolist() == dtypes
assert rows.columns.tolist() == colnames
# get a single column
col = electrodes["y"]
assert all(col.value == [5, 6, 7, 8, 9])
# get a single cell
val = electrodes[0, "y"]
assert val == 5
val = electrodes[0, 2]
assert val == 5
# get a slice of rows and columns
subsection = electrodes[0:3, 0:3]
assert subsection.shape == (3, 3)
assert subsection.columns.tolist() == colnames[0:3]
assert subsection.dtypes.values.tolist() == dtypes[0:3]
def test_dynamictable_ragged_units(units):
"""
Should be able to index ragged arrays using an implicit _index column
Also tests:
- passing arrays directly instead of wrapping in vectordata/index specifically,
if the models in the fixture instantiate then this works
"""
units, spike_times, spike_idx = units
# ensure we don't pivot to long when indexing
assert units[0].shape[0] == 1
# check that we got the indexing boundaries corrunect
# (and that we are forwarding attr calls to the dataframe by accessing shape
for i in range(units.shape[0]):
assert np.all(units.iloc[i, 0] == spike_times[i])
def test_dynamictable_region_basic_electricalseries(electrical_series):
"""
DynamicTableRegion should be able to refer to a row or rows of another table
itself as a column within a table
"""
series, electrodes = electrical_series
row = series.electrodes[0]
# check that we correctly got the 4th row instead of the 0th row,
# since the indexed table was constructed with inverted indexes because it's a test, ya dummy.
# we will only vaguely check the basic functionality here bc
# a) the indexing behavior of the indexed objects is tested above, and
# b) every other object in the chain is strictly validated,
# so we assume if we got a right shaped df that it is the correct one.
# feel free to @ me when i am wrong about this
assert all(row.id == 4)
assert row.shape == (1, 7)
# and we should still be preserving the model that is the contents of the cell of this row
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
# objects and that's as far as we are going to chase the recursion in this basic indexing test
# ElectrodeGroup is strictly validating so an instance check is all we need.
assert isinstance(row.group.values[0], ElectrodeGroup)
# getting a list of table rows is actually correct behavior here because
# this list of table rows is actually the cell of another table
rows = series.electrodes[0:3]
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
def test_aligned_dynamictable_ictable(intracellular_recordings_table):
"""
Multiple aligned dynamictables should be indexable with a multiindex
"""
# can get a single row.. (check correctness below)
row = intracellular_recordings_table[0]
# can get a single table with its name
stimuli = intracellular_recordings_table["stimuli"]
assert stimuli.shape == (10, 1)
# nab a few rows to make the dataframe
rows = intracellular_recordings_table[0:3]
assert all(
rows.columns
== pd.MultiIndex.from_tuples(
[
("electrodes", "index"),
("electrodes", "electrode"),
("stimuli", "index"),
("stimuli", "stimulus"),
("responses", "index"),
("responses", "response"),
]
)
)
# ensure that we get the actual values from the TimeSeriesReferenceVectorData
# also tested separately
# each individual cell should be an array of VoltageClampStimulusSeries...
# and then we should be able to index within that as well
stims = rows["stimuli", "stimulus"][0]
for i in range(len(stims)):
assert isinstance(stims[i], VoltageClampStimulusSeries)
assert all([i == val for val in stims[i][:]])