fix generics with defaults, use typing_extensions

This commit is contained in:
sneakers-the-rat 2024-10-03 00:11:38 -07:00
parent 748b304426
commit 4672f54630
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
8 changed files with 42 additions and 45 deletions

View file

@ -21,7 +21,7 @@ dependencies = [
"h5py>=3.9.0",
"pydantic-settings>=2.0.3",
"tqdm>=4.66.1",
'typing-extensions>=4.12.2;python_version<"3.11"',
'typing-extensions>=4.12.2;python_version<"3.13"',
"numpydantic>=1.6.0",
"black>=24.4.2",
"pandas>=2.2.2",

View file

@ -284,9 +284,9 @@ class AfterGenerateClass:
cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"]
elif cls.cls.name == "ElementIdentifiers":
cls.cls.bases = ["ElementIdentifiersMixin", "Data"]
# make ``value`` generic on T
# Formerly make this generic, but that breaks json serialization
if "value" in cls.cls.attributes:
cls.cls.attributes["value"].range = "Optional[T]"
cls.cls.attributes["value"].range = "Optional[NDArray]"
elif cls.cls.name == "TimeSeriesReferenceVectorData":
# in core.nwb.base, so need to inject and import again
cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"]

View file

@ -13,10 +13,10 @@ from typing import (
List,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from typing_extensions import TypeVar
import numpy as np
import pandas as pd
@ -36,8 +36,8 @@ from pydantic import (
if TYPE_CHECKING: # pragma: no cover
from nwb_models.models import VectorData, VectorIndex
T = TypeVar("T", bound=NDArray)
T_INJECT = 'T = TypeVar("T", bound=NDArray)'
T = TypeVar("T", default=NDArray)
T_INJECT = 'T = TypeVar("T", default=NDArray)'
if "pytest" in sys.modules:
from nwb_models.models import ConfiguredBaseModel
@ -71,7 +71,7 @@ class DynamicTableMixin(ConfiguredBaseModel):
"""
model_config = ConfigDict(extra="allow", validate_assignment=True)
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin"]]
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
"id",
"name",
@ -899,10 +899,10 @@ DYNAMIC_TABLE_IMPORTS = Imports(
ObjectImport(name="Generic"),
ObjectImport(name="Iterable"),
ObjectImport(name="Tuple"),
ObjectImport(name="TypeVar"),
ObjectImport(name="overload"),
],
),
Import(module="typing_extensions", objects=[ObjectImport(name="TypeVar")]),
Import(
module="numpydantic", objects=[ObjectImport(name="NDArray"), ObjectImport(name="Shape")]
),

View file

@ -1,26 +0,0 @@
"""
Abstract base classes for Map types
.. todo::
Make this consistent or don't call them all maps lmao
"""
from abc import ABC, abstractmethod
from typing import Any, Mapping, Sequence
class Map(ABC):
"""
The generic top-level mapping class is just a classmethod for checking if the map applies and a
method for applying the check if it does
"""
@classmethod
@abstractmethod
def check(cls, *args: Sequence, **kwargs: Mapping) -> bool:
"""Check if this map applies to the given item to read"""
@classmethod
@abstractmethod
def apply(cls, *args: Sequence, **kwargs: Mapping) -> Any:
"""Actually apply the map!"""

View file

@ -3,8 +3,13 @@ import networkx as nx
import numpy as np
import pytest
from nwb_linkml.io.hdf5 import HDF5IO, filter_dependency_graph, hdf_dependency_graph, truncate_file
from nwb_linkml.maps.hdf5 import resolve_hardlink
from nwb_linkml.io.hdf5 import (
HDF5IO,
filter_dependency_graph,
hdf_dependency_graph,
truncate_file,
resolve_hardlink,
)
@pytest.mark.skip()

View file

@ -66,6 +66,17 @@ def test_nwbfile_base(read_nwbfile, read_pynwb):
_compare_attrs(read_nwbfile, read_pynwb)
def test_nwbfile_dump(read_nwbfile):
electrode_id = read_nwbfile.general.extracellular_ephys.electrodes.id.model_dump_json(
round_trip=True
)
electrodes = read_nwbfile.general.extracellular_ephys.electrodes.model_dump_json(
round_trip=True
)
# data = read_nwbfile.model_dump_json(round_trip=True, serialize_as_any=True)
# pdb.set_trace()
def test_timeseries(read_nwbfile, read_pynwb):
py_acq = read_pynwb.get_acquisition("test_timeseries")
acq = read_nwbfile.acquisition["test_timeseries"]

View file

@ -8,7 +8,8 @@ authors = [
dependencies = [
"pydantic>=2.3.0",
"numpydantic>=1.3.3",
"pandas>=2.2.2"
"pandas>=2.2.2",
'typing-extensions>=4.12.2;python_version<"3.13"', # for default in TypeVar
]
requires-python = ">=3.10"
readme = "README.md"

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import pdb
import re
import sys
from datetime import date, datetime, time
@ -15,10 +16,10 @@ from typing import (
Literal,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
from typing_extensions import TypeVar
import numpy as np
import pandas as pd
@ -153,7 +154,8 @@ class LinkMLMeta(RootModel):
NUMPYDANTIC_VERSION = "1.2.1"
T = TypeVar("T", bound=NDArray)
T = TypeVar("T", default=NDArray)
U = TypeVar("U", default=NDArray)
class VectorDataMixin(ConfiguredBaseModel, Generic[T]):
@ -364,7 +366,7 @@ class DynamicTableMixin(ConfiguredBaseModel):
"""
model_config = ConfigDict(extra="allow", validate_assignment=True)
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin"]]
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
"id",
"name",
@ -657,10 +659,14 @@ class DynamicTableMixin(ConfiguredBaseModel):
Ensure that all columns are equal length
"""
lengths = [len(v) for v in self._columns.values() if v is not None] + [len(self.id)]
assert all([length == lengths[0] for length in lengths]), (
"DynamicTable columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
)
try:
assert all([length == lengths[0] for length in lengths]), (
"DynamicTable columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
)
except AssertionError:
pdb.set_trace()
return self
@field_validator("*", mode="wrap")
@ -958,7 +964,7 @@ class ElementIdentifiers(ElementIdentifiersMixin, Data):
name: str = Field(
"element_id", json_schema_extra={"linkml_meta": {"ifabsent": "string(element_id)"}}
)
value: Optional[T] = Field(
value: Optional[NDArray] = Field(
None,
json_schema_extra={"linkml_meta": {"array": {"dimensions": [{"alias": "num_elements"}]}}},
)