Make VectorData and VectorIndex generics to ensure coercion to VectorData for declared columns

This commit is contained in:
sneakers-the-rat 2024-08-13 21:25:56 -07:00
parent 50005d33e5
commit 01e46f7531
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 176 additions and 21 deletions

View file

@ -36,14 +36,10 @@ plot = [
"dash-cytoscape<1.0.0,>=0.3.0", "dash-cytoscape<1.0.0,>=0.3.0",
] ]
tests = [ tests = [
"nwb-linkml[plot]", "nwb-linkml",
"pytest<8.0.0,>=7.4.0", "pytest<8.0.0,>=7.4.0",
"pytest-depends<2.0.0,>=1.0.1", "pytest-depends<2.0.0,>=1.0.1",
"coverage<7.0.0,>=6.1.1",
"pytest-md<1.0.0,>=0.2.0",
"pytest-cov<5.0.0,>=4.1.0", "pytest-cov<5.0.0,>=4.1.0",
"coveralls<4.0.0,>=3.3.1",
"pytest-profiling<2.0.0,>=1.7.0",
"sybil<6.0.0,>=5.0.3", "sybil<6.0.0,>=5.0.3",
"requests-cache>=1.2.1", "requests-cache>=1.2.1",
] ]

View file

@ -116,6 +116,7 @@ class NWBPydanticGenerator(PydanticGenerator):
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult: def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
"""Customize dynamictable behavior""" """Customize dynamictable behavior"""
cls = AfterGenerateClass.inject_dynamictable(cls) cls = AfterGenerateClass.inject_dynamictable(cls)
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
return cls return cls
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
@ -278,6 +279,28 @@ class AfterGenerateClass:
return cls return cls
@staticmethod
def wrap_dynamictable_columns(cls: ClassResult, sv: SchemaView) -> ClassResult:
"""
Wrap NDArray columns inside of dynamictables with ``VectorData`` or
``VectorIndex``, which are generic classes whose value slot is
parameterized by the NDArray
"""
if cls.source.is_a == "DynamicTable" or "DynamicTable" in sv.class_ancestors(
cls.source.name
):
for an_attr in cls.cls.attributes:
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
if an_attr.endswith("_index"):
cls.cls.attributes[an_attr].range = "".join(
["VectorIndex[", slot_range, "]"]
)
else:
cls.cls.attributes[an_attr].range = "".join(
["VectorData[", slot_range, "]"]
)
return cls
def compile_python( def compile_python(
text_or_fn: str, package_path: Path = None, module_name: str = "test" text_or_fn: str, package_path: Path = None, module_name: str = "test"

View file

@ -2,15 +2,18 @@
Special types for mimicking HDMF special case behavior Special types for mimicking HDMF special case behavior
""" """
import sys
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
ClassVar, ClassVar,
Dict, Dict,
Generic,
Iterable, Iterable,
List, List,
Optional, Optional,
Tuple, Tuple,
TypeVar,
Union, Union,
overload, overload,
) )
@ -33,6 +36,9 @@ from pydantic import (
if TYPE_CHECKING: if TYPE_CHECKING:
from nwb_linkml.models import VectorData, VectorIndex from nwb_linkml.models import VectorData, VectorIndex
T = TypeVar("T", bound=NDArray)
T_INJECT = 'T = TypeVar("T", bound=NDArray)'
class DynamicTableMixin(BaseModel): class DynamicTableMixin(BaseModel):
""" """
@ -219,20 +225,28 @@ class DynamicTableMixin(BaseModel):
anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order
""" """
if "colnames" not in model: if "colnames" not in model:
colnames = [
k for k in model if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index")
]
model["colnames"] = colnames
else:
# add any columns not explicitly given an order at the end
colnames = [ colnames = [
k k
for k in model for k in model
if k not in cls.NON_COLUMN_FIELDS if k not in cls.NON_COLUMN_FIELDS
and not k.endswith("_index") and not k.endswith("_index")
and k not in model["colnames"] and not isinstance(model[k], VectorIndexMixin)
] ]
model["colnames"].extend(colnames) model["colnames"] = colnames
else:
# add any columns not explicitly given an order at the end
colnames = model["colnames"].copy()
colnames.extend(
[
k
for k in model
if k not in cls.NON_COLUMN_FIELDS
and not k.endswith("_index")
and k not in model["colnames"]
and not isinstance(model[k], VectorIndexMixin)
]
)
model["colnames"] = colnames
return model return model
@model_validator(mode="after") @model_validator(mode="after")
@ -322,7 +336,7 @@ class DynamicTableMixin(BaseModel):
) )
class VectorDataMixin(BaseModel): class VectorDataMixin(BaseModel, Generic[T]):
""" """
Mixin class to give VectorData indexing abilities Mixin class to give VectorData indexing abilities
""" """
@ -330,7 +344,7 @@ class VectorDataMixin(BaseModel):
_index: Optional["VectorIndex"] = None _index: Optional["VectorIndex"] = None
# redefined in `VectorData`, but included here for testing and type checking # redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None value: Optional[T] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs): def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs: if value is not None and "value" not in kwargs:
@ -373,13 +387,13 @@ class VectorDataMixin(BaseModel):
return len(self.value) return len(self.value)
class VectorIndexMixin(BaseModel): class VectorIndexMixin(BaseModel, Generic[T]):
""" """
Mixin class to give VectorIndex indexing abilities Mixin class to give VectorIndex indexing abilities
""" """
# redefined in `VectorData`, but included here for testing and type checking # redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None value: Optional[T] = None
target: Optional["VectorData"] = None target: Optional["VectorData"] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs): def __init__(self, value: Optional[NDArray] = None, **kwargs):
@ -649,8 +663,10 @@ DYNAMIC_TABLE_IMPORTS = Imports(
module="typing", module="typing",
objects=[ objects=[
ObjectImport(name="ClassVar"), ObjectImport(name="ClassVar"),
ObjectImport(name="Generic"),
ObjectImport(name="Iterable"), ObjectImport(name="Iterable"),
ObjectImport(name="Tuple"), ObjectImport(name="Tuple"),
ObjectImport(name="TypeVar"),
ObjectImport(name="overload"), ObjectImport(name="overload"),
], ],
), ),
@ -677,6 +693,7 @@ VectorData is purposefully excluded as an import or an inject so that it will be
resolved to the VectorData definition in the generated module resolved to the VectorData definition in the generated module
""" """
DYNAMIC_TABLE_INJECTS = [ DYNAMIC_TABLE_INJECTS = [
T_INJECT,
VectorDataMixin, VectorDataMixin,
VectorIndexMixin, VectorIndexMixin,
DynamicTableRegionMixin, DynamicTableRegionMixin,
@ -689,13 +706,27 @@ TSRVD_IMPORTS = Imports(
Import( Import(
module="typing", module="typing",
objects=[ objects=[
ObjectImport(name="overload"), ObjectImport(name="Generic"),
ObjectImport(name="Iterable"), ObjectImport(name="Iterable"),
ObjectImport(name="Tuple"), ObjectImport(name="Tuple"),
ObjectImport(name="TypeVar"),
ObjectImport(name="overload"),
], ],
), ),
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]), Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
] ]
) )
"""Imports for TimeSeriesReferenceVectorData""" """Imports for TimeSeriesReferenceVectorData"""
TSRVD_INJECTS = [VectorDataMixin, TimeSeriesReferenceVectorDataMixin] TSRVD_INJECTS = [T_INJECT, VectorDataMixin, TimeSeriesReferenceVectorDataMixin]
if "pytest" in sys.modules:
# during testing define concrete subclasses...
class VectorData(VectorDataMixin):
"""VectorData subclass for testing"""
pass
class VectorIndex(VectorIndexMixin):
"""VectorIndex subclass for testing"""
pass

