Make VectorData and VectorIndex generics to ensure coercion to VectorData for declared columns

This commit is contained in:
sneakers-the-rat 2024-08-13 21:25:56 -07:00
parent 50005d33e5
commit 01e46f7531
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 176 additions and 21 deletions

View file

@ -36,14 +36,10 @@ plot = [
"dash-cytoscape<1.0.0,>=0.3.0",
]
tests = [
"nwb-linkml[plot]",
"nwb-linkml",
"pytest<8.0.0,>=7.4.0",
"pytest-depends<2.0.0,>=1.0.1",
"coverage<7.0.0,>=6.1.1",
"pytest-md<1.0.0,>=0.2.0",
"pytest-cov<5.0.0,>=4.1.0",
"coveralls<4.0.0,>=3.3.1",
"pytest-profiling<2.0.0,>=1.7.0",
"sybil<6.0.0,>=5.0.3",
"requests-cache>=1.2.1",
]

View file

@ -116,6 +116,7 @@ class NWBPydanticGenerator(PydanticGenerator):
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
"""Customize dynamictable behavior"""
cls = AfterGenerateClass.inject_dynamictable(cls)
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
return cls
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
@ -278,6 +279,28 @@ class AfterGenerateClass:
return cls
@staticmethod
def wrap_dynamictable_columns(cls: ClassResult, sv: SchemaView) -> ClassResult:
"""
Wrap NDArray columns inside of dynamictables with ``VectorData`` or
``VectorIndex``, which are generic classes whose value slot is
parameterized by the NDArray
"""
if cls.source.is_a == "DynamicTable" or "DynamicTable" in sv.class_ancestors(
cls.source.name
):
for an_attr in cls.cls.attributes:
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
if an_attr.endswith("_index"):
cls.cls.attributes[an_attr].range = "".join(
["VectorIndex[", slot_range, "]"]
)
else:
cls.cls.attributes[an_attr].range = "".join(
["VectorData[", slot_range, "]"]
)
return cls
def compile_python(
text_or_fn: str, package_path: Path = None, module_name: str = "test"

View file

@ -2,15 +2,18 @@
Special types for mimicking HDMF special case behavior
"""
import sys
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
@ -33,6 +36,9 @@ from pydantic import (
if TYPE_CHECKING:
from nwb_linkml.models import VectorData, VectorIndex
T = TypeVar("T", bound=NDArray)
T_INJECT = 'T = TypeVar("T", bound=NDArray)'
class DynamicTableMixin(BaseModel):
"""
@ -219,20 +225,28 @@ class DynamicTableMixin(BaseModel):
anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order
"""
if "colnames" not in model:
colnames = [
k for k in model if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index")
]
model["colnames"] = colnames
else:
# add any columns not explicitly given an order at the end
colnames = [
k
for k in model
if k not in cls.NON_COLUMN_FIELDS
and not k.endswith("_index")
and k not in model["colnames"]
and not isinstance(model[k], VectorIndexMixin)
]
model["colnames"].extend(colnames)
model["colnames"] = colnames
else:
# add any columns not explicitly given an order at the end
colnames = model["colnames"].copy()
colnames.extend(
[
k
for k in model
if k not in cls.NON_COLUMN_FIELDS
and not k.endswith("_index")
and k not in model["colnames"]
and not isinstance(model[k], VectorIndexMixin)
]
)
model["colnames"] = colnames
return model
@model_validator(mode="after")
@ -322,7 +336,7 @@ class DynamicTableMixin(BaseModel):
)
class VectorDataMixin(BaseModel):
class VectorDataMixin(BaseModel, Generic[T]):
"""
Mixin class to give VectorData indexing abilities
"""
@ -330,7 +344,7 @@ class VectorDataMixin(BaseModel):
_index: Optional["VectorIndex"] = None
# redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None
value: Optional[T] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
@ -373,13 +387,13 @@ class VectorDataMixin(BaseModel):
return len(self.value)
class VectorIndexMixin(BaseModel):
class VectorIndexMixin(BaseModel, Generic[T]):
"""
Mixin class to give VectorIndex indexing abilities
"""
# redefined in `VectorData`, but included here for testing and type checking
value: Optional[NDArray] = None
value: Optional[T] = None
target: Optional["VectorData"] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
@ -649,8 +663,10 @@ DYNAMIC_TABLE_IMPORTS = Imports(
module="typing",
objects=[
ObjectImport(name="ClassVar"),
ObjectImport(name="Generic"),
ObjectImport(name="Iterable"),
ObjectImport(name="Tuple"),
ObjectImport(name="TypeVar"),
ObjectImport(name="overload"),
],
),
@ -677,6 +693,7 @@ VectorData is purposefully excluded as an import or an inject so that it will be
resolved to the VectorData definition in the generated module
"""
DYNAMIC_TABLE_INJECTS = [
T_INJECT,
VectorDataMixin,
VectorIndexMixin,
DynamicTableRegionMixin,
@ -689,13 +706,27 @@ TSRVD_IMPORTS = Imports(
Import(
module="typing",
objects=[
ObjectImport(name="overload"),
ObjectImport(name="Generic"),
ObjectImport(name="Iterable"),
ObjectImport(name="Tuple"),
ObjectImport(name="TypeVar"),
ObjectImport(name="overload"),
],
),
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
]
)
"""Imports for TimeSeriesReferenceVectorData"""
TSRVD_INJECTS = [VectorDataMixin, TimeSeriesReferenceVectorDataMixin]
TSRVD_INJECTS = [T_INJECT, VectorDataMixin, TimeSeriesReferenceVectorDataMixin]
if "pytest" in sys.modules:
# during testing define concrete subclasses...
class VectorData(VectorDataMixin):
"""VectorData subclass for testing"""
pass
class VectorIndex(VectorIndexMixin):
"""VectorIndex subclass for testing"""
pass

View file

@ -114,7 +114,7 @@ def load_namespace_adapter(
for ns in namespaces.namespaces:
for schema in ns.schema_:
if schema.source is None:
if imported is None and schema.namespace == "hdmf-common":
if imported is None and schema.namespace == "hdmf-common" and ns.name == "core":
# special case - hdmf-common is imported by name without location or version,
# so to get the correct version we have to handle it separately
imported = _resolve_hdmf(namespace, path)

View file

@ -1,5 +1,9 @@
import numpy as np
import pandas as pd
from numpydantic import NDArray, Shape
from nwb_linkml.includes import hdmf
from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin
# FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
@ -48,7 +52,7 @@ def test_dynamictable_indexing(electrical_series):
# get a single column
col = electrodes["y"]
assert all(col == [5, 6, 7, 8, 9])
assert all(col.value == [5, 6, 7, 8, 9])
# get a single cell
val = electrodes[0, "y"]
@ -198,3 +202,104 @@ def test_aligned_dynamictable(intracellular_recordings_table):
for i in range(len(stims)):
assert isinstance(stims[i], VoltageClampStimulusSeries)
assert all([i == val for val in stims[i][:]])
# --------------------------------------------------
# Direct mixin tests
# --------------------------------------------------
def test_dynamictable_mixin_indexing():
"""
This is just a placeholder test to say that indexing is tested above
with actual model objects in case i ever ctrl+f for this
"""
pass
def test_dynamictable_mixin_colnames():
"""
Should correctly infer colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
new_col_1 = VectorDataMixin(value=np.arange(10))
new_col_2 = VectorDataMixin(value=np.arange(10))
inst = MyDT(existing_col=np.arange(10), new_col_1=new_col_1, new_col_2=new_col_2)
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
def test_dynamictable_mixin_colnames_index():
"""
Exclude index columns in colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
"new_col_2": hdmf.VectorData(value=np.arange(10)),
}
# explicit index with mismatching name
cols["weirdname_index"] = VectorIndexMixin(value=np.arange(10), target=cols["new_col_1"])
# implicit index with matching name
cols["new_col_2_index"] = VectorIndexMixin(value=np.arange(10))
inst = MyDT(**cols)
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
def test_dynamictable_mixin_colnames_ordered():
"""
Should be able to pass explicit order to colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
"new_col_2": hdmf.VectorData(value=np.arange(10)),
"new_col_3": hdmf.VectorData(value=np.arange(10)),
}
order = ["new_col_2", "existing_col", "new_col_1", "new_col_3"]
inst = MyDT(**cols, colnames=order)
assert inst.colnames == order
# this should get reflected in the columns selector and the df produces
assert all([key1 == key2 for key1, key2 in zip(order, inst._columns)])
assert all(inst[0].columns == order)
# partial lists should append unnamed columsn at the end
partial_order = ["new_col_3", "new_col_2"]
inst = MyDT(**cols, colnames=partial_order)
assert inst.colnames == [*partial_order, "existing_col", "new_col_1"]
def test_dynamictable_mixin_getattr():
"""
Dynamictable should forward unknown getattr requests to the df
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
class AModel(DynamicTableMixin):
col: hdmf.VectorData[NDArray[Shape["3, 3"], int]]
col = hdmf.VectorData(value=np.arange(10))
inst = MyDT(existing_col=col)
# regular lookup for attrs that exist
# pdb.set_trace()
# inst.existing_col
# assert inst.existing_col == col
# df lookup otherwise
# inst.columns