mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-10 00:34:29 +00:00
Make VectorData and VectorIndex generics to ensure coercion to VectorData for declared columns
This commit is contained in:
parent
50005d33e5
commit
01e46f7531
5 changed files with 176 additions and 21 deletions
|
@ -36,14 +36,10 @@ plot = [
|
||||||
"dash-cytoscape<1.0.0,>=0.3.0",
|
"dash-cytoscape<1.0.0,>=0.3.0",
|
||||||
]
|
]
|
||||||
tests = [
|
tests = [
|
||||||
"nwb-linkml[plot]",
|
"nwb-linkml",
|
||||||
"pytest<8.0.0,>=7.4.0",
|
"pytest<8.0.0,>=7.4.0",
|
||||||
"pytest-depends<2.0.0,>=1.0.1",
|
"pytest-depends<2.0.0,>=1.0.1",
|
||||||
"coverage<7.0.0,>=6.1.1",
|
|
||||||
"pytest-md<1.0.0,>=0.2.0",
|
|
||||||
"pytest-cov<5.0.0,>=4.1.0",
|
"pytest-cov<5.0.0,>=4.1.0",
|
||||||
"coveralls<4.0.0,>=3.3.1",
|
|
||||||
"pytest-profiling<2.0.0,>=1.7.0",
|
|
||||||
"sybil<6.0.0,>=5.0.3",
|
"sybil<6.0.0,>=5.0.3",
|
||||||
"requests-cache>=1.2.1",
|
"requests-cache>=1.2.1",
|
||||||
]
|
]
|
||||||
|
|
|
@ -116,6 +116,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
|
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
|
||||||
"""Customize dynamictable behavior"""
|
"""Customize dynamictable behavior"""
|
||||||
cls = AfterGenerateClass.inject_dynamictable(cls)
|
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||||
|
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
||||||
|
@ -278,6 +279,28 @@ class AfterGenerateClass:
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wrap_dynamictable_columns(cls: ClassResult, sv: SchemaView) -> ClassResult:
|
||||||
|
"""
|
||||||
|
Wrap NDArray columns inside of dynamictables with ``VectorData`` or
|
||||||
|
``VectorIndex``, which are generic classes whose value slot is
|
||||||
|
parameterized by the NDArray
|
||||||
|
"""
|
||||||
|
if cls.source.is_a == "DynamicTable" or "DynamicTable" in sv.class_ancestors(
|
||||||
|
cls.source.name
|
||||||
|
):
|
||||||
|
for an_attr in cls.cls.attributes:
|
||||||
|
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
|
||||||
|
if an_attr.endswith("_index"):
|
||||||
|
cls.cls.attributes[an_attr].range = "".join(
|
||||||
|
["VectorIndex[", slot_range, "]"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cls.cls.attributes[an_attr].range = "".join(
|
||||||
|
["VectorData[", slot_range, "]"]
|
||||||
|
)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def compile_python(
|
def compile_python(
|
||||||
text_or_fn: str, package_path: Path = None, module_name: str = "test"
|
text_or_fn: str, package_path: Path = None, module_name: str = "test"
|
||||||
|
|
|
@ -2,15 +2,18 @@
|
||||||
Special types for mimicking HDMF special case behavior
|
Special types for mimicking HDMF special case behavior
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
@ -33,6 +36,9 @@ from pydantic import (
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nwb_linkml.models import VectorData, VectorIndex
|
from nwb_linkml.models import VectorData, VectorIndex
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=NDArray)
|
||||||
|
T_INJECT = 'T = TypeVar("T", bound=NDArray)'
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -220,19 +226,27 @@ class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
if "colnames" not in model:
|
if "colnames" not in model:
|
||||||
colnames = [
|
colnames = [
|
||||||
k for k in model if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index")
|
k
|
||||||
|
for k in model
|
||||||
|
if k not in cls.NON_COLUMN_FIELDS
|
||||||
|
and not k.endswith("_index")
|
||||||
|
and not isinstance(model[k], VectorIndexMixin)
|
||||||
]
|
]
|
||||||
model["colnames"] = colnames
|
model["colnames"] = colnames
|
||||||
else:
|
else:
|
||||||
# add any columns not explicitly given an order at the end
|
# add any columns not explicitly given an order at the end
|
||||||
colnames = [
|
colnames = model["colnames"].copy()
|
||||||
|
colnames.extend(
|
||||||
|
[
|
||||||
k
|
k
|
||||||
for k in model
|
for k in model
|
||||||
if k not in cls.NON_COLUMN_FIELDS
|
if k not in cls.NON_COLUMN_FIELDS
|
||||||
and not k.endswith("_index")
|
and not k.endswith("_index")
|
||||||
and k not in model["colnames"]
|
and k not in model["colnames"]
|
||||||
|
and not isinstance(model[k], VectorIndexMixin)
|
||||||
]
|
]
|
||||||
model["colnames"].extend(colnames)
|
)
|
||||||
|
model["colnames"] = colnames
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
@ -322,7 +336,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VectorDataMixin(BaseModel):
|
class VectorDataMixin(BaseModel, Generic[T]):
|
||||||
"""
|
"""
|
||||||
Mixin class to give VectorData indexing abilities
|
Mixin class to give VectorData indexing abilities
|
||||||
"""
|
"""
|
||||||
|
@ -330,7 +344,7 @@ class VectorDataMixin(BaseModel):
|
||||||
_index: Optional["VectorIndex"] = None
|
_index: Optional["VectorIndex"] = None
|
||||||
|
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[T] = None
|
||||||
|
|
||||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
if value is not None and "value" not in kwargs:
|
if value is not None and "value" not in kwargs:
|
||||||
|
@ -373,13 +387,13 @@ class VectorDataMixin(BaseModel):
|
||||||
return len(self.value)
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
class VectorIndexMixin(BaseModel):
|
class VectorIndexMixin(BaseModel, Generic[T]):
|
||||||
"""
|
"""
|
||||||
Mixin class to give VectorIndex indexing abilities
|
Mixin class to give VectorIndex indexing abilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[NDArray] = None
|
value: Optional[T] = None
|
||||||
target: Optional["VectorData"] = None
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
|
@ -649,8 +663,10 @@ DYNAMIC_TABLE_IMPORTS = Imports(
|
||||||
module="typing",
|
module="typing",
|
||||||
objects=[
|
objects=[
|
||||||
ObjectImport(name="ClassVar"),
|
ObjectImport(name="ClassVar"),
|
||||||
|
ObjectImport(name="Generic"),
|
||||||
ObjectImport(name="Iterable"),
|
ObjectImport(name="Iterable"),
|
||||||
ObjectImport(name="Tuple"),
|
ObjectImport(name="Tuple"),
|
||||||
|
ObjectImport(name="TypeVar"),
|
||||||
ObjectImport(name="overload"),
|
ObjectImport(name="overload"),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
@ -677,6 +693,7 @@ VectorData is purposefully excluded as an import or an inject so that it will be
|
||||||
resolved to the VectorData definition in the generated module
|
resolved to the VectorData definition in the generated module
|
||||||
"""
|
"""
|
||||||
DYNAMIC_TABLE_INJECTS = [
|
DYNAMIC_TABLE_INJECTS = [
|
||||||
|
T_INJECT,
|
||||||
VectorDataMixin,
|
VectorDataMixin,
|
||||||
VectorIndexMixin,
|
VectorIndexMixin,
|
||||||
DynamicTableRegionMixin,
|
DynamicTableRegionMixin,
|
||||||
|
@ -689,13 +706,27 @@ TSRVD_IMPORTS = Imports(
|
||||||
Import(
|
Import(
|
||||||
module="typing",
|
module="typing",
|
||||||
objects=[
|
objects=[
|
||||||
ObjectImport(name="overload"),
|
ObjectImport(name="Generic"),
|
||||||
ObjectImport(name="Iterable"),
|
ObjectImport(name="Iterable"),
|
||||||
ObjectImport(name="Tuple"),
|
ObjectImport(name="Tuple"),
|
||||||
|
ObjectImport(name="TypeVar"),
|
||||||
|
ObjectImport(name="overload"),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
|
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
"""Imports for TimeSeriesReferenceVectorData"""
|
"""Imports for TimeSeriesReferenceVectorData"""
|
||||||
TSRVD_INJECTS = [VectorDataMixin, TimeSeriesReferenceVectorDataMixin]
|
TSRVD_INJECTS = [T_INJECT, VectorDataMixin, TimeSeriesReferenceVectorDataMixin]
|
||||||
|
|
||||||
|
if "pytest" in sys.modules:
|
||||||
|
# during testing define concrete subclasses...
|
||||||
|
class VectorData(VectorDataMixin):
|
||||||
|
"""VectorData subclass for testing"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
class VectorIndex(VectorIndexMixin):
|
||||||
|
"""VectorIndex subclass for testing"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
|
@ -114,7 +114,7 @@ def load_namespace_adapter(
|
||||||
for ns in namespaces.namespaces:
|
for ns in namespaces.namespaces:
|
||||||
for schema in ns.schema_:
|
for schema in ns.schema_:
|
||||||
if schema.source is None:
|
if schema.source is None:
|
||||||
if imported is None and schema.namespace == "hdmf-common":
|
if imported is None and schema.namespace == "hdmf-common" and ns.name == "core":
|
||||||
# special case - hdmf-common is imported by name without location or version,
|
# special case - hdmf-common is imported by name without location or version,
|
||||||
# so to get the correct version we have to handle it separately
|
# so to get the correct version we have to handle it separately
|
||||||
imported = _resolve_hdmf(namespace, path)
|
imported = _resolve_hdmf(namespace, path)
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
|
from nwb_linkml.includes import hdmf
|
||||||
|
from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin
|
||||||
|
|
||||||
# FIXME: Make this just be the output of the provider by patching into import machinery
|
# FIXME: Make this just be the output of the provider by patching into import machinery
|
||||||
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||||
|
@ -48,7 +52,7 @@ def test_dynamictable_indexing(electrical_series):
|
||||||
|
|
||||||
# get a single column
|
# get a single column
|
||||||
col = electrodes["y"]
|
col = electrodes["y"]
|
||||||
assert all(col == [5, 6, 7, 8, 9])
|
assert all(col.value == [5, 6, 7, 8, 9])
|
||||||
|
|
||||||
# get a single cell
|
# get a single cell
|
||||||
val = electrodes[0, "y"]
|
val = electrodes[0, "y"]
|
||||||
|
@ -198,3 +202,104 @@ def test_aligned_dynamictable(intracellular_recordings_table):
|
||||||
for i in range(len(stims)):
|
for i in range(len(stims)):
|
||||||
assert isinstance(stims[i], VoltageClampStimulusSeries)
|
assert isinstance(stims[i], VoltageClampStimulusSeries)
|
||||||
assert all([i == val for val in stims[i][:]])
|
assert all([i == val for val in stims[i][:]])
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------
|
||||||
|
# Direct mixin tests
|
||||||
|
# --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_mixin_indexing():
|
||||||
|
"""
|
||||||
|
This is just a placeholder test to say that indexing is tested above
|
||||||
|
with actual model objects in case i ever ctrl+f for this
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_mixin_colnames():
|
||||||
|
"""
|
||||||
|
Should correctly infer colnames
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyDT(DynamicTableMixin):
|
||||||
|
existing_col: NDArray[Shape["* col"], int]
|
||||||
|
|
||||||
|
new_col_1 = VectorDataMixin(value=np.arange(10))
|
||||||
|
new_col_2 = VectorDataMixin(value=np.arange(10))
|
||||||
|
|
||||||
|
inst = MyDT(existing_col=np.arange(10), new_col_1=new_col_1, new_col_2=new_col_2)
|
||||||
|
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_mixin_colnames_index():
|
||||||
|
"""
|
||||||
|
Exclude index columns in colnames
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyDT(DynamicTableMixin):
|
||||||
|
existing_col: NDArray[Shape["* col"], int]
|
||||||
|
|
||||||
|
cols = {
|
||||||
|
"existing_col": np.arange(10),
|
||||||
|
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||||
|
"new_col_2": hdmf.VectorData(value=np.arange(10)),
|
||||||
|
}
|
||||||
|
# explicit index with mismatching name
|
||||||
|
cols["weirdname_index"] = VectorIndexMixin(value=np.arange(10), target=cols["new_col_1"])
|
||||||
|
# implicit index with matching name
|
||||||
|
cols["new_col_2_index"] = VectorIndexMixin(value=np.arange(10))
|
||||||
|
|
||||||
|
inst = MyDT(**cols)
|
||||||
|
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_mixin_colnames_ordered():
|
||||||
|
"""
|
||||||
|
Should be able to pass explicit order to colnames
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyDT(DynamicTableMixin):
|
||||||
|
existing_col: NDArray[Shape["* col"], int]
|
||||||
|
|
||||||
|
cols = {
|
||||||
|
"existing_col": np.arange(10),
|
||||||
|
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||||
|
"new_col_2": hdmf.VectorData(value=np.arange(10)),
|
||||||
|
"new_col_3": hdmf.VectorData(value=np.arange(10)),
|
||||||
|
}
|
||||||
|
order = ["new_col_2", "existing_col", "new_col_1", "new_col_3"]
|
||||||
|
|
||||||
|
inst = MyDT(**cols, colnames=order)
|
||||||
|
assert inst.colnames == order
|
||||||
|
|
||||||
|
# this should get reflected in the columns selector and the df produces
|
||||||
|
assert all([key1 == key2 for key1, key2 in zip(order, inst._columns)])
|
||||||
|
assert all(inst[0].columns == order)
|
||||||
|
|
||||||
|
# partial lists should append unnamed columsn at the end
|
||||||
|
partial_order = ["new_col_3", "new_col_2"]
|
||||||
|
inst = MyDT(**cols, colnames=partial_order)
|
||||||
|
assert inst.colnames == [*partial_order, "existing_col", "new_col_1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamictable_mixin_getattr():
|
||||||
|
"""
|
||||||
|
Dynamictable should forward unknown getattr requests to the df
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyDT(DynamicTableMixin):
|
||||||
|
existing_col: NDArray[Shape["* col"], int]
|
||||||
|
|
||||||
|
class AModel(DynamicTableMixin):
|
||||||
|
col: hdmf.VectorData[NDArray[Shape["3, 3"], int]]
|
||||||
|
|
||||||
|
col = hdmf.VectorData(value=np.arange(10))
|
||||||
|
inst = MyDT(existing_col=col)
|
||||||
|
# regular lookup for attrs that exist
|
||||||
|
|
||||||
|
# pdb.set_trace()
|
||||||
|
# inst.existing_col
|
||||||
|
# assert inst.existing_col == col
|
||||||
|
# df lookup otherwise
|
||||||
|
# inst.columns
|
||||||
|
|
Loading…
Reference in a new issue