mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 06:04:28 +00:00
working ragged array indexing before rebuilding models
This commit is contained in:
parent
fbb06dac52
commit
a11d3d042e
5 changed files with 332 additions and 23 deletions
|
@ -7,6 +7,7 @@ NWB schema translation
|
||||||
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
|
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
|
||||||
- handle compound `dtype` like in TimeSeriesReferenceVectorData
|
- handle compound `dtype` like in TimeSeriesReferenceVectorData
|
||||||
- Create a validator that checks if all the lists in a compound dtype dataset are same length
|
- Create a validator that checks if all the lists in a compound dtype dataset are same length
|
||||||
|
- [ ] Make `target` optional in vectorIndex
|
||||||
|
|
||||||
Cleanup
|
Cleanup
|
||||||
- [ ] Update pydantic generator
|
- [ ] Update pydantic generator
|
||||||
|
|
|
@ -5,9 +5,19 @@ Special types for mimicking HDMF special case behavior
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
|
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
|
||||||
|
|
||||||
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
|
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
|
||||||
from numpydantic import NDArray
|
from numpydantic import NDArray, Shape
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
model_validator,
|
||||||
|
field_validator,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
ValidationInfo,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nwb_linkml.models import VectorData, VectorIndex
|
from nwb_linkml.models import VectorData, VectorIndex
|
||||||
|
@ -31,6 +41,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -143,7 +154,28 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -167,6 +199,12 @@ class DynamicTableMixin(BaseModel):
|
||||||
model["colnames"].extend(colnames)
|
model["colnames"].extend(colnames)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
If an id column is not given, create one as an arange.
|
||||||
|
"""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def resolve_targets(self) -> "DynamicTableMixin":
|
def resolve_targets(self) -> "DynamicTableMixin":
|
||||||
"""
|
"""
|
||||||
|
@ -189,6 +227,38 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [l == lengths[0] for l in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VectorDataMixin(BaseModel):
|
class VectorDataMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -200,6 +270,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -214,6 +289,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -224,6 +320,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -231,19 +332,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return [self.target.value[slice(start, end)]]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -252,6 +353,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_TABLE_IMPORTS = Imports(
|
DYNAMIC_TABLE_IMPORTS = Imports(
|
||||||
imports=[
|
imports=[
|
||||||
|
@ -266,8 +385,20 @@ DYNAMIC_TABLE_IMPORTS = Imports(
|
||||||
ObjectImport(name="Tuple"),
|
ObjectImport(name="Tuple"),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
Import(module="numpydantic", objects=[ObjectImport(name="NDArray")]),
|
Import(
|
||||||
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
|
module="numpydantic", objects=[ObjectImport(name="NDArray"), ObjectImport(name="Shape")]
|
||||||
|
),
|
||||||
|
Import(
|
||||||
|
module="pydantic",
|
||||||
|
objects=[
|
||||||
|
ObjectImport(name="model_validator"),
|
||||||
|
ObjectImport(name="field_validator"),
|
||||||
|
ObjectImport(name="ValidationInfo"),
|
||||||
|
ObjectImport(name="ValidatorFunctionWrapHandler"),
|
||||||
|
ObjectImport(name="ValidationError"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Import(module="numpy", alias="np"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -19,7 +19,7 @@ ModelTypeString = """ModelType = TypeVar("ModelType", bound=Type[BaseModel])"""
|
||||||
|
|
||||||
def _get_name(item: ModelType | dict, info: ValidationInfo) -> Union[ModelType, dict]:
|
def _get_name(item: ModelType | dict, info: ValidationInfo) -> Union[ModelType, dict]:
|
||||||
"""Get the name of the slot that refers to this object"""
|
"""Get the name of the slot that refers to this object"""
|
||||||
assert isinstance(item, (BaseModel, dict))
|
assert isinstance(item, (BaseModel, dict)), f"{item} was not a BaseModel or a dict!"
|
||||||
name = info.field_name
|
name = info.field_name
|
||||||
if isinstance(item, BaseModel):
|
if isinstance(item, BaseModel):
|
||||||
item.name = name
|
item.name = name
|
||||||
|
|
|
@ -1,15 +1,22 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from datetime import datetime, date
|
|
||||||
from decimal import Decimal
|
|
||||||
from enum import Enum
|
from ...hdmf_common.v1_8_0.hdmf_common_base import Data
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
|
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
model_validator,
|
||||||
|
field_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
@ -60,6 +67,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +86,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +117,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,12 +129,12 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
|
@ -112,6 +150,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +187,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +279,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +302,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +381,38 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [l == lengths[0] for l in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +459,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -10,6 +10,7 @@ from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||||
ElectricalSeries,
|
ElectricalSeries,
|
||||||
ElectrodeGroup,
|
ElectrodeGroup,
|
||||||
ExtracellularEphysElectrodes,
|
ExtracellularEphysElectrodes,
|
||||||
|
Units,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,6 +57,40 @@ def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrod
|
||||||
return electrical_series, electrodes
|
return electrical_series, electrodes
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[True, False])
|
||||||
|
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
|
||||||
|
"""
|
||||||
|
Test case for units
|
||||||
|
|
||||||
|
Parameterized by extra_column because pandas likes to pivot dataframes
|
||||||
|
to long when there is only one column and it's not len() == 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_units = 24
|
||||||
|
spike_times = [
|
||||||
|
np.full(shape=np.random.randint(10, 50), fill_value=i, dtype=float) for i in range(n_units)
|
||||||
|
]
|
||||||
|
spike_idx = []
|
||||||
|
for i in range(n_units):
|
||||||
|
if i == 0:
|
||||||
|
spike_idx.append(len(spike_times[0]))
|
||||||
|
else:
|
||||||
|
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
|
||||||
|
spike_idx = np.array(spike_idx)
|
||||||
|
|
||||||
|
spike_times_flat = np.concatenate(spike_times)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"description": "units!!!!",
|
||||||
|
"spike_times": spike_times_flat,
|
||||||
|
"spike_times_index": spike_idx,
|
||||||
|
}
|
||||||
|
if request.param:
|
||||||
|
kwargs["extra_column"] = ["hey!"] * n_units
|
||||||
|
units = Units(**kwargs)
|
||||||
|
return units, spike_times, spike_idx
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_indexing(electrical_series):
|
def test_dynamictable_indexing(electrical_series):
|
||||||
"""
|
"""
|
||||||
Can index values from a dynamictable
|
Can index values from a dynamictable
|
||||||
|
@ -106,6 +141,24 @@ def test_dynamictable_indexing(electrical_series):
|
||||||
assert subsection.dtypes.values.tolist() == dtypes[0:3]
|
assert subsection.dtypes.values.tolist() == dtypes[0:3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_ragged_arrays(units):
|
||||||
|
"""
|
||||||
|
Should be able to index ragged arrays using an implicit _index column
|
||||||
|
|
||||||
|
Also tests:
|
||||||
|
- passing arrays directly instead of wrapping in vectordata/index specifically,
|
||||||
|
if the models in the fixture instantiate then this works
|
||||||
|
"""
|
||||||
|
units, spike_times, spike_idx = units
|
||||||
|
|
||||||
|
# ensure we don't pivot to long when indexing
|
||||||
|
assert units[0].shape[0] == 1
|
||||||
|
# check that we got the indexing boundaries corrunect
|
||||||
|
# (and that we are forwarding attr calls to the dataframe by accessing shape
|
||||||
|
for i in range(units.shape[0]):
|
||||||
|
assert np.all(units.iloc[i, 0] == spike_times[i])
|
||||||
|
|
||||||
|
|
||||||
def test_dynamictable_append_column():
|
def test_dynamictable_append_column():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue