working ragged array indexing before rebuilding models

This commit is contained in:
sneakers-the-rat 2024-08-06 19:44:04 -07:00
parent fbb06dac52
commit a11d3d042e
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 332 additions and 23 deletions

View file

@ -7,6 +7,7 @@ NWB schema translation
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask - handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
- handle compound `dtype` like in TimeSeriesReferenceVectorData - handle compound `dtype` like in TimeSeriesReferenceVectorData
- Create a validator that checks if all the lists in a compound dtype dataset are same length - Create a validator that checks if all the lists in a compound dtype dataset are same length
- [ ] Make `target` optional in vectorIndex
Cleanup Cleanup
- [ ] Update pydantic generator - [ ] Update pydantic generator

View file

@ -5,9 +5,19 @@ Special types for mimicking HDMF special case behavior
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
from numpydantic import NDArray from numpydantic import NDArray, Shape
from pandas import DataFrame, Series from pandas import DataFrame, Series
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
field_validator,
ValidatorFunctionWrapHandler,
ValidationError,
ValidationInfo,
)
import numpy as np
if TYPE_CHECKING: if TYPE_CHECKING:
from nwb_linkml.models import VectorData, VectorIndex from nwb_linkml.models import VectorData, VectorIndex
@ -31,6 +41,7 @@ class DynamicTableMixin(BaseModel):
# overridden by subclass but implemented here for testing and typechecking purposes :) # overridden by subclass but implemented here for testing and typechecking purposes :)
colnames: List[str] = Field(default_factory=list) colnames: List[str] = Field(default_factory=list)
id: Optional[NDArray[Shape["* num_rows"], int]] = None
@property @property
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]: def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
@ -143,7 +154,28 @@ class DynamicTableMixin(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def create_colnames(cls, model: Dict[str, Any]) -> None: def create_id(cls, model: Dict[str, Any]) -> Dict:
"""
Create ID column if not provided
"""
if "id" not in model:
lengths = []
for key, val in model.items():
# don't get lengths of columns with an index
if (
f"{key}_index" in model
or (isinstance(val, VectorData) and val._index)
or key in cls.NON_COLUMN_FIELDS
):
continue
lengths.append(len(val))
model["id"] = np.arange(np.max(lengths))
return model
@model_validator(mode="before")
@classmethod
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
""" """
Construct colnames from arguments. Construct colnames from arguments.
@ -167,6 +199,12 @@ class DynamicTableMixin(BaseModel):
model["colnames"].extend(colnames) model["colnames"].extend(colnames)
return model return model
@model_validator(mode="before")
def create_id(cls, model: Dict[str, Any]) -> Dict:
"""
If an id column is not given, create one as an arange.
"""
@model_validator(mode="after") @model_validator(mode="after")
def resolve_targets(self) -> "DynamicTableMixin": def resolve_targets(self) -> "DynamicTableMixin":
""" """
@ -189,6 +227,38 @@ class DynamicTableMixin(BaseModel):
idx.target = col idx.target = col
return self return self
@model_validator(mode="after")
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
"""
Ensure that all columns are equal length
"""
lengths = [len(v) for v in self._columns.values()]
assert [l == lengths[0] for l in lengths], (
"Columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
)
return self
@field_validator("*", mode="wrap")
@classmethod
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
"""
If columns are supplied as arrays, try casting them to the type before validating
"""
try:
return handler(val)
except ValidationError:
annotation = cls.model_fields[info.field_name].annotation
if type(annotation).__name__ == "_UnionGenericAlias":
annotation = annotation.__args__[0]
return handler(
annotation(
val,
name=info.field_name,
description=cls.model_fields[info.field_name].description,
)
)
class VectorDataMixin(BaseModel): class VectorDataMixin(BaseModel):
""" """
@ -200,6 +270,11 @@ class VectorDataMixin(BaseModel):
# redefined in `VectorData`, but included here for testing and type checking # redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None value: Optional[NDArray] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index: if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing # Following hdmf, VectorIndex is the thing that knows how to do the slicing
@ -214,6 +289,27 @@ class VectorDataMixin(BaseModel):
else: else:
self.value[key] = value self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Use index as length, if present
"""
if self._index:
return len(self._index)
else:
return len(self.value)
class VectorIndexMixin(BaseModel): class VectorIndexMixin(BaseModel):
""" """
@ -224,6 +320,11 @@ class VectorIndexMixin(BaseModel):
value: Optional[NDArray] = None value: Optional[NDArray] = None
target: Optional["VectorData"] = None target: Optional["VectorData"] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def _getitem_helper(self, arg: int) -> Union[list, NDArray]: def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
""" """
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper` Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
@ -231,19 +332,19 @@ class VectorIndexMixin(BaseModel):
start = 0 if arg == 0 else self.value[arg - 1] start = 0 if arg == 0 else self.value[arg - 1]
end = self.value[arg] end = self.value[arg]
return self.target.array[slice(start, end)] return [self.target.value[slice(start, end)]]
def __getitem__(self, item: Union[int, slice]) -> Any: def __getitem__(self, item: Union[int, slice]) -> Any:
if self.target is None: if self.target is None:
return self.value[item] return self.value[item]
elif type(self.target).__name__ == "VectorData": elif isinstance(self.target, VectorData):
if isinstance(item, int): if isinstance(item, int):
return self._getitem_helper(item) return self._getitem_helper(item)
else: else:
idx = range(*item.indices(len(self.value))) idx = range(*item.indices(len(self.value)))
return [self._getitem_helper(i) for i in idx] return [self._getitem_helper(i) for i in idx]
else: else:
raise NotImplementedError("DynamicTableRange not supported yet") raise AttributeError(f"Could not index with {item}")
def __setitem__(self, key: Union[int, slice], value: Any) -> None: def __setitem__(self, key: Union[int, slice], value: Any) -> None:
if self._index: if self._index:
@ -252,6 +353,24 @@ class VectorIndexMixin(BaseModel):
else: else:
self.value[key] = value self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Get length from value
"""
return len(self.value)
DYNAMIC_TABLE_IMPORTS = Imports( DYNAMIC_TABLE_IMPORTS = Imports(
imports=[ imports=[
@ -266,8 +385,20 @@ DYNAMIC_TABLE_IMPORTS = Imports(
ObjectImport(name="Tuple"), ObjectImport(name="Tuple"),
], ],
), ),
Import(module="numpydantic", objects=[ObjectImport(name="NDArray")]), Import(
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]), module="numpydantic", objects=[ObjectImport(name="NDArray"), ObjectImport(name="Shape")]
),
Import(
module="pydantic",
objects=[
ObjectImport(name="model_validator"),
ObjectImport(name="field_validator"),
ObjectImport(name="ValidationInfo"),
ObjectImport(name="ValidatorFunctionWrapHandler"),
ObjectImport(name="ValidationError"),
],
),
Import(module="numpy", alias="np"),
] ]
) )
""" """

