mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-10 00:34:29 +00:00
regenerate model, lint
This commit is contained in:
parent
a11d3d042e
commit
a309c25c3d
16 changed files with 1570 additions and 120 deletions
|
@ -7,7 +7,7 @@ NWB schema translation
|
||||||
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
|
- handle compound `dtype` like in ophys.PlaneSegmentation.pixel_mask
|
||||||
- handle compound `dtype` like in TimeSeriesReferenceVectorData
|
- handle compound `dtype` like in TimeSeriesReferenceVectorData
|
||||||
- Create a validator that checks if all the lists in a compound dtype dataset are same length
|
- Create a validator that checks if all the lists in a compound dtype dataset are same length
|
||||||
- [ ] Make `target` optional in vectorIndex
|
- [ ] Move making `target` optional in vectorIndex from pydantic generator to linkml generators!
|
||||||
|
|
||||||
Cleanup
|
Cleanup
|
||||||
- [ ] Update pydantic generator
|
- [ ] Update pydantic generator
|
||||||
|
|
|
@ -87,6 +87,14 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
if not base_range_subsumes_any_of:
|
if not base_range_subsumes_any_of:
|
||||||
raise ValueError("Slot cannot have both range and any_of defined")
|
raise ValueError("Slot cannot have both range and any_of defined")
|
||||||
|
|
||||||
|
def before_generate_slot(self, slot: SlotDefinition, sv: SchemaView) -> SlotDefinition:
|
||||||
|
"""
|
||||||
|
Force some properties to be optional
|
||||||
|
"""
|
||||||
|
if slot.name == "target" and "index" in slot.description:
|
||||||
|
slot.required = False
|
||||||
|
return slot
|
||||||
|
|
||||||
def after_generate_slot(self, slot: SlotResult, sv: SchemaView) -> SlotResult:
|
def after_generate_slot(self, slot: SlotResult, sv: SchemaView) -> SlotResult:
|
||||||
"""
|
"""
|
||||||
- strip unwanted metadata
|
- strip unwanted metadata
|
||||||
|
|
|
@ -4,6 +4,7 @@ Special types for mimicking HDMF special case behavior
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
|
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
|
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
|
@ -11,13 +12,12 @@ from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
model_validator,
|
|
||||||
field_validator,
|
|
||||||
ValidatorFunctionWrapHandler,
|
|
||||||
ValidationError,
|
ValidationError,
|
||||||
ValidationInfo,
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
)
|
)
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nwb_linkml.models import VectorData, VectorIndex
|
from nwb_linkml.models import VectorData, VectorIndex
|
||||||
|
@ -133,6 +133,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -152,6 +156,16 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
@ -199,12 +213,6 @@ class DynamicTableMixin(BaseModel):
|
||||||
model["colnames"].extend(colnames)
|
model["colnames"].extend(colnames)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
|
||||||
"""
|
|
||||||
If an id column is not given, create one as an arange.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def resolve_targets(self) -> "DynamicTableMixin":
|
def resolve_targets(self) -> "DynamicTableMixin":
|
||||||
"""
|
"""
|
||||||
|
@ -233,7 +241,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
Ensure that all columns are equal length
|
Ensure that all columns are equal length
|
||||||
"""
|
"""
|
||||||
lengths = [len(v) for v in self._columns.values()]
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
assert [l == lengths[0] for l in lengths], (
|
assert [length == lengths[0] for length in lengths], (
|
||||||
"Columns are not of equal length! "
|
"Columns are not of equal length! "
|
||||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
)
|
)
|
||||||
|
@ -241,7 +249,9 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
@field_validator("*", mode="wrap")
|
@field_validator("*", mode="wrap")
|
||||||
@classmethod
|
@classmethod
|
||||||
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
If columns are supplied as arrays, try casting them to the type before validating
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
"""
|
"""
|
||||||
|
@ -299,7 +309,7 @@ class VectorDataMixin(BaseModel):
|
||||||
try:
|
try:
|
||||||
return getattr(self.value, item)
|
return getattr(self.value, item)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise e
|
raise e from None
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -332,7 +342,7 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return [self.target.value[slice(start, end)]]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
|
@ -363,7 +373,7 @@ class VectorIndexMixin(BaseModel):
|
||||||
try:
|
try:
|
||||||
return getattr(self.value, item)
|
return getattr(self.value, item)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise e
|
raise e from None
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
|
@ -59,6 +69,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -73,6 +88,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -83,6 +119,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -90,19 +131,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -111,6 +152,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -130,6 +189,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -221,6 +281,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -240,9 +304,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -288,6 +383,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -325,7 +454,9 @@ class Index(Data):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: Data = Field(..., description="""Target dataset that this index applies to.""")
|
target: Optional[Data] = Field(
|
||||||
|
None, description="""Target dataset that this index applies to."""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VectorData(VectorDataMixin):
|
class VectorData(VectorDataMixin):
|
||||||
|
@ -351,8 +482,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.1.2"
|
version = "1.1.2"
|
||||||
|
@ -59,6 +69,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -73,6 +88,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -83,6 +119,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -90,19 +131,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -111,6 +152,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -130,6 +189,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -221,6 +281,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -240,9 +304,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -288,6 +383,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -325,7 +454,9 @@ class Index(Data):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: Data = Field(..., description="""Target dataset that this index applies to.""")
|
target: Optional[Data] = Field(
|
||||||
|
None, description="""Target dataset that this index applies to."""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VectorData(VectorDataMixin):
|
class VectorData(VectorDataMixin):
|
||||||
|
@ -351,8 +482,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.1.3"
|
version = "1.1.3"
|
||||||
|
@ -59,6 +69,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -73,6 +88,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -83,6 +119,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -90,19 +131,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -111,6 +152,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -130,6 +189,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -221,6 +281,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -240,9 +304,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -288,6 +383,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -325,7 +454,9 @@ class Index(Data):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: Data = Field(..., description="""Target dataset that this index applies to.""")
|
target: Optional[Data] = Field(
|
||||||
|
None, description="""Target dataset that this index applies to."""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VectorData(VectorDataMixin):
|
class VectorData(VectorDataMixin):
|
||||||
|
@ -359,8 +490,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
value: Optional[NDArray[Shape["* num_rows"], Any]] = Field(
|
value: Optional[NDArray[Shape["* num_rows"], Any]] = Field(
|
||||||
None, json_schema_extra={"linkml_meta": {"array": {"dimensions": [{"alias": "num_rows"}]}}}
|
None, json_schema_extra={"linkml_meta": {"array": {"dimensions": [{"alias": "num_rows"}]}}}
|
||||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_2_0.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_2_0.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_2_1.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_2_1.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.2.1"
|
version = "1.2.1"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_3_0.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_3_0.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -4,12 +4,22 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_4_0.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_4_0.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_5_0.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_5_0.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_5_1.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_5_1.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_6_0.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_6_0.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -4,11 +4,21 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
|
||||||
from ...hdmf_common.v1_7_0.hdmf_common_base import Data, Container
|
from ...hdmf_common.v1_7_0.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
RootModel,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
ValidationInfo,
|
||||||
|
ValidatorFunctionWrapHandler,
|
||||||
|
ValidationError,
|
||||||
|
)
|
||||||
|
import numpy as np
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
|
@ -60,6 +70,11 @@ class VectorDataMixin(BaseModel):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -74,6 +89,27 @@ class VectorDataMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Use index as length, if present
|
||||||
|
"""
|
||||||
|
if self._index:
|
||||||
|
return len(self._index)
|
||||||
|
else:
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +120,11 @@ class VectorIndexMixin(BaseModel):
|
||||||
value: Optional[NDArray] = None
|
value: Optional[NDArray] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
if value is not None and "value" not in kwargs:
|
||||||
|
kwargs["value"] = value
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
|
||||||
"""
|
"""
|
||||||
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
@ -91,19 +132,19 @@ class VectorIndexMixin(BaseModel):
|
||||||
|
|
||||||
start = 0 if arg == 0 else self.value[arg - 1]
|
start = 0 if arg == 0 else self.value[arg - 1]
|
||||||
end = self.value[arg]
|
end = self.value[arg]
|
||||||
return self.target.array[slice(start, end)]
|
return self.target.value[slice(start, end)]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice]) -> Any:
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
if self.target is None:
|
if self.target is None:
|
||||||
return self.value[item]
|
return self.value[item]
|
||||||
elif type(self.target).__name__ == "VectorData":
|
elif isinstance(self.target, VectorData):
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
return self._getitem_helper(item)
|
return self._getitem_helper(item)
|
||||||
else:
|
else:
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -112,6 +153,24 @@ class VectorIndexMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.value[key] = value
|
self.value[key] = value
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""
|
||||||
|
Forward getattr to ``value``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self.value, item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get length from value
|
||||||
|
"""
|
||||||
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -131,6 +190,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
colnames: List[str] = Field(default_factory=list)
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
id: Optional[NDArray[Shape["* num_rows"], int]] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorDataMixin"]]:
|
||||||
|
@ -222,6 +282,10 @@ class DynamicTableMixin(BaseModel):
|
||||||
# special case where pandas will unpack a pydantic model
|
# special case where pandas will unpack a pydantic model
|
||||||
# into {n_fields} rows, rather than keeping it in a dict
|
# into {n_fields} rows, rather than keeping it in a dict
|
||||||
val = Series([val])
|
val = Series([val])
|
||||||
|
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
|
||||||
|
# special case where we are returning a row in a ragged array,
|
||||||
|
# same as above - prevent pandas pivoting to long
|
||||||
|
val = Series([val])
|
||||||
data[k] = val
|
data[k] = val
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -241,9 +305,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str) -> Any:
|
||||||
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
|
try:
|
||||||
|
return BaseModel.__getattr__(self, item)
|
||||||
|
except AttributeError as e:
|
||||||
|
try:
|
||||||
|
return getattr(self[:, :], item)
|
||||||
|
except AttributeError:
|
||||||
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_colnames(cls, model: Dict[str, Any]) -> None:
|
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||||
|
"""
|
||||||
|
Create ID column if not provided
|
||||||
|
"""
|
||||||
|
if "id" not in model:
|
||||||
|
lengths = []
|
||||||
|
for key, val in model.items():
|
||||||
|
# don't get lengths of columns with an index
|
||||||
|
if (
|
||||||
|
f"{key}_index" in model
|
||||||
|
or (isinstance(val, VectorData) and val._index)
|
||||||
|
or key in cls.NON_COLUMN_FIELDS
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
lengths.append(len(val))
|
||||||
|
model["id"] = np.arange(np.max(lengths))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Construct colnames from arguments.
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
@ -289,6 +384,40 @@ class DynamicTableMixin(BaseModel):
|
||||||
idx.target = col
|
idx.target = col
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that all columns are equal length
|
||||||
|
"""
|
||||||
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
|
assert [length == lengths[0] for length in lengths], (
|
||||||
|
"Columns are not of equal length! "
|
||||||
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return handler(val)
|
||||||
|
except ValidationError:
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
if type(annotation).__name__ == "_UnionGenericAlias":
|
||||||
|
annotation = annotation.__args__[0]
|
||||||
|
return handler(
|
||||||
|
annotation(
|
||||||
|
val,
|
||||||
|
name=info.field_name,
|
||||||
|
description=cls.model_fields[info.field_name].description,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
linkml_meta = LinkMLMeta(
|
linkml_meta = LinkMLMeta(
|
||||||
{
|
{
|
||||||
|
@ -335,8 +464,8 @@ class VectorIndex(VectorIndexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
name: str = Field(...)
|
name: str = Field(...)
|
||||||
target: VectorData = Field(
|
target: Optional[VectorData] = Field(
|
||||||
..., description="""Reference to the target dataset that this index applies to."""
|
None, description="""Reference to the target dataset that this index applies to."""
|
||||||
)
|
)
|
||||||
description: str = Field(..., description="""Description of what these vectors represent.""")
|
description: str = Field(..., description="""Description of what these vectors represent.""")
|
||||||
value: Optional[
|
value: Optional[
|
||||||
|
|
|
@ -1,22 +1,25 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
from datetime import datetime, date
|
||||||
|
from decimal import Decimal
|
||||||
from ...hdmf_common.v1_8_0.hdmf_common_base import Data
|
from enum import Enum
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
|
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
Field,
|
Field,
|
||||||
RootModel,
|
RootModel,
|
||||||
model_validator,
|
|
||||||
field_validator,
|
field_validator,
|
||||||
|
model_validator,
|
||||||
ValidationInfo,
|
ValidationInfo,
|
||||||
ValidatorFunctionWrapHandler,
|
ValidatorFunctionWrapHandler,
|
||||||
ValidationError,
|
ValidationError,
|
||||||
)
|
)
|
||||||
from numpydantic import NDArray, Shape
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
metamodel_version = "None"
|
metamodel_version = "None"
|
||||||
version = "1.8.0"
|
version = "1.8.0"
|
||||||
|
@ -96,7 +99,7 @@ class VectorDataMixin(BaseModel):
|
||||||
try:
|
try:
|
||||||
return getattr(self.value, item)
|
return getattr(self.value, item)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise e
|
raise e from None
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -141,7 +144,7 @@ class VectorIndexMixin(BaseModel):
|
||||||
idx = range(*item.indices(len(self.value)))
|
idx = range(*item.indices(len(self.value)))
|
||||||
return [self._getitem_helper(i) for i in idx]
|
return [self._getitem_helper(i) for i in idx]
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("DynamicTableRange not supported yet")
|
raise AttributeError(f"Could not index with {item}")
|
||||||
|
|
||||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -160,7 +163,7 @@ class VectorIndexMixin(BaseModel):
|
||||||
try:
|
try:
|
||||||
return getattr(self.value, item)
|
return getattr(self.value, item)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise e
|
raise e from None
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -302,7 +305,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
return super().__setattr__(key, value)
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item: str) -> Any:
|
||||||
"""Try and use pandas df attrs if we don't have them"""
|
"""Try and use pandas df attrs if we don't have them"""
|
||||||
try:
|
try:
|
||||||
return BaseModel.__getattr__(self, item)
|
return BaseModel.__getattr__(self, item)
|
||||||
|
@ -310,7 +313,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
try:
|
try:
|
||||||
return getattr(self[:, :], item)
|
return getattr(self[:, :], item)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise e
|
raise e from None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -387,7 +390,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
Ensure that all columns are equal length
|
Ensure that all columns are equal length
|
||||||
"""
|
"""
|
||||||
lengths = [len(v) for v in self._columns.values()]
|
lengths = [len(v) for v in self._columns.values()]
|
||||||
assert [l == lengths[0] for l in lengths], (
|
assert [length == lengths[0] for length in lengths], (
|
||||||
"Columns are not of equal length! "
|
"Columns are not of equal length! "
|
||||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
)
|
)
|
||||||
|
@ -395,7 +398,9 @@ class DynamicTableMixin(BaseModel):
|
||||||
|
|
||||||
@field_validator("*", mode="wrap")
|
@field_validator("*", mode="wrap")
|
||||||
@classmethod
|
@classmethod
|
||||||
def cast_columns(cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo):
|
def cast_columns(
|
||||||
|
cls, val: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
|
||||||
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
If columns are supplied as arrays, try casting them to the type before validating
|
If columns are supplied as arrays, try casting them to the type before validating
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -70,6 +70,8 @@ ignore = [
|
||||||
"UP006", "UP035",
|
"UP006", "UP035",
|
||||||
# | for Union types (only supported >=3.10
|
# | for Union types (only supported >=3.10
|
||||||
"UP007", "UP038",
|
"UP007", "UP038",
|
||||||
|
# syntax error in forward annotation with numpydantic
|
||||||
|
"F722"
|
||||||
]
|
]
|
||||||
|
|
||||||
fixable = ["ALL"]
|
fixable = ["ALL"]
|
||||||
|
|
Loading…
Reference in a new issue