diff --git a/docs/meta/todo.md b/docs/meta/todo.md index 6508c62..3d3efcb 100644 --- a/docs/meta/todo.md +++ b/docs/meta/todo.md @@ -7,6 +7,7 @@ NWB schema translation - handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask - handle compound `dtype` like in TimeSeriesReferenceVectorData - Create a validator that checks if all the lists in a compound dtype dataset are same length +- [ ] Make `target` optional in vectorIndex Cleanup - [ ] Update pydantic generator diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index e4534de..c34f757 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -5,9 +5,19 @@ Special types for mimicking HDMF special case behavior from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport -from numpydantic import NDArray +from numpydantic import NDArray, Shape from pandas import DataFrame, Series -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + model_validator, + field_validator, + ValidatorFunctionWrapHandler, + ValidationError, + ValidationInfo, +) +import numpy as np if TYPE_CHECKING: from nwb_linkml.models import VectorData, VectorIndex @@ -31,6 +41,7 @@ class DynamicTableMixin(BaseModel): # overridden by subclass but implemented here for testing and typechecking purposes :) colnames: List[str] = Field(default_factory=list) + id: Optional[NDArray[Shape["* num_rows"], int]] = None @property def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]: @@ -143,7 +154,28 @@ class DynamicTableMixin(BaseModel): @model_validator(mode="before") @classmethod - def create_colnames(cls, model: Dict[str, Any]) -> None: + def create_id(cls, model: Dict[str, Any]) -> Dict: + """ + Create ID column if not provided + """ + if "id" not in model: + lengths = [] + for key, val in model.items(): + # don't get lengths of columns with an index + if ( + f"{key}_index" in model + or (isinstance(val, VectorData) and val._index) + or key in cls.NON_COLUMN_FIELDS + ): + continue + lengths.append(len(val)) + model["id"] = np.arange(np.max(lengths)) + + return model + + @model_validator(mode="before") + @classmethod + def create_colnames(cls, model: Dict[str, Any]) -> Dict: """ Construct colnames from arguments. @@ -167,6 +199,12 @@ class DynamicTableMixin(BaseModel): model["colnames"].extend(colnames) return model + @model_validator(mode="before") + def create_id(cls, model: Dict[str, Any]) -> Dict: + """ + If an id column is not given, create one as an arange. + """ + @model_validator(mode="after") def resolve_targets(self) -> "DynamicTableMixin": """ @@ -189,6 +227,38 @@ class DynamicTableMixin(BaseModel): idx.target = col return self + @model_validator(mode="after") + def ensure_equal_length_cols(self) -> "DynamicTableMixin": + """ + Ensure that all columns are equal length + """ + lengths = [len(v) for v in self._columns.values()] + assert [l == lengths[0] for l in lengths], ( + "Columns are not of equal length! " + f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" + ) + return self + + @field_validator("*", mode="wrap") + @classmethod + def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo): + """ + If columns are supplied as arrays, try casting them to the type before validating + """ + try: + return handler(val) + except ValidationError: + annotation = cls.model_fields[info.field_name].annotation + if type(annotation).__name__ == "_UnionGenericAlias": + annotation = annotation.__args__[0] + return handler( + annotation( + val, + name=info.field_name, + description=cls.model_fields[info.field_name].description, + ) + ) + class VectorDataMixin(BaseModel): """ @@ -200,6 +270,11 @@ class VectorDataMixin(BaseModel): # redefined in `VectorData`, but included here for testing and type checking value: Optional[NDArray] = None + def __init__(self, value: Optional[NDArray] = None, **kwargs): + if value is not None and "value" not in kwargs: + kwargs["value"] = value + super().__init__(**kwargs) + def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: if self._index: # Following hdmf, VectorIndex is the thing that knows how to do the slicing @@ -214,6 +289,27 @@ class VectorDataMixin(BaseModel): else: self.value[key] = value + def __getattr__(self, item: str) -> Any: + """ + Forward getattr to ``value`` + """ + try: + return BaseModel.__getattr__(self, item) + except AttributeError as e: + try: + return getattr(self.value, item) + except AttributeError: + raise e + + def __len__(self) -> int: + """ + Use index as length, if present + """ + if self._index: + return len(self._index) + else: + return len(self.value) + class VectorIndexMixin(BaseModel): """ @@ -224,6 +320,11 @@ class VectorIndexMixin(BaseModel): value: Optional[NDArray] = None target: Optional["VectorData"] = None + def __init__(self, value: Optional[NDArray] = None, **kwargs): + if value is not None and "value" not in kwargs: + kwargs["value"] = value + super().__init__(**kwargs) + def _getitem_helper(self, arg: int) -> Union[list, NDArray]: """ Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper` @@ -231,19 +332,19 @@ class VectorIndexMixin(BaseModel): start = 0 if arg == 0 else self.value[arg - 1] end = self.value[arg] - return self.target.array[slice(start, end)] + return [self.target.value[slice(start, end)]] def __getitem__(self, item: Union[int, slice]) -> Any: if self.target is None: return self.value[item] - elif type(self.target).__name__ == "VectorData": + elif isinstance(self.target, VectorData): if isinstance(item, int): return self._getitem_helper(item) else: idx = range(*item.indices(len(self.value))) return [self._getitem_helper(i) for i in idx] else: - raise NotImplementedError("DynamicTableRange not supported yet") + raise AttributeError(f"Could not index with {item}") def __setitem__(self, key: Union[int, slice], value: Any) -> None: if self._index: @@ -252,6 +353,24 @@ class VectorIndexMixin(BaseModel): else: self.value[key] = value + def __getattr__(self, item: str) -> Any: + """ + Forward getattr to ``value`` + """ + try: + return BaseModel.__getattr__(self, item) + except AttributeError as e: + try: + return getattr(self.value, item) + except AttributeError: + raise e + + def __len__(self) -> int: + """ + Get length from value + """ + return len(self.value) + DYNAMIC_TABLE_IMPORTS = Imports( imports=[ @@ -266,8 +385,20 @@ DYNAMIC_TABLE_IMPORTS = Imports( ObjectImport(name="Tuple"), ], ), - Import(module="numpydantic", objects=[ObjectImport(name="NDArray")]), - Import(module="pydantic", objects=[ObjectImport(name="model_validator")]), + Import( + module="numpydantic", objects=[ObjectImport(name="NDArray"), ObjectImport(name="Shape")] + ), + Import( + module="pydantic", + objects=[ + ObjectImport(name="model_validator"), + ObjectImport(name="field_validator"), + ObjectImport(name="ValidationInfo"), + ObjectImport(name="ValidatorFunctionWrapHandler"), + ObjectImport(name="ValidationError"), + ], + ), + Import(module="numpy", alias="np"), ] ) """ diff --git a/nwb_linkml/src/nwb_linkml/includes/types.py b/nwb_linkml/src/nwb_linkml/includes/types.py index 049aa65..2604eb5 100644 --- a/nwb_linkml/src/nwb_linkml/includes/types.py +++ b/nwb_linkml/src/nwb_linkml/includes/types.py @@ -19,7 +19,7 @@ ModelTypeString = """ModelType = TypeVar("ModelType", bound=Type[BaseModel])""" def _get_name(item: ModelType | dict, info: ValidationInfo) -> Union[ModelType, dict]: """Get the name of the slot that refers to this object""" - assert isinstance(item, (BaseModel, dict)) + assert isinstance(item, (BaseModel, dict)), f"{item} was not a BaseModel or a dict!" name = info.field_name if isinstance(item, BaseModel): item.name = name diff --git a/nwb_linkml/src/nwb_linkml/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py b/nwb_linkml/src/nwb_linkml/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py index ef6ba01..7dbb253 100644 --- a/nwb_linkml/src/nwb_linkml/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py +++ b/nwb_linkml/src/nwb_linkml/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py @@ -1,15 +1,22 @@ from __future__ import annotations -from datetime import datetime, date -from decimal import Decimal -from enum import Enum -import re -import sys -import numpy as np -from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container + + +from ...hdmf_common.v1_8_0.hdmf_common_base import Data from pandas import DataFrame, Series -from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple -from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator +from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple +from pydantic import ( + BaseModel, + ConfigDict, + Field, + RootModel, + model_validator, + field_validator, + ValidationInfo, + ValidatorFunctionWrapHandler, + ValidationError, +) from numpydantic import NDArray, Shape +import numpy as np metamodel_version = "None" version = "1.8.0" @@ -60,6 +67,11 @@ class VectorDataMixin(BaseModel): # redefined in `VectorData`, but included here for testing and type checking value: Optional[NDArray] = None + def __init__(self, value: Optional[NDArray] = None, **kwargs): + if value is not None and "value" not in kwargs: + kwargs["value"] = value + super().__init__(**kwargs) + def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: if self._index: # Following hdmf, VectorIndex is the thing that knows how to do the slicing @@ -74,6 +86,27 @@ class VectorDataMixin(BaseModel): else: self.value[key] = value + def __getattr__(self, item: str) -> Any: + """ + Forward getattr to ``value`` + """ + try: + return BaseModel.__getattr__(self, item) + except AttributeError as e: + try: + return getattr(self.value, item) + except AttributeError: + raise e + + def __len__(self) -> int: + """ + Use index as length, if present + """ + if self._index: + return len(self._index) + else: + return len(self.value) + class VectorIndexMixin(BaseModel): """ @@ -84,6 +117,11 @@ class VectorIndexMixin(BaseModel): value: Optional[NDArray] = None target: Optional["VectorData"] = None + def __init__(self, value: Optional[NDArray] = None, **kwargs): + if value is not None and "value" not in kwargs: + kwargs["value"] = value + super().__init__(**kwargs) + def _getitem_helper(self, arg: int) -> Union[list, NDArray]: """ Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper` @@ -91,12 +129,12 @@ class VectorIndexMixin(BaseModel): start = 0 if arg == 0 else self.value[arg - 1] end = self.value[arg] - return self.target.array[slice(start, end)] + return self.target.value[slice(start, end)] def __getitem__(self, item: Union[int, slice]) -> Any: if self.target is None: return self.value[item] - elif type(self.target).__name__ == "VectorData": + elif isinstance(self.target, VectorData): if isinstance(item, int): return self._getitem_helper(item) else: @@ -112,6 +150,24 @@ class VectorIndexMixin(BaseModel): else: self.value[key] = value + def __getattr__(self, item: str) -> Any: + """ + Forward getattr to ``value`` + """ + try: + return BaseModel.__getattr__(self, item) + except AttributeError as e: + try: + return getattr(self.value, item) + except AttributeError: + raise e + + def __len__(self) -> int: + """ + Get length from value + """ + return len(self.value) + class DynamicTableMixin(BaseModel): """ @@ -131,6 +187,7 @@ class DynamicTableMixin(BaseModel): # overridden by subclass but implemented here for testing and typechecking purposes :) colnames: List[str] = Field(default_factory=list) + id: Optional[NDArray[Shape["* num_rows"], int]] = None @property def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]: @@ -222,6 +279,10 @@ class DynamicTableMixin(BaseModel): # special case where pandas will unpack a pydantic model # into {n_fields} rows, rather than keeping it in a dict val = Series([val]) + elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1: + # special case where we are returning a row in a ragged array, + # same as above - prevent pandas pivoting to long + val = Series([val]) data[k] = val return data @@ -241,9 +302,40 @@ class DynamicTableMixin(BaseModel): return super().__setattr__(key, value) + def __getattr__(self, item): + """Try and use pandas df attrs if we don't have them""" + try: + return BaseModel.__getattr__(self, item) + except AttributeError as e: + try: + return getattr(self[:, :], item) + except AttributeError: + raise e + @model_validator(mode="before") @classmethod - def create_colnames(cls, model: Dict[str, Any]) -> None: + def create_id(cls, model: Dict[str, Any]) -> Dict: + """ + Create ID column if not provided + """ + if "id" not in model: + lengths = [] + for key, val in model.items(): + # don't get lengths of columns with an index + if ( + f"{key}_index" in model + or (isinstance(val, VectorData) and val._index) + or key in cls.NON_COLUMN_FIELDS + ): + continue + lengths.append(len(val)) + model["id"] = np.arange(np.max(lengths)) + + return model + + @model_validator(mode="before") + @classmethod + def create_colnames(cls, model: Dict[str, Any]) -> Dict: """ Construct colnames from arguments. @@ -289,6 +381,38 @@ class DynamicTableMixin(BaseModel): idx.target = col return self + @model_validator(mode="after") + def ensure_equal_length_cols(self) -> "DynamicTableMixin": + """ + Ensure that all columns are equal length + """ + lengths = [len(v) for v in self._columns.values()] + assert [l == lengths[0] for l in lengths], ( + "Columns are not of equal length! " + f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" + ) + return self + + @field_validator("*", mode="wrap") + @classmethod + def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo): + """ + If columns are supplied as arrays, try casting them to the type before validating + """ + try: + return handler(val) + except ValidationError: + annotation = cls.model_fields[info.field_name].annotation + if type(annotation).__name__ == "_UnionGenericAlias": + annotation = annotation.__args__[0] + return handler( + annotation( + val, + name=info.field_name, + description=cls.model_fields[info.field_name].description, + ) + ) + linkml_meta = LinkMLMeta( { @@ -335,8 +459,8 @@ class VectorIndex(VectorIndexMixin): ) name: str = Field(...) - target: VectorData = Field( - ..., description="""Reference to the target dataset that this index applies to.""" + target: Optional[VectorData] = Field( + None, description="""Reference to the target dataset that this index applies to.""" ) description: str = Field(..., description="""Description of what these vectors represent.""") value: Optional[ diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index 2a3b1d0..557e7db 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -10,6 +10,7 @@ from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( ElectricalSeries, ElectrodeGroup, ExtracellularEphysElectrodes, + Units, ) @@ -56,6 +57,40 @@ def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrod return electrical_series, electrodes +@pytest.fixture(params=[True, False]) +def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]: + """ + Test case for units + + Parameterized by extra_column because pandas likes to pivot dataframes + to long when there is only one column and it's not len() == 1 + """ + + n_units = 24 + spike_times = [ + np.full(shape=np.random.randint(10, 50), fill_value=i, dtype=float) for i in range(n_units) + ] + spike_idx = [] + for i in range(n_units): + if i == 0: + spike_idx.append(len(spike_times[0])) + else: + spike_idx.append(len(spike_times[i]) + spike_idx[i - 1]) + spike_idx = np.array(spike_idx) + + spike_times_flat = np.concatenate(spike_times) + + kwargs = { + "description": "units!!!!", + "spike_times": spike_times_flat, + "spike_times_index": spike_idx, + } + if request.param: + kwargs["extra_column"] = ["hey!"] * n_units + units = Units(**kwargs) + return units, spike_times, spike_idx + + def test_dynamictable_indexing(electrical_series): """ Can index values from a dynamictable @@ -106,6 +141,24 @@ def test_dynamictable_indexing(electrical_series): assert subsection.dtypes.values.tolist() == dtypes[0:3] +def test_dynamictable_ragged_arrays(units): + """ + Should be able to index ragged arrays using an implicit _index column + + Also tests: + - passing arrays directly instead of wrapping in vectordata/index specifically, + if the models in the fixture instantiate then this works + """ + units, spike_times, spike_idx = units + + # ensure we don't pivot to long when indexing + assert units[0].shape[0] == 1 + # check that we got the indexing boundaries corrunect + # (and that we are forwarding attr calls to the dataframe by accessing shape + for i in range(units.shape[0]): + assert np.all(units.iloc[i, 0] == spike_times[i]) + + def test_dynamictable_append_column(): pass