From 01e46f753160a8217059f818b61aff21040e3f17 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Tue, 13 Aug 2024 21:25:56 -0700 Subject: [PATCH] Make VectorData and VectorIndex generics to ensure coercion to VectorData for declared columns --- nwb_linkml/pyproject.toml | 6 +- .../src/nwb_linkml/generators/pydantic.py | 23 ++++ nwb_linkml/src/nwb_linkml/includes/hdmf.py | 59 +++++++--- nwb_linkml/src/nwb_linkml/io/schema.py | 2 +- nwb_linkml/tests/test_includes/test_hdmf.py | 107 +++++++++++++++++- 5 files changed, 176 insertions(+), 21 deletions(-) diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 8119a7f..ffe0f54 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -36,14 +36,10 @@ plot = [ "dash-cytoscape<1.0.0,>=0.3.0", ] tests = [ - "nwb-linkml[plot]", + "nwb-linkml", "pytest<8.0.0,>=7.4.0", "pytest-depends<2.0.0,>=1.0.1", - "coverage<7.0.0,>=6.1.1", - "pytest-md<1.0.0,>=0.2.0", "pytest-cov<5.0.0,>=4.1.0", - "coveralls<4.0.0,>=3.3.1", - "pytest-profiling<2.0.0,>=1.7.0", "sybil<6.0.0,>=5.0.3", "requests-cache>=1.2.1", ] diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 35ae598..0cdfd23 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -116,6 +116,7 @@ class NWBPydanticGenerator(PydanticGenerator): def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult: """Customize dynamictable behavior""" cls = AfterGenerateClass.inject_dynamictable(cls) + cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv) return cls def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: @@ -278,6 +279,28 @@ class AfterGenerateClass: return cls + @staticmethod + def wrap_dynamictable_columns(cls: ClassResult, sv: SchemaView) -> ClassResult: + """ + Wrap NDArray columns inside of dynamictables with ``VectorData`` or + ``VectorIndex``, which are generic classes whose value slot is + parameterized by the NDArray + """ + if cls.source.is_a == "DynamicTable" or "DynamicTable" in sv.class_ancestors( + cls.source.name + ): + for an_attr in cls.cls.attributes: + if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range): + if an_attr.endswith("_index"): + cls.cls.attributes[an_attr].range = "".join( + ["VectorIndex[", slot_range, "]"] + ) + else: + cls.cls.attributes[an_attr].range = "".join( + ["VectorData[", slot_range, "]"] + ) + return cls + def compile_python( text_or_fn: str, package_path: Path = None, module_name: str = "test" diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 2ddf2cc..e4b9ff1 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -2,15 +2,18 @@ Special types for mimicking HDMF special case behavior """ +import sys from typing import ( TYPE_CHECKING, Any, ClassVar, Dict, + Generic, Iterable, List, Optional, Tuple, + TypeVar, Union, overload, ) @@ -33,6 +36,9 @@ from pydantic import ( if TYPE_CHECKING: from nwb_linkml.models import VectorData, VectorIndex +T = TypeVar("T", bound=NDArray) +T_INJECT = 'T = TypeVar("T", bound=NDArray)' + class DynamicTableMixin(BaseModel): """ @@ -219,20 +225,28 @@ class DynamicTableMixin(BaseModel): anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order """ if "colnames" not in model: - colnames = [ - k for k in model if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index") - ] - model["colnames"] = colnames - else: - # add any columns not explicitly given an order at the end colnames = [ k for k in model if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index") - and k not in model["colnames"] + and not isinstance(model[k], VectorIndexMixin) ] - model["colnames"].extend(colnames) + model["colnames"] = colnames + else: + # add any columns not explicitly given an order at the end + colnames = model["colnames"].copy() + colnames.extend( + [ + k + for k in model + if k not in cls.NON_COLUMN_FIELDS + and not k.endswith("_index") + and k not in model["colnames"] + and not isinstance(model[k], VectorIndexMixin) + ] + ) + model["colnames"] = colnames return model @model_validator(mode="after") @@ -322,7 +336,7 @@ class DynamicTableMixin(BaseModel): ) -class VectorDataMixin(BaseModel): +class VectorDataMixin(BaseModel, Generic[T]): """ Mixin class to give VectorData indexing abilities """ @@ -330,7 +344,7 @@ class VectorDataMixin(BaseModel): _index: Optional["VectorIndex"] = None # redefined in `VectorData`, but included here for testing and type checking - value: Optional[NDArray] = None + value: Optional[T] = None def __init__(self, value: Optional[NDArray] = None, **kwargs): if value is not None and "value" not in kwargs: @@ -373,13 +387,13 @@ class VectorDataMixin(BaseModel): return len(self.value) -class VectorIndexMixin(BaseModel): +class VectorIndexMixin(BaseModel, Generic[T]): """ Mixin class to give VectorIndex indexing abilities """ # redefined in `VectorData`, but included here for testing and type checking - value: Optional[NDArray] = None + value: Optional[T] = None target: Optional["VectorData"] = None def __init__(self, value: Optional[NDArray] = None, **kwargs): @@ -649,8 +663,10 @@ DYNAMIC_TABLE_IMPORTS = Imports( module="typing", objects=[ ObjectImport(name="ClassVar"), + ObjectImport(name="Generic"), ObjectImport(name="Iterable"), ObjectImport(name="Tuple"), + ObjectImport(name="TypeVar"), ObjectImport(name="overload"), ], ), @@ -677,6 +693,7 @@ VectorData is purposefully excluded as an import or an inject so that it will be resolved to the VectorData definition in the generated module """ DYNAMIC_TABLE_INJECTS = [ + T_INJECT, VectorDataMixin, VectorIndexMixin, DynamicTableRegionMixin, @@ -689,13 +706,27 @@ TSRVD_IMPORTS = Imports( Import( module="typing", objects=[ - ObjectImport(name="overload"), + ObjectImport(name="Generic"), ObjectImport(name="Iterable"), ObjectImport(name="Tuple"), + ObjectImport(name="TypeVar"), + ObjectImport(name="overload"), ], ), Import(module="pydantic", objects=[ObjectImport(name="model_validator")]), ] ) """Imports for TimeSeriesReferenceVectorData""" -TSRVD_INJECTS = [VectorDataMixin, TimeSeriesReferenceVectorDataMixin] +TSRVD_INJECTS = [T_INJECT, VectorDataMixin, TimeSeriesReferenceVectorDataMixin] + +if "pytest" in sys.modules: + # during testing define concrete subclasses... + class VectorData(VectorDataMixin): + """VectorData subclass for testing""" + + pass + + class VectorIndex(VectorIndexMixin): + """VectorIndex subclass for testing""" + + pass diff --git a/nwb_linkml/src/nwb_linkml/io/schema.py b/nwb_linkml/src/nwb_linkml/io/schema.py index 029fc70..42718f5 100644 --- a/nwb_linkml/src/nwb_linkml/io/schema.py +++ b/nwb_linkml/src/nwb_linkml/io/schema.py @@ -114,7 +114,7 @@ def load_namespace_adapter( for ns in namespaces.namespaces: for schema in ns.schema_: if schema.source is None: - if imported is None and schema.namespace == "hdmf-common": + if imported is None and schema.namespace == "hdmf-common" and ns.name == "core": # special case - hdmf-common is imported by name without location or version, # so to get the correct version we have to handle it separately imported = _resolve_hdmf(namespace, path) diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index f7fd862..bde829b 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -1,5 +1,9 @@ import numpy as np import pandas as pd +from numpydantic import NDArray, Shape + +from nwb_linkml.includes import hdmf +from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin # FIXME: Make this just be the output of the provider by patching into import machinery from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( @@ -48,7 +52,7 @@ def test_dynamictable_indexing(electrical_series): # get a single column col = electrodes["y"] - assert all(col == [5, 6, 7, 8, 9]) + assert all(col.value == [5, 6, 7, 8, 9]) # get a single cell val = electrodes[0, "y"] @@ -198,3 +202,104 @@ def test_aligned_dynamictable(intracellular_recordings_table): for i in range(len(stims)): assert isinstance(stims[i], VoltageClampStimulusSeries) assert all([i == val for val in stims[i][:]]) + + +# -------------------------------------------------- +# Direct mixin tests +# -------------------------------------------------- + + +def test_dynamictable_mixin_indexing(): + """ + This is just a placeholder test to say that indexing is tested above + with actual model objects in case i ever ctrl+f for this + """ + pass + + +def test_dynamictable_mixin_colnames(): + """ + Should correctly infer colnames + """ + + class MyDT(DynamicTableMixin): + existing_col: NDArray[Shape["* col"], int] + + new_col_1 = VectorDataMixin(value=np.arange(10)) + new_col_2 = VectorDataMixin(value=np.arange(10)) + + inst = MyDT(existing_col=np.arange(10), new_col_1=new_col_1, new_col_2=new_col_2) + assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"] + + +def test_dynamictable_mixin_colnames_index(): + """ + Exclude index columns in colnames + """ + + class MyDT(DynamicTableMixin): + existing_col: NDArray[Shape["* col"], int] + + cols = { + "existing_col": np.arange(10), + "new_col_1": hdmf.VectorData(value=np.arange(10)), + "new_col_2": hdmf.VectorData(value=np.arange(10)), + } + # explicit index with mismatching name + cols["weirdname_index"] = VectorIndexMixin(value=np.arange(10), target=cols["new_col_1"]) + # implicit index with matching name + cols["new_col_2_index"] = VectorIndexMixin(value=np.arange(10)) + + inst = MyDT(**cols) + assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"] + + +def test_dynamictable_mixin_colnames_ordered(): + """ + Should be able to pass explicit order to colnames + """ + + class MyDT(DynamicTableMixin): + existing_col: NDArray[Shape["* col"], int] + + cols = { + "existing_col": np.arange(10), + "new_col_1": hdmf.VectorData(value=np.arange(10)), + "new_col_2": hdmf.VectorData(value=np.arange(10)), + "new_col_3": hdmf.VectorData(value=np.arange(10)), + } + order = ["new_col_2", "existing_col", "new_col_1", "new_col_3"] + + inst = MyDT(**cols, colnames=order) + assert inst.colnames == order + + # this should get reflected in the columns selector and the df produces + assert all([key1 == key2 for key1, key2 in zip(order, inst._columns)]) + assert all(inst[0].columns == order) + + # partial lists should append unnamed columsn at the end + partial_order = ["new_col_3", "new_col_2"] + inst = MyDT(**cols, colnames=partial_order) + assert inst.colnames == [*partial_order, "existing_col", "new_col_1"] + + +def test_dynamictable_mixin_getattr(): + """ + Dynamictable should forward unknown getattr requests to the df + """ + + class MyDT(DynamicTableMixin): + existing_col: NDArray[Shape["* col"], int] + + class AModel(DynamicTableMixin): + col: hdmf.VectorData[NDArray[Shape["3, 3"], int]] + + col = hdmf.VectorData(value=np.arange(10)) + inst = MyDT(existing_col=col) + # regular lookup for attrs that exist + + # pdb.set_trace() + # inst.existing_col + # assert inst.existing_col == col + # df lookup otherwise + # inst.columns