diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 82022de..0c4e7ce 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -33,7 +33,7 @@ from pydantic import ( model_validator, ) -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from nwb_linkml.models import VectorData, VectorIndex T = TypeVar("T", bound=NDArray) @@ -211,6 +211,8 @@ class DynamicTableMixin(BaseModel): """ Create ID column if not provided """ + if not isinstance(model, dict): + return model if "id" not in model: lengths = [] for key, val in model.items(): @@ -235,6 +237,8 @@ class DynamicTableMixin(BaseModel): the model dict is ordered after python3.6, so we can use that minus anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order """ + if not isinstance(model, dict): + return model if "colnames" not in model: colnames = [ k @@ -270,19 +274,21 @@ class DynamicTableMixin(BaseModel): See :meth:`.cast_specified_columns` for handling columns in the class specification """ # if columns are not in the specification, cast to a generic VectorData - for key, val in model.items(): - if key in cls.model_fields: - continue - if not isinstance(val, (VectorData, VectorIndex)): - try: - if key.endswith("_index"): - model[key] = VectorIndex(name=key, description="", value=val) - else: - model[key] = VectorData(name=key, description="", value=val) - except ValidationError as e: # pragma: no cover - raise ValidationError( - f"field {key} cannot be cast to VectorData from {val}" - ) from e + + if isinstance(model, dict): + for key, val in model.items(): + if key in cls.model_fields: + continue + if not isinstance(val, (VectorData, VectorIndex)): + try: + if key.endswith("_index"): + model[key] = VectorIndex(name=key, description="", value=val) + else: + model[key] = VectorData(name=key, description="", value=val) + except ValidationError as e: # pragma: no cover + raise ValidationError( + f"field {key} cannot be cast to VectorData from {val}" + ) from e return model @model_validator(mode="after") @@ -437,7 +443,7 @@ class VectorIndexMixin(BaseModel, Generic[T]): if isinstance(item, slice): item = range(*item.indices(len(self.value))) return [self.target.value[self._slice(i)] for i in item] - else: + else: # pragma: no cover raise AttributeError(f"Could not index with {item}") def __setitem__(self, key: Union[int, slice], value: Any) -> None: @@ -530,7 +536,7 @@ class DynamicTableRegionMixin(BaseModel): # so we index table with an array to construct # a list of lists of rows return [self.table[idx] for idx in self._index[item]] - else: + else: # pragma: no cover raise ValueError(f"Dont know how to index with {item}, need an int or a slice") else: if isinstance(item, (int, np.integer)): @@ -543,19 +549,26 @@ class DynamicTableRegionMixin(BaseModel): if isinstance(item, slice): item = range(*item.indices(len(self.value))) return [self.table[self.value[i]] for i in item] - else: + else: # pragma: no cover raise ValueError(f"Dont know how to index with {item}, need an int or a slice") def __setitem__(self, key: Union[int, str, slice], value: Any) -> None: - self.table[self.value[key]] = value + # self.table[self.value[key]] = value + raise NotImplementedError( + "Assigning values to tables is not implemented yet!" + ) # pragma: no cover -class AlignedDynamicTableMixin(DynamicTableMixin): +class AlignedDynamicTableMixin(BaseModel): """ Mixin to allow indexing multiple tables that are aligned on a common ID + + A great deal of code duplication because we need to avoid diamond inheritance + and also it's not so easy to copy a pydantic validator method. """ - __pydantic_extra__: Dict[str, "DynamicTableMixin"] + model_config = ConfigDict(extra="allow", validate_assignment=True) + __pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]] NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = ( "name", @@ -573,7 +586,7 @@ class AlignedDynamicTableMixin(DynamicTableMixin): return {k: getattr(self, k) for i, k in enumerate(self.categories)} def __getitem__( - self, item: Union[int, str, slice, Tuple[Union[int, slice], str]] + self, item: Union[int, str, slice, NDArray[Shape["*"], int], Tuple[Union[int, slice], str]] ) -> pd.DataFrame: """ Mimic hdmf: @@ -591,25 +604,78 @@ class AlignedDynamicTableMixin(DynamicTableMixin): elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str): # get a slice of a single table return self._categories[item[1]][item[0]] - elif isinstance(item, (int, slice)): + elif isinstance(item, (int, slice, Iterable)): # get a slice of all the tables ids = self.id[item] if not isinstance(ids, Iterable): ids = pd.Series([ids]) ids = pd.DataFrame({"id": ids}) - tables = [ids] + [table[item].reset_index() for table in self._categories.values()] + tables = [ids] + for category_name, category in self._categories.items(): + table = category[item] + if isinstance(table, pd.DataFrame): + table = table.reset_index() + elif isinstance(table, np.ndarray): + table = pd.DataFrame({category_name: [table]}) + elif isinstance(table, Iterable): + table = pd.DataFrame({category_name: table}) + else: + raise ValueError( + f"Don't know how to construct category table for {category_name}" + ) + tables.append(table) + names = [self.name] + self.categories # construct below in case we need to support array indexing in the future else: raise ValueError( f"Dont know how to index with {item}, " - "need an int, string, slice, or tuple[int | slice, str]" + "need an int, string, slice, ndarray, or tuple[int | slice, str]" ) df = pd.concat(tables, axis=1, keys=names) df.set_index((self.name, "id"), drop=True, inplace=True) return df + def __getattr__(self, item: str) -> Any: + """Try and use pandas df attrs if we don't have them""" + try: + return BaseModel.__getattr__(self, item) + except AttributeError as e: + try: + return getattr(self[:], item) + except AttributeError: + raise e from None + + def __len__(self) -> int: + """ + Use the id column to determine length. + + If the id column doesn't represent length accurately, it's a bug + """ + return len(self.id) + + @model_validator(mode="before") + @classmethod + def create_id(cls, model: Dict[str, Any]) -> Dict: + """ + Create ID column if not provided + """ + if "id" not in model: + lengths = [] + for key, val in model.items(): + # don't get lengths of columns with an index + if ( + f"{key}_index" in model + or (isinstance(val, VectorData) and val._index) + or key in cls.NON_CATEGORY_FIELDS + ): + continue + lengths.append(len(val)) + model["id"] = np.arange(np.max(lengths)) + + return model + @model_validator(mode="before") @classmethod def create_categories(cls, model: Dict[str, Any]) -> Dict: @@ -636,6 +702,42 @@ class AlignedDynamicTableMixin(DynamicTableMixin): model["categories"].extend(categories) return model + @model_validator(mode="after") + def resolve_targets(self) -> "DynamicTableMixin": + """ + Ensure that any implicitly indexed columns are linked, and create backlinks + """ + for key, col in self._categories.items(): + if isinstance(col, VectorData): + # find an index + idx = None + for field_name in self.model_fields_set: + if field_name in self.NON_CATEGORY_FIELDS or field_name == key: + continue + # implicit name-based index + field = getattr(self, field_name) + if isinstance(field, VectorIndex) and ( + field_name == f"{key}_index" or field.target is col + ): + idx = field + break + if idx is not None: + col._index = idx + idx.target = col + return self + + @model_validator(mode="after") + def ensure_equal_length_cols(self) -> "DynamicTableMixin": + """ + Ensure that all columns are equal length + """ + lengths = [len(v) for v in self._categories.values()] + [len(self.id)] + assert all([length == lengths[0] for length in lengths]), ( + "Columns are not of equal length! " + f"Got colnames:\n{self.categories}\nand lengths: {lengths}" + ) + return self + class TimeSeriesReferenceVectorDataMixin(VectorDataMixin): """ diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index 32a4e5f..fb9a3e2 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type import numpy as np import pandas as pd @@ -7,7 +7,12 @@ from numpydantic import NDArray, Shape from pydantic import ValidationError from nwb_linkml.includes import hdmf -from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin +from nwb_linkml.includes.hdmf import ( + AlignedDynamicTableMixin, + DynamicTableMixin, + VectorDataMixin, + VectorIndexMixin, +) # FIXME: Make this just be the output of the provider by patching into import machinery from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( @@ -39,6 +44,33 @@ def basic_table() -> tuple[DynamicTableMixin, dict[str, NDArray[Shape["10"], int return MyData, cols +@pytest.fixture() +def aligned_table() -> tuple[Type[AlignedDynamicTableMixin], dict[str, DynamicTableMixin]]: + class Table1(DynamicTableMixin): + col1: hdmf.VectorData[NDArray[Shape["*"], int]] + col2: hdmf.VectorData[NDArray[Shape["*"], int]] + + class Table2(DynamicTableMixin): + col3: hdmf.VectorData[NDArray[Shape["*"], int]] + col4: hdmf.VectorData[NDArray[Shape["*"], int]] + + class Table3(DynamicTableMixin): + col5: hdmf.VectorData[NDArray[Shape["*"], int]] + col6: hdmf.VectorData[NDArray[Shape["*"], int]] + + array = np.arange(10) + + table1 = Table1(col1=array, col2=array) + table2 = Table2(col3=array, col4=array) + table3 = Table3(col5=array, col6=array) + + class AlignedTable(AlignedDynamicTableMixin): + table1: Table1 + table2: Table2 + + return AlignedTable, {"table1": table1, "table2": table2, "table3": table3} + + def test_dynamictable_mixin_indexing(basic_table): """ Can index values from a dynamictable @@ -357,6 +389,8 @@ def test_vectordata_indexing(): assert all(data[0] == 6) assert all(data[1] == 6) assert all(data[2] == 6) + with pytest.raises(ValueError, match=".*equal-length.*"): + data[0:3] = [5, 4] def test_vectordata_getattr(): @@ -506,6 +540,77 @@ def test_dynamictable_region_ragged(): assert df.equals(region[i]) +def test_aligned_dynamictable_indexing(aligned_table): + """ + Should be able to index aligned dynamic tables to yield a multi-index df + """ + AlignedTable, tables = aligned_table + atable = AlignedTable(**tables) + + row = atable[0] + assert all( + row.columns + == pd.MultiIndex.from_tuples( + [ + ("table1", "index"), + ("table1", "col1"), + ("table1", "col2"), + ("table2", "index"), + ("table2", "col3"), + ("table2", "col4"), + ("table3", "index"), + ("table3", "col5"), + ("table3", "col6"), + ] + ) + ) + for i in range(len(atable)): + vals = atable[i] + assert vals.shape == (1, 9) + assert all(vals == i) + + # mildly different, indexing with a slice. + rows = atable[0:3] + for i, row in enumerate(rows.iterrows()): + vals = row[1] + assert len(vals) == 9 + assert all(vals == i) + + # index just a single table + row = atable[0:3, "table3"] + assert all(row.columns.to_numpy() == ["col5", "col6"]) + assert row.shape == (3, 2) + + # index out of order + rows = atable[np.array([0, 2, 1])] + assert all(rows.iloc[:, 0] == [0, 2, 1]) + + +def test_mixed_aligned_dynamictable(aligned_table): + """ + Aligned dynamictable should also accept vectordata/vector index pairs + """ + + AlignedTable, cols = aligned_table + value_array, index_array = _ragged_array(10) + value_array = np.concat(value_array) + + data = hdmf.VectorData(value=value_array) + index = hdmf.VectorIndex(value=index_array) + + atable = AlignedTable(**cols, extra_col=data, extra_col_index=index) + atable[0] + assert atable[0].columns[-1] == ("extra_col", "extra_col") + + for i, row in enumerate(atable[:].extra_col.iterrows()): + array = row[1].iloc[0] + assert all(array == i) + if i > 0: + assert len(array) == index_array[i] - index_array[i - 1] + else: + assert len(array) == index_array[i] + + # -------------------------------------------------- # Model-based tests # --------------------------------------------------