working ragged array indexing before rebuilding models

This commit is contained in:
sneakers-the-rat 2024-08-06 19:44:04 -07:00
parent fbb06dac52
commit a11d3d042e
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 332 additions and 23 deletions

View file

@ -7,6 +7,7 @@ NWB schema translation
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
- handle compound `dtype` like in TimeSeriesReferenceVectorData
- Create a validator that checks if all the lists in a compound dtype dataset are same length
- [ ] Make `target` optional in vectorIndex
Cleanup
- [ ] Update pydantic generator

View file

@ -5,9 +5,19 @@ Special types for mimicking HDMF special case behavior
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
from numpydantic import NDArray
from numpydantic import NDArray, Shape
from pandas import DataFrame, Series
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
model_validator,
field_validator,
ValidatorFunctionWrapHandler,
ValidationError,
ValidationInfo,
)
import numpy as np
if TYPE_CHECKING:
from nwb_linkml.models import VectorData, VectorIndex
@ -31,6 +41,7 @@ class DynamicTableMixin(BaseModel):
# overridden by subclass but implemented here for testing and typechecking purposes :)
colnames: List[str] = Field(default_factory=list)
id: Optional[NDArray[Shape["* num_rows"], int]] = None
@property
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
@ -143,7 +154,28 @@ class DynamicTableMixin(BaseModel):
@model_validator(mode="before")
@classmethod
def create_colnames(cls, model: Dict[str, Any]) -> None:
def create_id(cls, model: Dict[str, Any]) -> Dict:
"""
Create ID column if not provided
"""
if "id" not in model:
lengths = []
for key, val in model.items():
# don't get lengths of columns with an index
if (
f"{key}_index" in model
or (isinstance(val, VectorData) and val._index)
or key in cls.NON_COLUMN_FIELDS
):
continue
lengths.append(len(val))
model["id"] = np.arange(np.max(lengths))
return model
@model_validator(mode="before")
@classmethod
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
"""
Construct colnames from arguments.
@ -167,6 +199,12 @@ class DynamicTableMixin(BaseModel):
model["colnames"].extend(colnames)
return model
@model_validator(mode="before")
def create_id(cls, model: Dict[str, Any]) -> Dict:
"""
If an id column is not given, create one as an arange.
"""
@model_validator(mode="after")
def resolve_targets(self) -> "DynamicTableMixin":
"""
@ -189,6 +227,38 @@ class DynamicTableMixin(BaseModel):
idx.target = col
return self
@model_validator(mode="after")
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
"""
Ensure that all columns are equal length
"""
lengths = [len(v) for v in self._columns.values()]
assert [l == lengths[0] for l in lengths], (
"Columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
)
return self
@field_validator("*", mode="wrap")
@classmethod
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
"""
If columns are supplied as arrays, try casting them to the type before validating
"""
try:
return handler(val)
except ValidationError:
annotation = cls.model_fields[info.field_name].annotation
if type(annotation).__name__ == "_UnionGenericAlias":
annotation = annotation.__args__[0]
return handler(
annotation(
val,
name=info.field_name,
description=cls.model_fields[info.field_name].description,
)
)
class VectorDataMixin(BaseModel):
"""
@ -200,6 +270,11 @@ class VectorDataMixin(BaseModel):
# redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
@ -214,6 +289,27 @@ class VectorDataMixin(BaseModel):
else:
self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Use index as length, if present
"""
if self._index:
return len(self._index)
else:
return len(self.value)
class VectorIndexMixin(BaseModel):
"""
@ -224,6 +320,11 @@ class VectorIndexMixin(BaseModel):
value: Optional[NDArray] = None
target: Optional["VectorData"] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
"""
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
@ -231,19 +332,19 @@ class VectorIndexMixin(BaseModel):
start = 0 if arg == 0 else self.value[arg - 1]
end = self.value[arg]
return self.target.array[slice(start, end)]
return [self.target.value[slice(start, end)]]
def __getitem__(self, item: Union[int, slice]) -> Any:
if self.target is None:
return self.value[item]
elif type(self.target).__name__ == "VectorData":
elif isinstance(self.target, VectorData):
if isinstance(item, int):
return self._getitem_helper(item)
else:
idx = range(*item.indices(len(self.value)))
return [self._getitem_helper(i) for i in idx]
else:
raise NotImplementedError("DynamicTableRange not supported yet")
raise AttributeError(f"Could not index with {item}")
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
if self._index:
@ -252,6 +353,24 @@ class VectorIndexMixin(BaseModel):
else:
self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Get length from value
"""
return len(self.value)
DYNAMIC_TABLE_IMPORTS = Imports(
imports=[
@ -266,8 +385,20 @@ DYNAMIC_TABLE_IMPORTS = Imports(
ObjectImport(name="Tuple"),
],
),
Import(module="numpydantic", objects=[ObjectImport(name="NDArray")]),
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
Import(
module="numpydantic", objects=[ObjectImport(name="NDArray"), ObjectImport(name="Shape")]
),
Import(
module="pydantic",
objects=[
ObjectImport(name="model_validator"),
ObjectImport(name="field_validator"),
ObjectImport(name="ValidationInfo"),
ObjectImport(name="ValidatorFunctionWrapHandler"),
ObjectImport(name="ValidationError"),
],
),
Import(module="numpy", alias="np"),
]
)
"""