View file

@ -19,7 +19,7 @@ ModelTypeString = """ModelType = TypeVar("ModelType", bound=Type[BaseModel])"""
def _get_name(item: ModelType | dict, info: ValidationInfo) -> Union[ModelType, dict]: def _get_name(item: ModelType | dict, info: ValidationInfo) -> Union[ModelType, dict]:
"""Get the name of the slot that refers to this object""" """Get the name of the slot that refers to this object"""
assert isinstance(item, (BaseModel, dict)) assert isinstance(item, (BaseModel, dict)), f"{item} was not a BaseModel or a dict!"
name = info.field_name name = info.field_name
if isinstance(item, BaseModel): if isinstance(item, BaseModel):
item.name = name item.name = name

View file

@ -1,15 +1,22 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal
from enum import Enum from ...hdmf_common.v1_8_0.hdmf_common_base import Data
import re
import sys
import numpy as np
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
from pandas import DataFrame, Series from pandas import DataFrame, Series
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator from pydantic import (
BaseModel,
ConfigDict,
Field,
RootModel,
model_validator,
field_validator,
ValidationInfo,
ValidatorFunctionWrapHandler,
ValidationError,
)
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
import numpy as np
metamodel_version = "None" metamodel_version = "None"
version = "1.8.0" version = "1.8.0"
@ -60,6 +67,11 @@ class VectorDataMixin(BaseModel):
# redefined in `VectorData`, but included here for testing and type checking # redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None value: Optional[NDArray] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index: if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing # Following hdmf, VectorIndex is the thing that knows how to do the slicing
@ -74,6 +86,27 @@ class VectorDataMixin(BaseModel):
else: else:
self.value[key] = value self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Use index as length, if present
"""
if self._index:
return len(self._index)
else:
return len(self.value)
class VectorIndexMixin(BaseModel): class VectorIndexMixin(BaseModel):
""" """
@ -84,6 +117,11 @@ class VectorIndexMixin(BaseModel):
value: Optional[NDArray] = None value: Optional[NDArray] = None
target: Optional["VectorData"] = None target: Optional["VectorData"] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def _getitem_helper(self, arg: int) -> Union[list, NDArray]: def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
""" """
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper` Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
@ -91,12 +129,12 @@ class VectorIndexMixin(BaseModel):
start = 0 if arg == 0 else self.value[arg - 1] start = 0 if arg == 0 else self.value[arg - 1]
end = self.value[arg] end = self.value[arg]
return self.target.array[slice(start, end)] return self.target.value[slice(start, end)]
def __getitem__(self, item: Union[int, slice]) -> Any: def __getitem__(self, item: Union[int, slice]) -> Any:
if self.target is None: if self.target is None:
return self.value[item] return self.value[item]
elif type(self.target).__name__ == "VectorData": elif isinstance(self.target, VectorData):
if isinstance(item, int): if isinstance(item, int):
return self._getitem_helper(item) return self._getitem_helper(item)
else: else:
@ -112,6 +150,24 @@ class VectorIndexMixin(BaseModel):
else: else:
self.value[key] = value self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Get length from value
"""
return len(self.value)
class DynamicTableMixin(BaseModel): class DynamicTableMixin(BaseModel):
""" """
@ -131,6 +187,7 @@ class DynamicTableMixin(BaseModel):
# overridden by subclass but implemented here for testing and typechecking purposes :) # overridden by subclass but implemented here for testing and typechecking purposes :)
colnames: List[str] = Field(default_factory=list) colnames: List[str] = Field(default_factory=list)
id: Optional[NDArray[Shape["* num_rows"], int]] = None
@property @property
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]: def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
@ -222,6 +279,10 @@ class DynamicTableMixin(BaseModel):
# special case where pandas will unpack a pydantic model # special case where pandas will unpack a pydantic model
# into {n_fields} rows, rather than keeping it in a dict # into {n_fields} rows, rather than keeping it in a dict
val = Series([val]) val = Series([val])
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
# special case where we are returning a row in a ragged array,
# same as above - prevent pandas pivoting to long
val = Series([val])
data[k] = val data[k] = val
return data return data
@ -241,9 +302,40 @@ class DynamicTableMixin(BaseModel):
return super().__setattr__(key, value) return super().__setattr__(key, value)
def __getattr__(self, item):
"""Try and use pandas df attrs if we don't have them"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self[:, :], item)
except AttributeError:
raise e
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def create_colnames(cls, model: Dict[str, Any]) -> None: def create_id(cls, model: Dict[str, Any]) -> Dict:
"""
Create ID column if not provided
"""
if "id" not in model:
lengths = []
for key, val in model.items():
# don't get lengths of columns with an index
if (
f"{key}_index" in model
or (isinstance(val, VectorData) and val._index)
or key in cls.NON_COLUMN_FIELDS
):
continue
lengths.append(len(val))
model["id"] = np.arange(np.max(lengths))
return model
@model_validator(mode="before")
@classmethod
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
""" """
Construct colnames from arguments. Construct colnames from arguments.
@ -289,6 +381,38 @@ class DynamicTableMixin(BaseModel):
idx.target = col idx.target = col
return self return self
@model_validator(mode="after")
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
"""
Ensure that all columns are equal length
"""
lengths = [len(v) for v in self._columns.values()]
assert [l == lengths[0] for l in lengths], (
"Columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
)
return self
@field_validator("*", mode="wrap")
@classmethod
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
"""
If columns are supplied as arrays, try casting them to the type before validating
"""
try:
return handler(val)
except ValidationError:
annotation = cls.model_fields[info.field_name].annotation
if type(annotation).__name__ == "_UnionGenericAlias":
annotation = annotation.__args__[0]
return handler(
annotation(
val,
name=info.field_name,
description=cls.model_fields[info.field_name].description,
)
)
linkml_meta = LinkMLMeta( linkml_meta = LinkMLMeta(
{ {
@ -335,8 +459,8 @@ class VectorIndex(VectorIndexMixin):
) )
name: str = Field(...) name: str = Field(...)
target: VectorData = Field( target: Optional[VectorData] = Field(
..., description="""Reference to the target dataset that this index applies to.""" None, description="""Reference to the target dataset that this index applies to."""
) )
description: str = Field(..., description="""Description of what these vectors represent.""") description: str = Field(..., description="""Description of what these vectors represent.""")
value: Optional[ value: Optional[

View file

@ -10,6 +10,7 @@ from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
ElectricalSeries, ElectricalSeries,
ElectrodeGroup, ElectrodeGroup,
ExtracellularEphysElectrodes, ExtracellularEphysElectrodes,
Units,
) )
@ -56,6 +57,40 @@ def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrod
return electrical_series, electrodes return electrical_series, electrodes
@pytest.fixture(params=[True, False])
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
"""
Test case for units
Parameterized by extra_column because pandas likes to pivot dataframes
to long when there is only one column and it's not len() == 1
"""
n_units = 24
spike_times = [
np.full(shape=np.random.randint(10, 50), fill_value=i, dtype=float) for i in range(n_units)
]
spike_idx = []
for i in range(n_units):
if i == 0:
spike_idx.append(len(spike_times[0]))
else:
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
spike_idx = np.array(spike_idx)
spike_times_flat = np.concatenate(spike_times)
kwargs = {
"description": "units!!!!",
"spike_times": spike_times_flat,
"spike_times_index": spike_idx,
}
if request.param:
kwargs["extra_column"] = ["hey!"] * n_units
units = Units(**kwargs)
return units, spike_times, spike_idx
def test_dynamictable_indexing(electrical_series): def test_dynamictable_indexing(electrical_series):
""" """
Can index values from a dynamictable Can index values from a dynamictable
@ -106,6 +141,24 @@ def test_dynamictable_indexing(electrical_series):
assert subsection.dtypes.values.tolist() == dtypes[0:3] assert subsection.dtypes.values.tolist() == dtypes[0:3]
def test_dynamictable_ragged_arrays(units):
"""
Should be able to index ragged arrays using an implicit _index column
Also tests:
- passing arrays directly instead of wrapping in vectordata/index specifically,
if the models in the fixture instantiate then this works
"""
units, spike_times, spike_idx = units
# ensure we don't pivot to long when indexing
assert units[0].shape[0] == 1
# check that we got the indexing boundaries corrunect
# (and that we are forwarding attr calls to the dataframe by accessing shape
for i in range(units.shape[0]):
assert np.all(units.iloc[i, 0] == spike_times[i])
def test_dynamictable_append_column(): def test_dynamictable_append_column():
pass pass