mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
working ragged array indexing before rebuilding models
This commit is contained in:
parent
fbb06dac52
commit
a11d3d042e
5 changed files with 332 additions and 23 deletions
|
@ -7,6 +7,7 @@ NWB schema translation
|
|||
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
|
||||
- handle compound `dtype` like in TimeSeriesReferenceVectorData
|
||||
- Create a validator that checks if all the lists in a compound dtype dataset are same length
|
||||
- [ ] Make `target` optional in vectorIndex
|
||||
|
||||
Cleanup
|
||||
- [ ] Update pydantic generator
|
||||
|
|
|
@ -5,9 +5,19 @@ Special types for mimicking HDMF special case behavior
|
|||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
|
||||
|
||||
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
|
||||
from numpydantic import NDArray
|
||||
from numpydantic import NDArray, Shape
|
||||
from pandas import DataFrame, Series
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
field_validator,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
ValidationInfo,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nwb_linkml.models import VectorData, VectorIndex
|
||||
|
@ -31,6 +41,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -143,7 +154,28 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -167,6 +199,12 @@ class DynamicTableMixin(BaseModel):
|
|||
model["colnames"].extend(colnames)
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
If an id column is not given, create one as an arange.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_targets(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
|
@ -189,6 +227,38 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [l == lengths[0] for l in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class VectorDataMixin(BaseModel):
|
||||
"""
|
||||
|
@ -200,6 +270,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -214,6 +289,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -224,6 +320,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -231,19 +332,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return [self.target.value[slice(start, end)]]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -252,6 +353,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
DYNAMIC_TABLE_IMPORTS = Imports(
|
||||
imports=[
|
||||
|
@ -266,8 +385,20 @@ DYNAMIC_TABLE_IMPORTS = Imports(
|
|||
ObjectImport(name="Tuple"),
|
||||
],
|
||||
),
|
||||
Import(module="numpydantic", objects=[ObjectImport(name="NDArray")]),
|
||||
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
|
||||
Import(
|
||||
module="numpydantic", objects=[ObjectImport(name="NDArray"), ObjectImport(name="Shape")]
|
||||
),
|
||||
Import(
|
||||
module="pydantic",
|
||||
objects=[
|
||||
ObjectImport(name="model_validator"),
|
||||
ObjectImport(name="field_validator"),
|
||||
ObjectImport(name="ValidationInfo"),
|
||||
ObjectImport(name="ValidatorFunctionWrapHandler"),
|
||||
ObjectImport(name="ValidationError"),
|
||||
],
|
||||
),
|
||||
Import(module="numpy", alias="np"),
|
||||
]
|
||||
)
|
||||
"""
|
||||
|
|
|
@ -19,7 +19,7 @@ ModelTypeString = """ModelType = TypeVar("ModelType", bound=Type[BaseModel])"""
|
|||
|
||||
def _get_name(item: ModelType | dict, info: ValidationInfo) -> Union[ModelType, dict]:
|
||||
"""Get the name of the slot that refers to this object"""
|
||||
assert isinstance(item, (BaseModel, dict))
|
||||
assert isinstance(item, (BaseModel, dict)), f"{item} was not a BaseModel or a dict!"
|
||||
name = info.field_name
|
||||
if isinstance(item, BaseModel):
|
||||
item.name = name
|
||||
|
|
|
@ -1,15 +1,22 @@
|
|||
from __future__ import annotations
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
|
||||
|
||||
|
||||
from ...hdmf_common.v1_8_0.hdmf_common_base import Data
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
model_validator,
|
||||
field_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
from numpydantic import NDArray, Shape
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.8.0"
|
||||
|
@ -60,6 +67,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +86,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +117,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,12 +129,12 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
|
@ -112,6 +150,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +187,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +279,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +302,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +381,38 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [l == lengths[0] for l in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +459,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -10,6 +10,7 @@ from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
|||
ElectricalSeries,
|
||||
ElectrodeGroup,
|
||||
ExtracellularEphysElectrodes,
|
||||
Units,
|
||||
)
|
||||
|
||||
|
||||
|
@ -56,6 +57,40 @@ def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrod
|
|||
return electrical_series, electrodes
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
Test case for units
|
||||
|
||||
Parameterized by extra_column because pandas likes to pivot dataframes
|
||||
to long when there is only one column and it's not len() == 1
|
||||
"""
|
||||
|
||||
n_units = 24
|
||||
spike_times = [
|
||||
np.full(shape=np.random.randint(10, 50), fill_value=i, dtype=float) for i in range(n_units)
|
||||
]
|
||||
spike_idx = []
|
||||
for i in range(n_units):
|
||||
if i == 0:
|
||||
spike_idx.append(len(spike_times[0]))
|
||||
else:
|
||||
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
|
||||
spike_idx = np.array(spike_idx)
|
||||
|
||||
spike_times_flat = np.concatenate(spike_times)
|
||||
|
||||
kwargs = {
|
||||
"description": "units!!!!",
|
||||
"spike_times": spike_times_flat,
|
||||
"spike_times_index": spike_idx,
|
||||
}
|
||||
if request.param:
|
||||
kwargs["extra_column"] = ["hey!"] * n_units
|
||||
units = Units(**kwargs)
|
||||
return units, spike_times, spike_idx
|
||||
|
||||
|
||||
def test_dynamictable_indexing(electrical_series):
|
||||
"""
|
||||
Can index values from a dynamictable
|
||||
|
@ -106,6 +141,24 @@ def test_dynamictable_indexing(electrical_series):
|
|||
assert subsection.dtypes.values.tolist() == dtypes[0:3]
|
||||
|
||||
|
||||
def test_dynamictable_ragged_arrays(units):
|
||||
"""
|
||||
Should be able to index ragged arrays using an implicit _index column
|
||||
|
||||
Also tests:
|
||||
- passing arrays directly instead of wrapping in vectordata/index specifically,
|
||||
if the models in the fixture instantiate then this works
|
||||
"""
|
||||
units, spike_times, spike_idx = units
|
||||
|
||||
# ensure we don't pivot to long when indexing
|
||||
assert units[0].shape[0] == 1
|
||||
# check that we got the indexing boundaries corrunect
|
||||
# (and that we are forwarding attr calls to the dataframe by accessing shape
|
||||
for i in range(units.shape[0]):
|
||||
assert np.all(units.iloc[i, 0] == spike_times[i])
|
||||
|
||||
|
||||
def test_dynamictable_append_column():
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in a new issue