View file

@ -19,7 +19,7 @@ ModelTypeString = """ModelType = TypeVar("ModelType", bound=Type[BaseModel])"""
def _get_name(item: ModelType | dict, info: ValidationInfo) -> Union[ModelType, dict]:
"""Get the name of the slot that refers to this object"""
assert isinstance(item, (BaseModel, dict))
assert isinstance(item, (BaseModel, dict)), f"{item} was not a BaseModel or a dict!"
name = info.field_name
if isinstance(item, BaseModel):
item.name = name

View file

@ -1,15 +1,22 @@
from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
import re
import sys
import numpy as np
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
from ...hdmf_common.v1_8_0.hdmf_common_base import Data
from pandas import DataFrame, Series
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
from pydantic import (
BaseModel,
ConfigDict,
Field,
RootModel,
model_validator,
field_validator,
ValidationInfo,
ValidatorFunctionWrapHandler,
ValidationError,
)
from numpydantic import NDArray, Shape
import numpy as np
metamodel_version = "None"
version = "1.8.0"
@ -60,6 +67,11 @@ class VectorDataMixin(BaseModel):
# redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
@ -74,6 +86,27 @@ class VectorDataMixin(BaseModel):
else:
self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Use index as length, if present
"""
if self._index:
return len(self._index)
else:
return len(self.value)
class VectorIndexMixin(BaseModel):
"""
@ -84,6 +117,11 @@ class VectorIndexMixin(BaseModel):
value: Optional[NDArray] = None
target: Optional["VectorData"] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
"""
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
@ -91,12 +129,12 @@ class VectorIndexMixin(BaseModel):
start = 0 if arg == 0 else self.value[arg - 1]
end = self.value[arg]
return self.target.array[slice(start, end)]
return self.target.value[slice(start, end)]
def __getitem__(self, item: Union[int, slice]) -> Any:
if self.target is None:
return self.value[item]
elif type(self.target).__name__ == "VectorData":
elif isinstance(self.target, VectorData):
if isinstance(item, int):
return self._getitem_helper(item)
else:
@ -112,6 +150,24 @@ class VectorIndexMixin(BaseModel):
else:
self.value[key] = value
def __getattr__(self, item: str) -> Any:
"""
Forward getattr to ``value``
"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self.value, item)
except AttributeError:
raise e
def __len__(self) -> int:
"""
Get length from value
"""
return len(self.value)
class DynamicTableMixin(BaseModel):
"""
@ -131,6 +187,7 @@ class DynamicTableMixin(BaseModel):
# overridden by subclass but implemented here for testing and typechecking purposes :)
colnames: List[str] = Field(default_factory=list)
id: Optional[NDArray[Shape["* num_rows"], int]] = None
@property
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
@ -222,6 +279,10 @@ class DynamicTableMixin(BaseModel):
# special case where pandas will unpack a pydantic model
# into {n_fields} rows, rather than keeping it in a dict
val = Series([val])
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
# special case where we are returning a row in a ragged array,
# same as above - prevent pandas pivoting to long
val = Series([val])
data[k] = val
return data
@ -241,9 +302,40 @@ class DynamicTableMixin(BaseModel):
return super().__setattr__(key, value)
def __getattr__(self, item):
"""Try and use pandas df attrs if we don't have them"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self[:, :], item)
except AttributeError:
raise e
@model_validator(mode="before")
@classmethod
def create_colnames(cls, model: Dict[str, Any]) -> None:
def create_id(cls, model: Dict[str, Any]) -> Dict:
"""
Create ID column if not provided
"""
if "id" not in model:
lengths = []
for key, val in model.items():
# don't get lengths of columns with an index
if (
f"{key}_index" in model
or (isinstance(val, VectorData) and val._index)
or key in cls.NON_COLUMN_FIELDS
):
continue
lengths.append(len(val))
model["id"] = np.arange(np.max(lengths))
return model
@model_validator(mode="before")
@classmethod
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
"""
Construct colnames from arguments.
@ -289,6 +381,38 @@ class DynamicTableMixin(BaseModel):
idx.target = col
return self
@model_validator(mode="after")
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
"""
Ensure that all columns are equal length
"""
lengths = [len(v) for v in self._columns.values()]
assert [l == lengths[0] for l in lengths], (
"Columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
)
return self
@field_validator("*", mode="wrap")
@classmethod
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
"""
If columns are supplied as arrays, try casting them to the type before validating
"""
try:
return handler(val)
except ValidationError:
annotation = cls.model_fields[info.field_name].annotation
if type(annotation).__name__ == "_UnionGenericAlias":
annotation = annotation.__args__[0]
return handler(
annotation(
val,
name=info.field_name,
description=cls.model_fields[info.field_name].description,
)
)
linkml_meta = LinkMLMeta(
{
@ -335,8 +459,8 @@ class VectorIndex(VectorIndexMixin):
)
name: str = Field(...)
target: VectorData = Field(
..., description="""Reference to the target dataset that this index applies to."""
target: Optional[VectorData] = Field(
None, description="""Reference to the target dataset that this index applies to."""
)
description: str = Field(..., description="""Description of what these vectors represent.""")
value: Optional[

View file

@ -10,6 +10,7 @@ from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
ElectricalSeries,
ElectrodeGroup,
ExtracellularEphysElectrodes,
Units,
)
@ -56,6 +57,40 @@ def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrod
return electrical_series, electrodes
@pytest.fixture(params=[True, False])
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
"""
Test case for units
Parameterized by extra_column because pandas likes to pivot dataframes
to long when there is only one column and it's not len() == 1
"""
n_units = 24
spike_times = [
np.full(shape=np.random.randint(10, 50), fill_value=i, dtype=float) for i in range(n_units)
]
spike_idx = []
for i in range(n_units):
if i == 0:
spike_idx.append(len(spike_times[0]))
else:
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
spike_idx = np.array(spike_idx)
spike_times_flat = np.concatenate(spike_times)
kwargs = {
"description": "units!!!!",
"spike_times": spike_times_flat,
"spike_times_index": spike_idx,
}
if request.param:
kwargs["extra_column"] = ["hey!"] * n_units
units = Units(**kwargs)
return units, spike_times, spike_idx
def test_dynamictable_indexing(electrical_series):
"""
Can index values from a dynamictable
@ -106,6 +141,24 @@ def test_dynamictable_indexing(electrical_series):
assert subsection.dtypes.values.tolist() == dtypes[0:3]
def test_dynamictable_ragged_arrays(units):
"""
Should be able to index ragged arrays using an implicit _index column
Also tests:
- passing arrays directly instead of wrapping in vectordata/index specifically,
if the models in the fixture instantiate then this works
"""
units, spike_times, spike_idx = units
# ensure we don't pivot to long when indexing
assert units[0].shape[0] == 1
# check that we got the indexing boundaries corrunect
# (and that we are forwarding attr calls to the dataframe by accessing shape
for i in range(units.shape[0]):
assert np.all(units.iloc[i, 0] == spike_times[i])
def test_dynamictable_append_column():
pass