From 4672f54630221eded8a9862a5c85efbe04f90f66 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 3 Oct 2024 00:11:38 -0700 Subject: [PATCH] fix generics with defaults, use typing_extensions --- nwb_linkml/pyproject.toml | 2 +- .../src/nwb_linkml/generators/pydantic.py | 4 +-- nwb_linkml/src/nwb_linkml/includes/hdmf.py | 10 +++---- nwb_linkml/src/nwb_linkml/maps/map.py | 26 ------------------- nwb_linkml/tests/test_io/test_io_hdf5.py | 9 +++++-- nwb_linkml/tests/test_io/test_io_nwb.py | 11 ++++++++ nwb_models/pyproject.toml | 3 ++- .../hdmf_common/v1_8_0/hdmf_common_table.py | 22 ++++++++++------ 8 files changed, 42 insertions(+), 45 deletions(-) delete mode 100644 nwb_linkml/src/nwb_linkml/maps/map.py diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 2670310..903131a 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "h5py>=3.9.0", "pydantic-settings>=2.0.3", "tqdm>=4.66.1", - 'typing-extensions>=4.12.2;python_version<"3.11"', + 'typing-extensions>=4.12.2;python_version<"3.13"', "numpydantic>=1.6.0", "black>=24.4.2", "pandas>=2.2.2", diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 7506ff8..777001e 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -284,9 +284,9 @@ class AfterGenerateClass: cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"] elif cls.cls.name == "ElementIdentifiers": cls.cls.bases = ["ElementIdentifiersMixin", "Data"] - # make ``value`` generic on T + # Formerly make this generic, but that breaks json serialization if "value" in cls.cls.attributes: - cls.cls.attributes["value"].range = "Optional[T]" + cls.cls.attributes["value"].range = "Optional[NDArray]" elif cls.cls.name == "TimeSeriesReferenceVectorData": # in core.nwb.base, so need to inject and import again cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"] diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 3d456d0..6655a41 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -13,10 +13,10 @@ from typing import ( List, Optional, Tuple, - TypeVar, Union, overload, ) +from typing_extensions import TypeVar import numpy as np import pandas as pd @@ -36,8 +36,8 @@ from pydantic import ( if TYPE_CHECKING: # pragma: no cover from nwb_models.models import VectorData, VectorIndex -T = TypeVar("T", bound=NDArray) -T_INJECT = 'T = TypeVar("T", bound=NDArray)' +T = TypeVar("T", default=NDArray) +T_INJECT = 'T = TypeVar("T", default=NDArray)' if "pytest" in sys.modules: from nwb_models.models import ConfiguredBaseModel @@ -71,7 +71,7 @@ class DynamicTableMixin(ConfiguredBaseModel): """ model_config = ConfigDict(extra="allow", validate_assignment=True) - __pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]] + __pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin"]] NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ( "id", "name", @@ -899,10 +899,10 @@ DYNAMIC_TABLE_IMPORTS = Imports( ObjectImport(name="Generic"), ObjectImport(name="Iterable"), ObjectImport(name="Tuple"), - ObjectImport(name="TypeVar"), ObjectImport(name="overload"), ], ), + Import(module="typing_extensions", objects=[ObjectImport(name="TypeVar")]), Import( module="numpydantic", objects=[ObjectImport(name="NDArray"), ObjectImport(name="Shape")] ), diff --git a/nwb_linkml/src/nwb_linkml/maps/map.py b/nwb_linkml/src/nwb_linkml/maps/map.py deleted file mode 100644 index f03a9be..0000000 --- a/nwb_linkml/src/nwb_linkml/maps/map.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Abstract base classes for Map types - -.. todo:: - Make this consistent or don't call them all maps lmao -""" - -from abc import ABC, abstractmethod -from typing import Any, Mapping, Sequence - - -class Map(ABC): - """ - The generic top-level mapping class is just a classmethod for checking if the map applies and a - method for applying the check if it does - """ - - @classmethod - @abstractmethod - def check(cls, *args: Sequence, **kwargs: Mapping) -> bool: - """Check if this map applies to the given item to read""" - - @classmethod - @abstractmethod - def apply(cls, *args: Sequence, **kwargs: Mapping) -> Any: - """Actually apply the map!""" diff --git a/nwb_linkml/tests/test_io/test_io_hdf5.py b/nwb_linkml/tests/test_io/test_io_hdf5.py index 4222a2c..59a2291 100644 --- a/nwb_linkml/tests/test_io/test_io_hdf5.py +++ b/nwb_linkml/tests/test_io/test_io_hdf5.py @@ -3,8 +3,13 @@ import networkx as nx import numpy as np import pytest -from nwb_linkml.io.hdf5 import HDF5IO, filter_dependency_graph, hdf_dependency_graph, truncate_file -from nwb_linkml.maps.hdf5 import resolve_hardlink +from nwb_linkml.io.hdf5 import ( + HDF5IO, + filter_dependency_graph, + hdf_dependency_graph, + truncate_file, + resolve_hardlink, +) @pytest.mark.skip() diff --git a/nwb_linkml/tests/test_io/test_io_nwb.py b/nwb_linkml/tests/test_io/test_io_nwb.py index 32a50d1..b91c64f 100644 --- a/nwb_linkml/tests/test_io/test_io_nwb.py +++ b/nwb_linkml/tests/test_io/test_io_nwb.py @@ -66,6 +66,17 @@ def test_nwbfile_base(read_nwbfile, read_pynwb): _compare_attrs(read_nwbfile, read_pynwb) +def test_nwbfile_dump(read_nwbfile): + electrode_id = read_nwbfile.general.extracellular_ephys.electrodes.id.model_dump_json( + round_trip=True + ) + electrodes = read_nwbfile.general.extracellular_ephys.electrodes.model_dump_json( + round_trip=True + ) + # data = read_nwbfile.model_dump_json(round_trip=True, serialize_as_any=True) + # pdb.set_trace() + + def test_timeseries(read_nwbfile, read_pynwb): py_acq = read_pynwb.get_acquisition("test_timeseries") acq = read_nwbfile.acquisition["test_timeseries"] diff --git a/nwb_models/pyproject.toml b/nwb_models/pyproject.toml index 59b0b6d..b7b3b1a 100644 --- a/nwb_models/pyproject.toml +++ b/nwb_models/pyproject.toml @@ -8,7 +8,8 @@ authors = [ dependencies = [ "pydantic>=2.3.0", "numpydantic>=1.3.3", - "pandas>=2.2.2" + "pandas>=2.2.2", + 'typing-extensions>=4.12.2;python_version<"3.13"', # for default in TypeVar ] requires-python = ">=3.10" readme = "README.md" diff --git a/nwb_models/src/nwb_models/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py b/nwb_models/src/nwb_models/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py index b779c48..0560025 100644 --- a/nwb_models/src/nwb_models/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py +++ b/nwb_models/src/nwb_models/models/pydantic/hdmf_common/v1_8_0/hdmf_common_table.py @@ -1,5 +1,6 @@ from __future__ import annotations +import pdb import re import sys from datetime import date, datetime, time @@ -15,10 +16,10 @@ from typing import ( Literal, Optional, Tuple, - TypeVar, Union, overload, ) +from typing_extensions import TypeVar import numpy as np import pandas as pd @@ -153,7 +154,8 @@ class LinkMLMeta(RootModel): NUMPYDANTIC_VERSION = "1.2.1" -T = TypeVar("T", bound=NDArray) +T = TypeVar("T", default=NDArray) +U = TypeVar("U", default=NDArray) class VectorDataMixin(ConfiguredBaseModel, Generic[T]): @@ -364,7 +366,7 @@ class DynamicTableMixin(ConfiguredBaseModel): """ model_config = ConfigDict(extra="allow", validate_assignment=True) - __pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]] + __pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin"]] NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ( "id", "name", @@ -657,10 +659,14 @@ class DynamicTableMixin(ConfiguredBaseModel): Ensure that all columns are equal length """ lengths = [len(v) for v in self._columns.values() if v is not None] + [len(self.id)] - assert all([length == lengths[0] for length in lengths]), ( - "DynamicTable columns are not of equal length! " - f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" - ) + + try: + assert all([length == lengths[0] for length in lengths]), ( + "DynamicTable columns are not of equal length! " + f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" + ) + except AssertionError: + pdb.set_trace() return self @field_validator("*", mode="wrap") @@ -958,7 +964,7 @@ class ElementIdentifiers(ElementIdentifiersMixin, Data): name: str = Field( "element_id", json_schema_extra={"linkml_meta": {"ifabsent": "string(element_id)"}} ) - value: Optional[T] = Field( + value: Optional[NDArray] = Field( None, json_schema_extra={"linkml_meta": {"array": {"dimensions": [{"alias": "num_elements"}]}}}, )