mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-09 13:44:27 +00:00
regenerate model, lint
This commit is contained in:
parent
a11d3d042e
commit
a309c25c3d
16 changed files with 1570 additions and 120 deletions
|
@ -7,7 +7,7 @@ NWB schema translation
|
|||
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
|
||||
- handle compound `dtype` like in TimeSeriesReferenceVectorData
|
||||
- Create a validator that checks if all the lists in a compound dtype dataset are same length
|
||||
- [ ] Make `target` optional in vectorIndex
|
||||
- [ ] Move making `target` optional in vectorIndex from pydantic generator to linkml generators!
|
||||
|
||||
Cleanup
|
||||
- [ ] Update pydantic generator
|
||||
|
|
|
@ -87,6 +87,14 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
if not base_range_subsumes_any_of:
|
||||
raise ValueError("Slot cannot have both range and any_of defined")
|
||||
|
||||
def before_generate_slot(self, slot: SlotDefinition, sv: SchemaView) -> SlotDefinition:
|
||||
"""
|
||||
Force some properties to be optional
|
||||
"""
|
||||
if slot.name == "target" and "index" in slot.description:
|
||||
slot.required = False
|
||||
return slot
|
||||
|
||||
def after_generate_slot(self, slot: SlotResult, sv: SchemaView) -> SlotResult:
|
||||
"""
|
||||
- strip unwanted metadata
|
||||
|
|
|
@ -4,6 +4,7 @@ Special types for mimicking HDMF special case behavior
|
|||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
|
||||
|
||||
import numpy as np
|
||||
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
|
||||
from numpydantic import NDArray, Shape
|
||||
from pandas import DataFrame, Series
|
||||
|
@ -11,13 +12,12 @@ from pydantic import (
|
|||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
model_validator,
|
||||
field_validator,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nwb_linkml.models import VectorData, VectorIndex
|
||||
|
@ -133,6 +133,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -152,6 +156,16 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
|
@ -199,12 +213,6 @@ class DynamicTableMixin(BaseModel):
|
|||
model["colnames"].extend(colnames)
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
If an id column is not given, create one as an arange.
|
||||
"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_targets(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
|
@ -233,7 +241,7 @@ class DynamicTableMixin(BaseModel):
|
|||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [l == lengths[0] for l in lengths], (
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
|
@ -241,7 +249,9 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
|
@ -299,7 +309,7 @@ class VectorDataMixin(BaseModel):
|
|||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
@ -332,7 +342,7 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return [self.target.value[slice(start, end)]]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
|
@ -363,7 +373,7 @@ class VectorIndexMixin(BaseModel):
|
|||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.1.0"
|
||||
|
@ -59,6 +69,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -73,6 +88,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -83,6 +119,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -90,19 +131,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -111,6 +152,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -130,6 +189,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -221,6 +281,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -240,9 +304,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -288,6 +383,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -325,7 +454,9 @@ class Index(Data):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: Data = Field(..., description="""Target dataset that this index applies to.""")
|
||||
target: Optional[Data] = Field(
|
||||
None, description="""Target dataset that this index applies to."""
|
||||
)
|
||||
|
||||
|
||||
class VectorData(VectorDataMixin):
|
||||
|
@ -351,8 +482,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.1.2"
|
||||
|
@ -59,6 +69,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -73,6 +88,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -83,6 +119,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -90,19 +131,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -111,6 +152,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -130,6 +189,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -221,6 +281,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -240,9 +304,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -288,6 +383,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -325,7 +454,9 @@ class Index(Data):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: Data = Field(..., description="""Target dataset that this index applies to.""")
|
||||
target: Optional[Data] = Field(
|
||||
None, description="""Target dataset that this index applies to."""
|
||||
)
|
||||
|
||||
|
||||
class VectorData(VectorDataMixin):
|
||||
|
@ -351,8 +482,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.1.3"
|
||||
|
@ -59,6 +69,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -73,6 +88,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -83,6 +119,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -90,19 +131,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -111,6 +152,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -130,6 +189,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -221,6 +281,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -240,9 +304,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -288,6 +383,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -325,7 +454,9 @@ class Index(Data):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: Data = Field(..., description="""Target dataset that this index applies to.""")
|
||||
target: Optional[Data] = Field(
|
||||
None, description="""Target dataset that this index applies to."""
|
||||
)
|
||||
|
||||
|
||||
class VectorData(VectorDataMixin):
|
||||
|
@ -359,8 +490,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
value: Optional[NDArray[Shape["* num_rows"], Any]] = Field(
|
||||
None, json_schema_extra={"linkml_meta": {"array": {"dimensions": [{"alias": "num_rows"}]}}}
|
||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_2_0.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.2.0"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_2_1.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.2.1"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_3_0.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.3.0"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_4_0.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from numpydantic import NDArray, Shape
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.4.0"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_5_0.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
metamodel_version = "None"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_5_1.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
metamodel_version = "None"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_6_0.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
metamodel_version = "None"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
import numpy as np
|
||||
from ...hdmf_common.v1_7_0.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
import numpy as np
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
metamodel_version = "None"
|
||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use index as length, if present
|
||||
"""
|
||||
if self._index:
|
||||
return len(self._index)
|
||||
else:
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
|||
value: Optional[NDArray] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||
"""
|
||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
|||
|
||||
start = 0 if arg == 0 else self.value[arg - 1]
|
||||
end = self.value[arg]
|
||||
return self.target.array[slice(start, end)]
|
||||
return self.target.value[slice(start, end)]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||
if self.target is None:
|
||||
return self.value[item]
|
||||
elif type(self.target).__name__ == "VectorData":
|
||||
elif isinstance(self.target, VectorData):
|
||||
if isinstance(item, int):
|
||||
return self._getitem_helper(item)
|
||||
else:
|
||||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
|||
else:
|
||||
self.value[key] = value
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""
|
||||
Forward getattr to ``value``
|
||||
"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Get length from value
|
||||
"""
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||
colnames: List[str] = Field(default_factory=list)
|
||||
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||
|
||||
@property
|
||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
|||
# special case where pandas will unpack a pydantic model
|
||||
# into {n_fields} rows, rather than keeping it in a dict
|
||||
val = Series([val])
|
||||
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||
# special case where we are returning a row in a ragged array,
|
||||
# same as above - prevent pandas pivoting to long
|
||||
val = Series([val])
|
||||
data[k] = val
|
||||
return data
|
||||
|
||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_COLUMN_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Construct colnames from arguments.
|
||||
|
||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
|||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
try:
|
||||
return handler(val)
|
||||
except ValidationError:
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||
annotation = annotation.__args__[0]
|
||||
return handler(
|
||||
annotation(
|
||||
val,
|
||||
name=info.field_name,
|
||||
description=cls.model_fields[info.field_name].description,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
linkml_meta = LinkMLMeta(
|
||||
{
|
||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
|||
)
|
||||
|
||||
name: str = Field(...)
|
||||
target: VectorData = Field(
|
||||
..., description="""Reference to the target dataset that this index applies to."""
|
||||
target: Optional[VectorData] = Field(
|
||||
None, description="""Reference to the target dataset that this index applies to."""
|
||||
)
|
||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||
value: Optional[
|
||||
|
|
|
@ -1,22 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
|
||||
from ...hdmf_common.v1_8_0.hdmf_common_base import Data
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
import re
|
||||
import sys
|
||||
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
|
||||
from pandas import DataFrame, Series
|
||||
from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
|
||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
RootModel,
|
||||
model_validator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
ValidationError,
|
||||
)
|
||||
from numpydantic import NDArray, Shape
|
||||
import numpy as np
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
metamodel_version = "None"
|
||||
version = "1.8.0"
|
||||
|
@ -96,7 +99,7 @@ class VectorDataMixin(BaseModel):
|
|||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
@ -141,7 +144,7 @@ class VectorIndexMixin(BaseModel):
|
|||
idx = range(*item.indices(len(self.value)))
|
||||
return [self._getitem_helper(i) for i in idx]
|
||||
else:
|
||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
if self._index:
|
||||
|
@ -160,7 +163,7 @@ class VectorIndexMixin(BaseModel):
|
|||
try:
|
||||
return getattr(self.value, item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
|
@ -302,7 +305,7 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
return super().__setattr__(key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
|
@ -310,7 +313,7 @@ class DynamicTableMixin(BaseModel):
|
|||
try:
|
||||
return getattr(self[:, :], item)
|
||||
except AttributeError:
|
||||
raise e
|
||||
raise e from None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
@ -387,7 +390,7 @@ class DynamicTableMixin(BaseModel):
|
|||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()]
|
||||
assert [l == lengths[0] for l in lengths], (
|
||||
assert [length == lengths[0] for length in lengths], (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
|
@ -395,7 +398,9 @@ class DynamicTableMixin(BaseModel):
|
|||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
||||
def cast_columns(
|
||||
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||
) -> Any:
|
||||
"""
|
||||
If columns are supplied as arrays, try casting them to the type before validating
|
||||
"""
|
||||
|
|
|
@ -70,6 +70,8 @@ ignore = [
|
|||
"UP006", "UP035",
|
||||
# | for Union types (only supported >=3.10
|
||||
"UP007", "UP038",
|
||||
# syntax error in forward annotation with numpydantic
|
||||
"F722"
|
||||
]
|
||||
|
||||
fixable = ["ALL"]
|
||||
|
|
Loading…
Reference in a new issue