From 36add1a306980654aaf44fae6210445f5395c381 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 14 Aug 2024 22:17:03 -0700 Subject: [PATCH] region tests --- nwb_linkml/src/nwb_linkml/includes/hdmf.py | 70 ++- nwb_linkml/tests/test_includes/test_hdmf.py | 549 ++++++++++++-------- 2 files changed, 404 insertions(+), 215 deletions(-) diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 58cfb8c..9763ab3 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -48,9 +48,10 @@ class DynamicTableMixin(BaseModel): but simplifying along the way :) """ - model_config = ConfigDict(extra="allow") + model_config = ConfigDict(extra="allow", validate_assignment=True) __pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]] NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ( + "id", "name", "colnames", "description", @@ -116,6 +117,7 @@ class DynamicTableMixin(BaseModel): return self._columns[item] if isinstance(item, (int, slice, np.integer, np.ndarray)): data = self._slice_range(item) + index = self.id[item] elif isinstance(item, tuple): if len(item) != 2: raise ValueError( @@ -133,11 +135,15 @@ class DynamicTableMixin(BaseModel): return self._columns[cols][rows] data = self._slice_range(rows, cols) + index = self.id[rows] else: raise ValueError(f"Unsure how to get item with key {item}") # cast to DF - return pd.DataFrame(data) + if not isinstance(index, Iterable): + index = [index] + index = pd.Index(data=index) + return pd.DataFrame(data, index=index) def _slice_range( self, rows: Union[int, slice, np.ndarray], cols: Optional[Union[str, List[str]]] = None @@ -149,31 +155,40 @@ class DynamicTableMixin(BaseModel): data = {} for k in cols: if isinstance(rows, np.ndarray): + # help wanted - this is probably cr*zy slow val = [self._columns[k][i] for i in rows] else: val = self._columns[k][rows] # scalars need to be wrapped in series for pandas + # do this by the iterability of the rows index not the value because + # we want all lengths from this method to be equal, and if the rows are + # scalar, that means length == 1 if not isinstance(rows, (Iterable, slice)): - val = pd.Series([val]) + val = [val] data[k] = val return data def __setitem__(self, key: str, value: Any) -> None: - raise NotImplementedError("TODO") + raise NotImplementedError("TODO") # pragma: no cover def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]): """ Add a column, appending it to ``colnames`` """ # don't use this while building the model - if not getattr(self, "__pydantic_complete__", False): + if not getattr(self, "__pydantic_complete__", False): # pragma: no cover return super().__setattr__(key, value) if key not in self.model_fields_set and not key.endswith("_index"): self.colnames.append(key) + # we get a recursion error if we setattr without having first added to + # extras if we need it to be there + if key not in self.model_fields and key not in self.__pydantic_extra__: + self.__pydantic_extra__[key] = value + return super().__setattr__(key, value) def __getattr__(self, item: str) -> Any: @@ -303,8 +318,8 @@ class DynamicTableMixin(BaseModel): """ Ensure that all columns are equal length """ - lengths = [len(v) for v in self._columns.values()] - assert [length == lengths[0] for length in lengths], ( + lengths = [len(v) for v in self._columns.values()] + [len(self.id)] + assert all([length == lengths[0] for length in lengths]), ( "Columns are not of equal length! " f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" ) @@ -430,9 +445,21 @@ class VectorIndexMixin(BaseModel, Generic[T]): raise AttributeError(f"Could not index with {item}") def __setitem__(self, key: Union[int, slice], value: Any) -> None: - if self._index: - # VectorIndex is the thing that knows how to do the slicing - self._index[key] = value + """ + Set a value on the :attr:`.target` . + + .. note:: + + Even though we correct the indexing logic from HDMF where the + _data_ is the thing that is provided by the API when one accesses + table.data (rather than table.data_index as hdmf does), + we will set to the target here (rather than to the index) + to be consistent. To modify the index, modify `self.value` directly + + """ + if self.target: + # __getitem__ will return the indexed reference to the target + self[key] = value else: self.value[key] = value @@ -463,9 +490,19 @@ class DynamicTableRegionMixin(BaseModel): _index: Optional["VectorIndex"] = None table: "DynamicTableMixin" - value: Optional[NDArray] = None + value: Optional[NDArray[Shape["*"], int]] = None - def __getitem__(self, item: Union[int, slice, Iterable]) -> Any: + @overload + def __getitem__(self, item: int) -> pd.DataFrame: ... + + @overload + def __getitem__( + self, item: Union[slice, Iterable] + ) -> List[pd.DataFrame]: ... + + def __getitem__( + self, item: Union[int, slice, Iterable] + ) -> Union[pd.DataFrame, List[pd.DataFrame]]: """ Use ``value`` to index the table. Works analogously to ``VectorIndex`` despite this being a subclass of ``VectorData`` @@ -486,6 +523,10 @@ class DynamicTableRegionMixin(BaseModel): if isinstance(item, (int, np.integer)): return self.table[self.value[item]] elif isinstance(item, (slice, Iterable)): + # Return a list of dataframe rows because this is most often used + # as a column in a DynamicTable, so while it would normally be + # ideal to just return the slice as above as a single df, + # we need each row to be separate to fill the column if isinstance(item, slice): item = range(*item.indices(len(self.value))) return [self.table[self.value[i]] for i in item] @@ -737,3 +778,8 @@ if "pytest" in sys.modules: """VectorIndex subclass for testing""" pass + + class DynamicTableRegion(DynamicTableRegionMixin, VectorData): + """DynamicTableRegion subclass for testing""" + + pass diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index 100a70f..420a5ae 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import pandas as pd import pytest @@ -9,213 +11,19 @@ from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorI # FIXME: Make this just be the output of the provider by patching into import machinery from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( - DynamicTable, - DynamicTableRegion, ElectrodeGroup, - VectorIndex, VoltageClampStimulusSeries, ) from .conftest import _ragged_array - -def test_dynamictable_indexing(electrical_series): - """ - Can index values from a dynamictable - """ - series, electrodes = electrical_series - - colnames = [ - "id", - "x", - "y", - "group", - "group_name", - "location", - "extra_column", - ] - dtypes = [ - np.dtype("int64"), - np.dtype("float64"), - np.dtype("float64"), - ] + ([np.dtype("O")] * 4) - - row = electrodes[0] - # successfully get a single row :) - assert row.shape == (1, 7) - assert row.dtypes.values.tolist() == dtypes - assert row.columns.tolist() == colnames - - # slice a range of rows - rows = electrodes[0:3] - assert rows.shape == (3, 7) - assert rows.dtypes.values.tolist() == dtypes - assert rows.columns.tolist() == colnames - - # get a single column - col = electrodes["y"] - assert all(col.value == [5, 6, 7, 8, 9]) - - # get a single cell - val = electrodes[0, "y"] - assert val == 5 - val = electrodes[0, 2] - assert val == 5 - - # get a slice of rows and columns - subsection = electrodes[0:3, 0:3] - assert subsection.shape == (3, 3) - assert subsection.columns.tolist() == colnames[0:3] - assert subsection.dtypes.values.tolist() == dtypes[0:3] - - -def test_dynamictable_ragged(units): - """ - Should be able to index ragged arrays using an implicit _index column - - Also tests: - - passing arrays directly instead of wrapping in vectordata/index specifically, - if the models in the fixture instantiate then this works - """ - units, spike_times, spike_idx = units - - # ensure we don't pivot to long when indexing - assert units[0].shape[0] == 1 - # check that we got the indexing boundaries corrunect - # (and that we are forwarding attr calls to the dataframe by accessing shape - for i in range(units.shape[0]): - assert np.all(units.iloc[i, 0] == spike_times[i]) - - -def test_dynamictable_region_basic(electrical_series): - """ - DynamicTableRegion should be able to refer to a row or rows of another table - itself as a column within a table - """ - series, electrodes = electrical_series - row = series.electrodes[0] - # check that we correctly got the 4th row instead of the 0th row, - # since the indexed table was constructed with inverted indexes because it's a test, ya dummy. - # we will only vaguely check the basic functionality here bc - # a) the indexing behavior of the indexed objects is tested above, and - # b) every other object in the chain is strictly validated, - # so we assume if we got a right shaped df that it is the correct one. - # feel free to @ me when i am wrong about this - assert all(row.id == 4) - assert row.shape == (1, 7) - # and we should still be preserving the model that is the contents of the cell of this row - # so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup - # objects and that's as far as we are going to chase the recursion in this basic indexing test - # ElectrodeGroup is strictly validating so an instance check is all we need. - assert isinstance(row.group.values[0], ElectrodeGroup) - - # getting a list of table rows is actually correct behavior here because - # this list of table rows is actually the cell of another table - rows = series.electrodes[0:3] - assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])]) - - -def test_dynamictable_region_ragged(): - """ - Dynamictables can also have indexes so that they are ragged arrays of column rows - """ - spike_times, spike_idx = _ragged_array(24) - spike_times_flat = np.concatenate(spike_times) - - # construct a secondary index that selects overlapping segments of the first table - value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4]) - idx = np.array([3, 6, 9]) - - table = DynamicTable( - name="table", - description="a table what else would it be", - id=np.arange(len(spike_idx)), - timeseries=spike_times_flat, - timeseries_index=spike_idx, - ) - region = DynamicTableRegion( - name="dynamictableregion", - description="this field should be optional", - table=table, - value=value, - ) - index = VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx) - region._index = index - rows = region[1] - # i guess this is right? - # the region should be a set of three rows of the table, with a ragged array column timeseries - # like... - # - # id timeseries - # 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ... - # 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ... - # 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ... - assert rows.shape == (3, 2) - assert all(rows.id == [1, 2, 3]) - assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())]) - - -def test_dynamictable_append_column(): - pass - - -def test_dynamictable_append_row(): - pass - - -def test_dynamictable_extra_coercion(): - """ - Extra fields should be coerced to VectorData and have their - indexing relationships handled when passed as plain arrays. - """ - - -def test_aligned_dynamictable(intracellular_recordings_table): - """ - Multiple aligned dynamictables should be indexable with a multiindex - """ - # can get a single row.. (check correctness below) - row = intracellular_recordings_table[0] - # can get a single table with its name - stimuli = intracellular_recordings_table["stimuli"] - assert stimuli.shape == (10, 1) - - # nab a few rows to make the dataframe - rows = intracellular_recordings_table[0:3] - assert all( - rows.columns - == pd.MultiIndex.from_tuples( - [ - ("electrodes", "index"), - ("electrodes", "electrode"), - ("stimuli", "index"), - ("stimuli", "stimulus"), - ("responses", "index"), - ("responses", "response"), - ] - ) - ) - - # ensure that we get the actual values from the TimeSeriesReferenceVectorData - # also tested separately - # each individual cell should be an array of VoltageClampStimulusSeries... - # and then we should be able to index within that as well - stims = rows["stimuli", "stimulus"][0] - for i in range(len(stims)): - assert isinstance(stims[i], VoltageClampStimulusSeries) - assert all([i == val for val in stims[i][:]]) - - # -------------------------------------------------- -# Direct mixin tests +# Unit tests on mixins directly (model tests below) # -------------------------------------------------- -def test_dynamictable_mixin_indexing(): - """ - Can index values from a dynamictable - """ - +@pytest.fixture() +def basic_table() -> tuple[DynamicTableMixin, dict[str, NDArray[Shape["10"], int]]]: class MyData(DynamicTableMixin): col_1: hdmf.VectorData[NDArray[Shape["*"], int]] col_2: hdmf.VectorData[NDArray[Shape["*"], int]] @@ -228,8 +36,18 @@ def test_dynamictable_mixin_indexing(): "col_4": np.arange(10), "col_5": np.arange(10), } + return MyData, cols + + +def test_dynamictable_mixin_indexing(basic_table): + """ + Can index values from a dynamictable + """ + MyData, cols = basic_table + colnames = [c for c in cols] inst = MyData(**cols) + assert len(inst) == 10 row = inst[0] # successfully get a single row :) @@ -251,9 +69,28 @@ def test_dynamictable_mixin_indexing(): assert val == 5 # get a slice of rows and columns - subsection = inst[0:3, 0:3] - assert subsection.shape == (3, 3) - assert subsection.columns.tolist() == colnames[0:3] + val = inst[0:3, 0:3] + assert val.shape == (3, 3) + assert val.columns.tolist() == colnames[0:3] + + # slice of rows with string colname + val = inst[0:2, "col_1"] + assert val.shape == (2, 1) + assert val.columns.tolist() == ["col_1"] + + # array of rows + # crazy slow but we'll work on perf later + val = inst[np.arange(2), "col_1"] + assert val.shape == (2, 1) + assert val.columns.tolist() == ["col_1"] + + # should raise an error on a 3d index + with pytest.raises(ValueError, match=".*2-dimensional.*"): + _ = inst[1, 1, 1] + + # error on unhandled indexing type + with pytest.raises(ValueError, match="Unsure how to get item with key.*"): + _ = inst[5.5] def test_dynamictable_mixin_colnames(): @@ -337,9 +174,12 @@ def test_dynamictable_mixin_getattr(): assert isinstance(inst.existing_col, hdmf.VectorData) assert all(inst.existing_col.value == col.value) - # df lookup for thsoe that don't + # df lookup for those that don't assert isinstance(inst.columns, pd.Index) + with pytest.raises(AttributeError): + _ = inst.really_fake_name_that_pandas_and_pydantic_definitely_dont_define + def test_dynamictable_coercion(): """ @@ -348,15 +188,19 @@ def test_dynamictable_coercion(): class MyDT(DynamicTableMixin): existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]] + optional_col: Optional[hdmf.VectorData[NDArray[Shape["* col"], int]]] cols = { "existing_col": np.arange(10), + "optional_col": np.arange(10), "new_col_1": np.arange(10), } inst = MyDT(**cols) assert isinstance(inst.existing_col, hdmf.VectorData) + assert isinstance(inst.optional_col, hdmf.VectorData) assert isinstance(inst.new_col_1, hdmf.VectorData) assert all(inst.existing_col.value == np.arange(10)) + assert all(inst.optional_col.value == np.arange(10)) assert all(inst.new_col_1.value == np.arange(10)) @@ -409,14 +253,14 @@ def dynamictable_assert_equal_length(): "existing_col": np.arange(10), "new_col_1": hdmf.VectorData(value=np.arange(11)), } - with pytest.raises(ValidationError, pattern="Columns are not of equal length"): + with pytest.raises(ValidationError, match="Columns are not of equal length"): _ = MyDT(**cols) cols = { "existing_col": np.arange(11), "new_col_1": hdmf.VectorData(value=np.arange(10)), } - with pytest.raises(ValidationError, pattern="Columns are not of equal length"): + with pytest.raises(ValidationError, match="Columns are not of equal length"): _ = MyDT(**cols) # wrong lengths are fine as long as the index is good @@ -437,6 +281,32 @@ def dynamictable_assert_equal_length(): _ = MyDT(**cols) +def test_dynamictable_setattr(): + """ + Setting a new column as an attribute adds it to colnames and reruns validations + """ + + class MyDT(DynamicTableMixin): + existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]] + + cols = { + "existing_col": hdmf.VectorData(value=np.arange(10)), + "new_col_1": hdmf.VectorData(value=np.arange(10)), + } + inst = MyDT(existing_col=cols["existing_col"]) + assert inst.colnames == ["existing_col"] + + inst.new_col_1 = cols["new_col_1"] + assert inst.colnames == ["existing_col", "new_col_1"] + assert inst[:].columns.tolist() == ["existing_col", "new_col_1"] + # length unchanged because id should be the same + assert len(inst) == 10 + + # model validators should be called to ensure equal length + with pytest.raises(ValidationError): + inst.new_col_2 = hdmf.VectorData(value=np.arange(11)) + + def test_vectordata_indexing(): """ Vectordata/VectorIndex pairs should know how to index off each other @@ -449,6 +319,10 @@ def test_vectordata_indexing(): # before we have an index, things should work as normal, indexing a 1D array assert data[0] == 0 + # and setting values + data[0] = 1 + assert data[0] == 1 + data[0] = 0 index = hdmf.VectorIndex(value=index_array, target=data) data._index = index @@ -468,6 +342,31 @@ def test_vectordata_indexing(): assert all(data[0] == 5) +def test_vectordata_getattr(): + """ + VectorData and VectorIndex both forward getattr to ``value`` + """ + data = hdmf.VectorData(value=np.arange(100)) + index = hdmf.VectorIndex(value=np.arange(10, 101, 10), target=data) + + # get attrs that we defined on the models + # i.e. no attribute errors here + _ = data.model_fields + _ = index.model_fields + + # but for things that aren't defined, get the numpy method + # note that index should not try and get the sum from the target - + # that would be hella confusing. we only refer to the target when indexing. + assert data.sum() == np.sum(np.arange(100)) + assert index.sum() == np.sum(np.arange(10, 101, 10)) + + # and also raise attribute errors when nothing is found + with pytest.raises(AttributeError): + _ = data.super_fake_attr_name + with pytest.raises(AttributeError): + _ = index.super_fake_attr_name + + def test_vectordata_generic_numpydantic_validation(): """ Using VectorData as a generic with a numpydantic array annotation should still validate @@ -481,3 +380,247 @@ def test_vectordata_generic_numpydantic_validation(): with pytest.raises(ValidationError): _ = MyDT(existing_col=np.zeros((4, 5, 6), dtype=int)) + + +@pytest.mark.xfail +def test_dynamictable_append_row(): + raise NotImplementedError("Reminder to implement row appending") + + +def test_dynamictable_region_indexing(basic_table): + """ + Without an index, DynamicTableRegion should just be a single-row index into + another table + """ + model, cols = basic_table + inst = model(**cols) + + index = np.array([9, 4, 8, 3, 7, 2, 6, 1, 5, 0]) + + table_region = hdmf.DynamicTableRegion(value=index, table=inst) + + row = table_region[1] + assert all(row.iloc[0] == index[1]) + + # slices + rows = table_region[3:5] + assert all(rows[0].iloc[0] == index[3]) + assert all(rows[1].iloc[0] == index[4]) + assert len(rows) == 2 + assert all([row.shape == (1, 5) for row in rows]) + + # out of order fine too + oorder = [2, 5, 4] + rows = table_region[oorder] + assert len(rows) == 3 + assert all([row.shape == (1, 5) for row in rows]) + for i, idx in enumerate(oorder): + assert all(rows[i].iloc[0] == index[idx]) + + # also works when used as a column in a table + class AnotherTable(DynamicTableMixin): + region: hdmf.DynamicTableRegion + another_col: hdmf.VectorData[NDArray[Shape["*"], int]] + + inst2 = AnotherTable(region=table_region, another_col=np.arange(10)) + rows = inst2[0:3] + col = rows.region + for i in range(3): + assert all(col[i].iloc[0] == index[i]) + + +def test_dynamictable_region_ragged(): + """ + Dynamictables can also have indexes so that they are ragged arrays of column rows + """ + spike_times, spike_idx = _ragged_array(24) + spike_times_flat = np.concatenate(spike_times) + + # construct a secondary index that selects overlapping segments of the first table + value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4]) + idx = np.array([3, 6, 9]) + + table = DynamicTableMixin( + name="table", + description="a table what else would it be", + id=np.arange(len(spike_idx)), + another_column=np.arange(len(spike_idx) - 1, -1, -1), + timeseries=spike_times_flat, + timeseries_index=spike_idx, + ) + region = hdmf.DynamicTableRegion( + table=table, + value=value, + ) + index = hdmf.VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx) + region._index = index + + rows = region[1] + # i guess this is right? + # the region should be a set of three rows of the table, with a ragged array column timeseries + # like... + # + # id timeseries + # 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ... + # 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ... + # 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ... + assert rows.shape == (3, 2) + assert all(rows.index.to_numpy() == [1, 2, 3]) + assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())]) + + rows = region[0:2] + for i in range(2): + assert all( + [all(row[1].timeseries == i) for i, row in zip(range(i, i + 3), rows[i].iterrows())] + ) + + # also works when used as a column in a table + class AnotherTable(DynamicTableMixin): + region: hdmf.DynamicTableRegion + yet_another_col: hdmf.VectorData[NDArray[Shape["*"], int]] + + inst2 = AnotherTable(region=region, yet_another_col=np.arange(len(idx))) + row = inst2[0] + assert row.shape == (1, 2) + assert row.iloc[0, 0].equals(region[0]) + + rows = inst2[0:3] + for i, df in enumerate(rows.iloc[:, 0]): + assert df.equals(region[i]) + + +# -------------------------------------------------- +# Model-based tests +# -------------------------------------------------- + + +def test_dynamictable_indexing_electricalseries(electrical_series): + """ + Can index values from a dynamictable + """ + series, electrodes = electrical_series + + colnames = [ + "id", + "x", + "y", + "group", + "group_name", + "location", + "extra_column", + ] + dtypes = [ + np.dtype("int64"), + np.dtype("float64"), + np.dtype("float64"), + ] + ([np.dtype("O")] * 4) + + row = electrodes[0] + # successfully get a single row :) + assert row.shape == (1, 7) + assert row.dtypes.values.tolist() == dtypes + assert row.columns.tolist() == colnames + + # slice a range of rows + rows = electrodes[0:3] + assert rows.shape == (3, 7) + assert rows.dtypes.values.tolist() == dtypes + assert rows.columns.tolist() == colnames + + # get a single column + col = electrodes["y"] + assert all(col.value == [5, 6, 7, 8, 9]) + + # get a single cell + val = electrodes[0, "y"] + assert val == 5 + val = electrodes[0, 2] + assert val == 5 + + # get a slice of rows and columns + subsection = electrodes[0:3, 0:3] + assert subsection.shape == (3, 3) + assert subsection.columns.tolist() == colnames[0:3] + assert subsection.dtypes.values.tolist() == dtypes[0:3] + + +def test_dynamictable_ragged_units(units): + """ + Should be able to index ragged arrays using an implicit _index column + + Also tests: + - passing arrays directly instead of wrapping in vectordata/index specifically, + if the models in the fixture instantiate then this works + """ + units, spike_times, spike_idx = units + + # ensure we don't pivot to long when indexing + assert units[0].shape[0] == 1 + # check that we got the indexing boundaries corrunect + # (and that we are forwarding attr calls to the dataframe by accessing shape + for i in range(units.shape[0]): + assert np.all(units.iloc[i, 0] == spike_times[i]) + + +def test_dynamictable_region_basic_electricalseries(electrical_series): + """ + DynamicTableRegion should be able to refer to a row or rows of another table + itself as a column within a table + """ + series, electrodes = electrical_series + row = series.electrodes[0] + # check that we correctly got the 4th row instead of the 0th row, + # since the indexed table was constructed with inverted indexes because it's a test, ya dummy. + # we will only vaguely check the basic functionality here bc + # a) the indexing behavior of the indexed objects is tested above, and + # b) every other object in the chain is strictly validated, + # so we assume if we got a right shaped df that it is the correct one. + # feel free to @ me when i am wrong about this + assert all(row.id == 4) + assert row.shape == (1, 7) + # and we should still be preserving the model that is the contents of the cell of this row + # so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup + # objects and that's as far as we are going to chase the recursion in this basic indexing test + # ElectrodeGroup is strictly validating so an instance check is all we need. + assert isinstance(row.group.values[0], ElectrodeGroup) + + # getting a list of table rows is actually correct behavior here because + # this list of table rows is actually the cell of another table + rows = series.electrodes[0:3] + assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])]) + + +def test_aligned_dynamictable_ictable(intracellular_recordings_table): + """ + Multiple aligned dynamictables should be indexable with a multiindex + """ + # can get a single row.. (check correctness below) + row = intracellular_recordings_table[0] + # can get a single table with its name + stimuli = intracellular_recordings_table["stimuli"] + assert stimuli.shape == (10, 1) + + # nab a few rows to make the dataframe + rows = intracellular_recordings_table[0:3] + assert all( + rows.columns + == pd.MultiIndex.from_tuples( + [ + ("electrodes", "index"), + ("electrodes", "electrode"), + ("stimuli", "index"), + ("stimuli", "stimulus"), + ("responses", "index"), + ("responses", "response"), + ] + ) + ) + + # ensure that we get the actual values from the TimeSeriesReferenceVectorData + # also tested separately + # each individual cell should be an array of VoltageClampStimulusSeries... + # and then we should be able to index within that as well + stims = rows["stimuli", "stimulus"][0] + for i in range(len(stims)): + assert isinstance(stims[i], VoltageClampStimulusSeries) + assert all([i == val for val in stims[i][:]])