mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 06:04:28 +00:00
Make VectorData and VectorIndex generics to ensure coercion to VectorData for declared columns
This commit is contained in:
parent
50005d33e5
commit
01e46f7531
5 changed files with 176 additions and 21 deletions
|
@ -36,14 +36,10 @@ plot = [
|
|||
"dash-cytoscape<1.0.0,>=0.3.0",
|
||||
]
|
||||
tests = [
|
||||
"nwb-linkml[plot]",
|
||||
"nwb-linkml",
|
||||
"pytest<8.0.0,>=7.4.0",
|
||||
"pytest-depends<2.0.0,>=1.0.1",
|
||||
"coverage<7.0.0,>=6.1.1",
|
||||
"pytest-md<1.0.0,>=0.2.0",
|
||||
"pytest-cov<5.0.0,>=4.1.0",
|
||||
"coveralls<4.0.0,>=3.3.1",
|
||||
"pytest-profiling<2.0.0,>=1.7.0",
|
||||
"sybil<6.0.0,>=5.0.3",
|
||||
"requests-cache>=1.2.1",
|
||||
]
|
||||
|
|
|
@ -116,6 +116,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
|
||||
"""Customize dynamictable behavior"""
|
||||
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
||||
return cls
|
||||
|
||||
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
||||
|
@ -278,6 +279,28 @@ class AfterGenerateClass:
|
|||
|
||||
return cls
|
||||
|
||||
@staticmethod
|
||||
def wrap_dynamictable_columns(cls: ClassResult, sv: SchemaView) -> ClassResult:
|
||||
"""
|
||||
Wrap NDArray columns inside of dynamictables with ``VectorData`` or
|
||||
``VectorIndex``, which are generic classes whose value slot is
|
||||
parameterized by the NDArray
|
||||
"""
|
||||
if cls.source.is_a == "DynamicTable" or "DynamicTable" in sv.class_ancestors(
|
||||
cls.source.name
|
||||
):
|
||||
for an_attr in cls.cls.attributes:
|
||||
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
|
||||
if an_attr.endswith("_index"):
|
||||
cls.cls.attributes[an_attr].range = "".join(
|
||||
["VectorIndex[", slot_range, "]"]
|
||||
)
|
||||
else:
|
||||
cls.cls.attributes[an_attr].range = "".join(
|
||||
["VectorData[", slot_range, "]"]
|
||||
)
|
||||
return cls
|
||||
|
||||
|
||||
def compile_python(
|
||||
text_or_fn: str, package_path: Path = None, module_name: str = "test"
|
||||
|
|
|
@ -2,15 +2,18 @@
|
|||
Special types for mimicking HDMF special case behavior
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
@ -33,6 +36,9 @@ from pydantic import (
|
|||
if TYPE_CHECKING:
|
||||
from nwb_linkml.models import VectorData, VectorIndex
|
||||
|
||||
T = TypeVar("T", bound=NDArray)
|
||||
T_INJECT = 'T = TypeVar("T", bound=NDArray)'
|
||||
|
||||
|
||||
class DynamicTableMixin(BaseModel):
|
||||
"""
|
||||
|
@ -219,20 +225,28 @@ class DynamicTableMixin(BaseModel):
|
|||
anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order
|
||||
"""
|
||||
if "colnames" not in model:
|
||||
colnames = [
|
||||
k for k in model if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index")
|
||||
]
|
||||
model["colnames"] = colnames
|
||||
else:
|
||||
# add any columns not explicitly given an order at the end
|
||||
colnames = [
|
||||
k
|
||||
for k in model
|
||||
if k not in cls.NON_COLUMN_FIELDS
|
||||
and not k.endswith("_index")
|
||||
and k not in model["colnames"]
|
||||
and not isinstance(model[k], VectorIndexMixin)
|
||||
]
|
||||
model["colnames"].extend(colnames)
|
||||
model["colnames"] = colnames
|
||||
else:
|
||||
# add any columns not explicitly given an order at the end
|
||||
colnames = model["colnames"].copy()
|
||||
colnames.extend(
|
||||
[
|
||||
k
|
||||
for k in model
|
||||
if k not in cls.NON_COLUMN_FIELDS
|
||||
and not k.endswith("_index")
|
||||
and k not in model["colnames"]
|
||||
and not isinstance(model[k], VectorIndexMixin)
|
||||
]
|
||||
)
|
||||
model["colnames"] = colnames
|
||||
return model
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -322,7 +336,7 @@ class DynamicTableMixin(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class VectorDataMixin(BaseModel):
|
||||
class VectorDataMixin(BaseModel, Generic[T]):
|
||||
"""
|
||||
Mixin class to give VectorData indexing abilities
|
||||
"""
|
||||
|
@ -330,7 +344,7 @@ class VectorDataMixin(BaseModel):
|
|||
_index: Optional["VectorIndex"] = None
|
||||
|
||||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
value: Optional[T] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
|
@ -373,13 +387,13 @@ class VectorDataMixin(BaseModel):
|
|||
return len(self.value)
|
||||
|
||||
|
||||
class VectorIndexMixin(BaseModel):
|
||||
class VectorIndexMixin(BaseModel, Generic[T]):
|
||||
"""
|
||||
Mixin class to give VectorIndex indexing abilities
|
||||
"""
|
||||
|
||||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[NDArray] = None
|
||||
value: Optional[T] = None
|
||||
target: Optional["VectorData"] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
|
@ -649,8 +663,10 @@ DYNAMIC_TABLE_IMPORTS = Imports(
|
|||
module="typing",
|
||||
objects=[
|
||||
ObjectImport(name="ClassVar"),
|
||||
ObjectImport(name="Generic"),
|
||||
ObjectImport(name="Iterable"),
|
||||
ObjectImport(name="Tuple"),
|
||||
ObjectImport(name="TypeVar"),
|
||||
ObjectImport(name="overload"),
|
||||
],
|
||||
),
|
||||
|
@ -677,6 +693,7 @@ VectorData is purposefully excluded as an import or an inject so that it will be
|
|||
resolved to the VectorData definition in the generated module
|
||||
"""
|
||||
DYNAMIC_TABLE_INJECTS = [
|
||||
T_INJECT,
|
||||
VectorDataMixin,
|
||||
VectorIndexMixin,
|
||||
DynamicTableRegionMixin,
|
||||
|
@ -689,13 +706,27 @@ TSRVD_IMPORTS = Imports(
|
|||
Import(
|
||||
module="typing",
|
||||
objects=[
|
||||
ObjectImport(name="overload"),
|
||||
ObjectImport(name="Generic"),
|
||||
ObjectImport(name="Iterable"),
|
||||
ObjectImport(name="Tuple"),
|
||||
ObjectImport(name="TypeVar"),
|
||||
ObjectImport(name="overload"),
|
||||
],
|
||||
),
|
||||
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
|
||||
]
|
||||
)
|
||||
"""Imports for TimeSeriesReferenceVectorData"""
|
||||
TSRVD_INJECTS = [VectorDataMixin, TimeSeriesReferenceVectorDataMixin]
|
||||
TSRVD_INJECTS = [T_INJECT, VectorDataMixin, TimeSeriesReferenceVectorDataMixin]
|
||||
|
||||
if "pytest" in sys.modules:
|
||||
# during testing define concrete subclasses...
|
||||
class VectorData(VectorDataMixin):
|
||||
"""VectorData subclass for testing"""
|
||||
|
||||
pass
|
||||
|
||||
class VectorIndex(VectorIndexMixin):
|
||||
"""VectorIndex subclass for testing"""
|
||||
|
||||
pass
|
||||
|
|
|
@ -114,7 +114,7 @@ def load_namespace_adapter(
|
|||
for ns in namespaces.namespaces:
|
||||
for schema in ns.schema_:
|
||||
if schema.source is None:
|
||||
if imported is None and schema.namespace == "hdmf-common":
|
||||
if imported is None and schema.namespace == "hdmf-common" and ns.name == "core":
|
||||
# special case - hdmf-common is imported by name without location or version,
|
||||
# so to get the correct version we have to handle it separately
|
||||
imported = _resolve_hdmf(namespace, path)
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
from nwb_linkml.includes import hdmf
|
||||
from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin
|
||||
|
||||
# FIXME: Make this just be the output of the provider by patching into import machinery
|
||||
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||
|
@ -48,7 +52,7 @@ def test_dynamictable_indexing(electrical_series):
|
|||
|
||||
# get a single column
|
||||
col = electrodes["y"]
|
||||
assert all(col == [5, 6, 7, 8, 9])
|
||||
assert all(col.value == [5, 6, 7, 8, 9])
|
||||
|
||||
# get a single cell
|
||||
val = electrodes[0, "y"]
|
||||
|
@ -198,3 +202,104 @@ def test_aligned_dynamictable(intracellular_recordings_table):
|
|||
for i in range(len(stims)):
|
||||
assert isinstance(stims[i], VoltageClampStimulusSeries)
|
||||
assert all([i == val for val in stims[i][:]])
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# Direct mixin tests
|
||||
# --------------------------------------------------
|
||||
|
||||
|
||||
def test_dynamictable_mixin_indexing():
|
||||
"""
|
||||
This is just a placeholder test to say that indexing is tested above
|
||||
with actual model objects in case i ever ctrl+f for this
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def test_dynamictable_mixin_colnames():
|
||||
"""
|
||||
Should correctly infer colnames
|
||||
"""
|
||||
|
||||
class MyDT(DynamicTableMixin):
|
||||
existing_col: NDArray[Shape["* col"], int]
|
||||
|
||||
new_col_1 = VectorDataMixin(value=np.arange(10))
|
||||
new_col_2 = VectorDataMixin(value=np.arange(10))
|
||||
|
||||
inst = MyDT(existing_col=np.arange(10), new_col_1=new_col_1, new_col_2=new_col_2)
|
||||
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
|
||||
|
||||
|
||||
def test_dynamictable_mixin_colnames_index():
|
||||
"""
|
||||
Exclude index columns in colnames
|
||||
"""
|
||||
|
||||
class MyDT(DynamicTableMixin):
|
||||
existing_col: NDArray[Shape["* col"], int]
|
||||
|
||||
cols = {
|
||||
"existing_col": np.arange(10),
|
||||
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||
"new_col_2": hdmf.VectorData(value=np.arange(10)),
|
||||
}
|
||||
# explicit index with mismatching name
|
||||
cols["weirdname_index"] = VectorIndexMixin(value=np.arange(10), target=cols["new_col_1"])
|
||||
# implicit index with matching name
|
||||
cols["new_col_2_index"] = VectorIndexMixin(value=np.arange(10))
|
||||
|
||||
inst = MyDT(**cols)
|
||||
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
|
||||
|
||||
|
||||
def test_dynamictable_mixin_colnames_ordered():
|
||||
"""
|
||||
Should be able to pass explicit order to colnames
|
||||
"""
|
||||
|
||||
class MyDT(DynamicTableMixin):
|
||||
existing_col: NDArray[Shape["* col"], int]
|
||||
|
||||
cols = {
|
||||
"existing_col": np.arange(10),
|
||||
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||
"new_col_2": hdmf.VectorData(value=np.arange(10)),
|
||||
"new_col_3": hdmf.VectorData(value=np.arange(10)),
|
||||
}
|
||||
order = ["new_col_2", "existing_col", "new_col_1", "new_col_3"]
|
||||
|
||||
inst = MyDT(**cols, colnames=order)
|
||||
assert inst.colnames == order
|
||||
|
||||
# this should get reflected in the columns selector and the df produces
|
||||
assert all([key1 == key2 for key1, key2 in zip(order, inst._columns)])
|
||||
assert all(inst[0].columns == order)
|
||||
|
||||
# partial lists should append unnamed columsn at the end
|
||||
partial_order = ["new_col_3", "new_col_2"]
|
||||
inst = MyDT(**cols, colnames=partial_order)
|
||||
assert inst.colnames == [*partial_order, "existing_col", "new_col_1"]
|
||||
|
||||
|
||||
def test_dynamictable_mixin_getattr():
|
||||
"""
|
||||
Dynamictable should forward unknown getattr requests to the df
|
||||
"""
|
||||
|
||||
class MyDT(DynamicTableMixin):
|
||||
existing_col: NDArray[Shape["* col"], int]
|
||||
|
||||
class AModel(DynamicTableMixin):
|
||||
col: hdmf.VectorData[NDArray[Shape["3, 3"], int]]
|
||||
|
||||
col = hdmf.VectorData(value=np.arange(10))
|
||||
inst = MyDT(existing_col=col)
|
||||
# regular lookup for attrs that exist
|
||||
|
||||
# pdb.set_trace()
|
||||
# inst.existing_col
|
||||
# assert inst.existing_col == col
|
||||
# df lookup otherwise
|
||||
# inst.columns
|
||||
|
|
Loading…
Reference in a new issue