View file

@ -114,7 +114,7 @@ def load_namespace_adapter(
for ns in namespaces.namespaces: for ns in namespaces.namespaces:
for schema in ns.schema_: for schema in ns.schema_:
if schema.source is None: if schema.source is None:
if imported is None and schema.namespace == "hdmf-common": if imported is None and schema.namespace == "hdmf-common" and ns.name == "core":
# special case - hdmf-common is imported by name without location or version, # special case - hdmf-common is imported by name without location or version,
# so to get the correct version we have to handle it separately # so to get the correct version we have to handle it separately
imported = _resolve_hdmf(namespace, path) imported = _resolve_hdmf(namespace, path)

View file

@ -1,5 +1,9 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from numpydantic import NDArray, Shape
from nwb_linkml.includes import hdmf
from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin
# FIXME: Make this just be the output of the provider by patching into import machinery # FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
@ -48,7 +52,7 @@ def test_dynamictable_indexing(electrical_series):
# get a single column # get a single column
col = electrodes["y"] col = electrodes["y"]
assert all(col == [5, 6, 7, 8, 9]) assert all(col.value == [5, 6, 7, 8, 9])
# get a single cell # get a single cell
val = electrodes[0, "y"] val = electrodes[0, "y"]
@ -198,3 +202,104 @@ def test_aligned_dynamictable(intracellular_recordings_table):
for i in range(len(stims)): for i in range(len(stims)):
assert isinstance(stims[i], VoltageClampStimulusSeries) assert isinstance(stims[i], VoltageClampStimulusSeries)
assert all([i == val for val in stims[i][:]]) assert all([i == val for val in stims[i][:]])
# --------------------------------------------------
# Direct mixin tests
# --------------------------------------------------
def test_dynamictable_mixin_indexing():
"""
This is just a placeholder test to say that indexing is tested above
with actual model objects in case i ever ctrl+f for this
"""
pass
def test_dynamictable_mixin_colnames():
"""
Should correctly infer colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
new_col_1 = VectorDataMixin(value=np.arange(10))
new_col_2 = VectorDataMixin(value=np.arange(10))
inst = MyDT(existing_col=np.arange(10), new_col_1=new_col_1, new_col_2=new_col_2)
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
def test_dynamictable_mixin_colnames_index():
"""
Exclude index columns in colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
"new_col_2": hdmf.VectorData(value=np.arange(10)),
}
# explicit index with mismatching name
cols["weirdname_index"] = VectorIndexMixin(value=np.arange(10), target=cols["new_col_1"])
# implicit index with matching name
cols["new_col_2_index"] = VectorIndexMixin(value=np.arange(10))
inst = MyDT(**cols)
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
def test_dynamictable_mixin_colnames_ordered():
"""
Should be able to pass explicit order to colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
"new_col_2": hdmf.VectorData(value=np.arange(10)),
"new_col_3": hdmf.VectorData(value=np.arange(10)),
}
order = ["new_col_2", "existing_col", "new_col_1", "new_col_3"]
inst = MyDT(**cols, colnames=order)
assert inst.colnames == order
# this should get reflected in the columns selector and the df produces
assert all([key1 == key2 for key1, key2 in zip(order, inst._columns)])
assert all(inst[0].columns == order)
# partial lists should append unnamed columsn at the end
partial_order = ["new_col_3", "new_col_2"]
inst = MyDT(**cols, colnames=partial_order)
assert inst.colnames == [*partial_order, "existing_col", "new_col_1"]
def test_dynamictable_mixin_getattr():
"""
Dynamictable should forward unknown getattr requests to the df
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
class AModel(DynamicTableMixin):
col: hdmf.VectorData[NDArray[Shape["3, 3"], int]]
col = hdmf.VectorData(value=np.arange(10))
inst = MyDT(existing_col=col)
# regular lookup for attrs that exist
# pdb.set_trace()
# inst.existing_col
# assert inst.existing_col == col
# df lookup otherwise
# inst.columns