mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
stuck at the recursion error in pydantic core
This commit is contained in:
parent
1701ef9d7e
commit
33eb1c15d5
9 changed files with 218 additions and 64 deletions
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
|
@ -47,7 +47,7 @@ jobs:
|
||||||
run: pip install "numpy${{ matrix.numpy-version }}"
|
run: pip install "numpy${{ matrix.numpy-version }}"
|
||||||
|
|
||||||
- name: Run Tests
|
- name: Run Tests
|
||||||
run: pytest
|
run: pytest -n auto
|
||||||
|
|
||||||
- name: Coveralls Parallel
|
- name: Coveralls Parallel
|
||||||
uses: coverallsapp/github-action@v2.3.0
|
uses: coverallsapp/github-action@v2.3.0
|
||||||
|
|
28
pdm.lock
28
pdm.lock
|
@ -5,7 +5,7 @@
|
||||||
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
|
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
|
||||||
strategy = ["cross_platform", "inherit_metadata"]
|
strategy = ["cross_platform", "inherit_metadata"]
|
||||||
lock_version = "4.5.0"
|
lock_version = "4.5.0"
|
||||||
content_hash = "sha256:cc2b0fb32896c6df0ad747ddb5dee89af22f5c4c4643ee7a52db47fef30da936"
|
content_hash = "sha256:576c76e5a4a7616b7cb72bca277db1ada330bb8f3f9827d0e161728325e19264"
|
||||||
|
|
||||||
[[metadata.targets]]
|
[[metadata.targets]]
|
||||||
requires_python = "~=3.9"
|
requires_python = "~=3.9"
|
||||||
|
@ -576,6 +576,17 @@ files = [
|
||||||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "execnet"
|
||||||
|
version = "2.1.1"
|
||||||
|
requires_python = ">=3.8"
|
||||||
|
summary = "execnet: rapid multi-Python deployment"
|
||||||
|
groups = ["dev", "tests"]
|
||||||
|
files = [
|
||||||
|
{file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"},
|
||||||
|
{file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "executing"
|
name = "executing"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
|
@ -1647,6 +1658,21 @@ files = [
|
||||||
{file = "pytest_depends-1.0.1-py3-none-any.whl", hash = "sha256:a1df072bcc93d77aca3f0946903f5fed8af2d9b0056db1dfc9ed5ac164ab0642"},
|
{file = "pytest_depends-1.0.1-py3-none-any.whl", hash = "sha256:a1df072bcc93d77aca3f0946903f5fed8af2d9b0056db1dfc9ed5ac164ab0642"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-xdist"
|
||||||
|
version = "3.6.1"
|
||||||
|
requires_python = ">=3.8"
|
||||||
|
summary = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs"
|
||||||
|
groups = ["dev", "tests"]
|
||||||
|
dependencies = [
|
||||||
|
"execnet>=2.1",
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
]
|
||||||
|
files = [
|
||||||
|
{file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"},
|
||||||
|
{file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
|
|
|
@ -71,6 +71,7 @@ tests = [
|
||||||
"coverage>=6.1.1",
|
"coverage>=6.1.1",
|
||||||
"pytest-cov<5.0.0,>=4.1.0",
|
"pytest-cov<5.0.0,>=4.1.0",
|
||||||
"coveralls<4.0.0,>=3.3.1",
|
"coveralls<4.0.0,>=3.3.1",
|
||||||
|
"pytest-xdist>=3.6.1",
|
||||||
]
|
]
|
||||||
docs = [
|
docs = [
|
||||||
"numpydantic[arrays]",
|
"numpydantic[arrays]",
|
||||||
|
|
|
@ -55,7 +55,6 @@ from typing import (
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
from pydantic_core import SchemaSerializer, core_schema
|
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface, JsonDict
|
from numpydantic.interface.interface import Interface, JsonDict
|
||||||
from numpydantic.types import DtypeType, NDArrayType
|
from numpydantic.types import DtypeType, NDArrayType
|
||||||
|
@ -100,49 +99,49 @@ class H5JsonDict(JsonDict):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def to_json(self, info: SerializationInfo):
|
# def to_json(self, info: SerializationInfo):
|
||||||
"""
|
# """
|
||||||
Serialize H5Proxy to JSON, as the interface does,
|
# Serialize H5Proxy to JSON, as the interface does,
|
||||||
in cases when the interface is not able to be used
|
# in cases when the interface is not able to be used
|
||||||
(eg. like when used as an `extra` field in a model without a type annotation)
|
# (eg. like when used as an `extra` field in a model without a type annotation)
|
||||||
"""
|
# """
|
||||||
from numpydantic.serialization import postprocess_json
|
# from numpydantic.serialization import postprocess_json
|
||||||
|
#
|
||||||
if info.round_trip:
|
# if info.round_trip:
|
||||||
as_json = {
|
# as_json = {
|
||||||
"type": H5Interface.name,
|
# "type": H5Interface.name,
|
||||||
}
|
# }
|
||||||
as_json.update(self._h5arraypath._asdict())
|
# as_json.update(self._h5arraypath._asdict())
|
||||||
else:
|
# else:
|
||||||
try:
|
# try:
|
||||||
dset = self.open()
|
# dset = self.open()
|
||||||
as_json = dset[:].tolist()
|
# as_json = dset[:].tolist()
|
||||||
finally:
|
# finally:
|
||||||
self.close()
|
# self.close()
|
||||||
return postprocess_json(as_json, info)
|
# return postprocess_json(as_json, info)
|
||||||
|
|
||||||
|
|
||||||
def _make_pydantic_schema():
|
# def _make_pydantic_schema():
|
||||||
return core_schema.typed_dict_schema(
|
# return core_schema.typed_dict_schema(
|
||||||
{
|
# {
|
||||||
"file": core_schema.typed_dict_field(
|
# "file": core_schema.typed_dict_field(
|
||||||
core_schema.str_schema(), required=True
|
# core_schema.str_schema(), required=True
|
||||||
),
|
# ),
|
||||||
"path": core_schema.typed_dict_field(
|
# "path": core_schema.typed_dict_field(
|
||||||
core_schema.str_schema(), required=True
|
# core_schema.str_schema(), required=True
|
||||||
),
|
# ),
|
||||||
"field": core_schema.typed_dict_field(
|
# "field": core_schema.typed_dict_field(
|
||||||
core_schema.union_schema(
|
# core_schema.union_schema(
|
||||||
[
|
# [
|
||||||
core_schema.str_schema(),
|
# core_schema.str_schema(),
|
||||||
core_schema.list_schema(core_schema.str_schema()),
|
# core_schema.list_schema(core_schema.str_schema()),
|
||||||
],
|
# ],
|
||||||
),
|
# ),
|
||||||
required=True,
|
# required=True,
|
||||||
),
|
# ),
|
||||||
},
|
# },
|
||||||
# serialization=
|
# # serialization=
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
class H5Proxy:
|
class H5Proxy:
|
||||||
|
@ -167,11 +166,11 @@ class H5Proxy:
|
||||||
annotation_dtype (dtype): Optional - the dtype of our type annotation
|
annotation_dtype (dtype): Optional - the dtype of our type annotation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__pydantic_serializer__ = SchemaSerializer(
|
# __pydantic_serializer__ = SchemaSerializer(
|
||||||
core_schema.plain_serializer_function_ser_schema(
|
# core_schema.plain_serializer_function_ser_schema(
|
||||||
to_json, when_used="json", info_arg=True
|
# jsonize_array, when_used="json", info_arg=True
|
||||||
),
|
# ),
|
||||||
)
|
# )
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -20,6 +20,7 @@ from numpydantic.exceptions import (
|
||||||
ShapeError,
|
ShapeError,
|
||||||
TooManyMatchesError,
|
TooManyMatchesError,
|
||||||
)
|
)
|
||||||
|
from numpydantic.serialization import pydantic_serializer
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
from numpydantic.validation import validate_dtype, validate_shape
|
from numpydantic.validation import validate_dtype, validate_shape
|
||||||
|
|
||||||
|
@ -232,6 +233,7 @@ class Interface(ABC, Generic[T]):
|
||||||
shape_valid = self.validate_shape(shape)
|
shape_valid = self.validate_shape(shape)
|
||||||
self.raise_for_shape(shape_valid, shape)
|
self.raise_for_shape(shape_valid, shape)
|
||||||
|
|
||||||
|
array = self.apply_serializer(array)
|
||||||
array = self.after_validation(array)
|
array = self.after_validation(array)
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
@ -338,6 +340,17 @@ class Interface(ABC, Generic[T]):
|
||||||
f"got shape {shape}"
|
f"got shape {shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def apply_serializer(self, array: NDArrayType) -> NDArrayType:
|
||||||
|
"""
|
||||||
|
Apply a __pydantic_serializer__ method that allows the array to be
|
||||||
|
serialized directly without knowledge of the interface in the containing model.
|
||||||
|
|
||||||
|
Useful for when arrays are specified as `__pydantic_extra__` which has
|
||||||
|
different serializer resolution logic.
|
||||||
|
"""
|
||||||
|
array.__pydantic_serializer__ = pydantic_serializer
|
||||||
|
return array
|
||||||
|
|
||||||
def after_validation(self, array: NDArrayType) -> T:
|
def after_validation(self, array: NDArrayType) -> T:
|
||||||
"""
|
"""
|
||||||
Optional step post-validation that coerces the intermediate array type into the
|
Optional step post-validation that coerces the intermediate array type into the
|
||||||
|
|
|
@ -2,11 +2,12 @@
|
||||||
Interface to numpy arrays
|
Interface to numpy arrays
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, SerializationInfo
|
from pydantic import BaseModel, SerializationInfo
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface, JsonDict
|
from numpydantic.interface.interface import Interface, JsonDict
|
||||||
|
from numpydantic.serialization import pydantic_serializer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -36,6 +37,28 @@ class NumpyJsonDict(JsonDict):
|
||||||
return np.array(self.value, dtype=self.dtype)
|
return np.array(self.value, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SerializableNDArray(np.ndarray):
|
||||||
|
"""
|
||||||
|
Trivial subclass of :class:`numpy.ndarray` that allows
|
||||||
|
an additional ``__pydantic_serializer__`` attr to allow it to be
|
||||||
|
json roundtripped without the help of an interface
|
||||||
|
|
||||||
|
References:
|
||||||
|
https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, input_array: np.ndarray, **kwargs: dict[str, Any]):
|
||||||
|
"""Create a new ndarray instance, adding a new attribute"""
|
||||||
|
obj = np.asarray(input_array, **kwargs).view(cls)
|
||||||
|
obj.__pydantic_serializer__ = pydantic_serializer
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __array_finalize__(self, obj: Optional[np.ndarray]) -> None:
|
||||||
|
if obj is None:
|
||||||
|
return
|
||||||
|
self.__pydantic_serializer__ = getattr(obj, "__pydantic_serializer__", None)
|
||||||
|
|
||||||
|
|
||||||
class NumpyInterface(Interface):
|
class NumpyInterface(Interface):
|
||||||
"""
|
"""
|
||||||
Numpy :class:`~numpy.ndarray` s!
|
Numpy :class:`~numpy.ndarray` s!
|
||||||
|
@ -62,7 +85,7 @@ class NumpyInterface(Interface):
|
||||||
if array is None:
|
if array is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if isinstance(array, ndarray):
|
if isinstance(array, (ndarray, SerializableNDArray)):
|
||||||
return True
|
return True
|
||||||
elif isinstance(array, dict):
|
elif isinstance(array, dict):
|
||||||
return NumpyJsonDict.is_valid(array)
|
return NumpyJsonDict.is_valid(array)
|
||||||
|
@ -78,12 +101,14 @@ class NumpyInterface(Interface):
|
||||||
Coerce to an ndarray. We have already checked if coercion is possible
|
Coerce to an ndarray. We have already checked if coercion is possible
|
||||||
in :meth:`.check`
|
in :meth:`.check`
|
||||||
"""
|
"""
|
||||||
if not isinstance(array, ndarray):
|
if not isinstance(array, SerializableNDArray):
|
||||||
array = np.array(array)
|
array = SerializableNDArray(array)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
|
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
|
||||||
array = np.vectorize(lambda x: self.dtype(**x))(array)
|
array = SerializableNDArray(
|
||||||
|
np.vectorize(lambda x: self.dtype(**x))(array)
|
||||||
|
)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# fine, dtype isn't a type
|
# fine, dtype isn't a type
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -4,11 +4,16 @@ and :func:`pydantic.BaseModel.model_dump_json` .
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Iterable, TypeVar, Union
|
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union
|
||||||
|
|
||||||
from pydantic_core.core_schema import SerializationInfo
|
from pydantic_core import SchemaSerializer
|
||||||
|
from pydantic_core.core_schema import (
|
||||||
|
SerializationInfo,
|
||||||
|
plain_serializer_function_ser_schema,
|
||||||
|
)
|
||||||
|
|
||||||
from numpydantic.interface import Interface, JsonDict
|
if TYPE_CHECKING:
|
||||||
|
from numpydantic.interface import Interface
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
@ -16,8 +21,8 @@ U = TypeVar("U")
|
||||||
|
|
||||||
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
"""Use an interface class to render an array as JSON"""
|
"""Use an interface class to render an array as JSON"""
|
||||||
# return [1, 2, 3]
|
from numpydantic.interface import Interface
|
||||||
# pdb.set_trace()
|
|
||||||
interface_cls = Interface.match_output(value)
|
interface_cls = Interface.match_output(value)
|
||||||
array = interface_cls.to_json(value, info)
|
array = interface_cls.to_json(value, info)
|
||||||
array = postprocess_json(array, info, interface_cls)
|
array = postprocess_json(array, info, interface_cls)
|
||||||
|
@ -25,11 +30,13 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
|
|
||||||
|
|
||||||
def postprocess_json(
|
def postprocess_json(
|
||||||
array: Union[dict, list], info: SerializationInfo, interface_cls: type[Interface]
|
array: Union[dict, list], info: SerializationInfo, interface_cls: type["Interface"]
|
||||||
) -> Union[dict, list]:
|
) -> Union[dict, list]:
|
||||||
"""
|
"""
|
||||||
Modify json after dumping from an interface
|
Modify json after dumping from an interface
|
||||||
"""
|
"""
|
||||||
|
from numpydantic.interface import JsonDict
|
||||||
|
|
||||||
# perf: keys to skip in generation - anything named "value" is array data.
|
# perf: keys to skip in generation - anything named "value" is array data.
|
||||||
skip = ["value"]
|
skip = ["value"]
|
||||||
if isinstance(array, JsonDict):
|
if isinstance(array, JsonDict):
|
||||||
|
@ -67,6 +74,17 @@ def postprocess_json(
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
pydantic_serializer = SchemaSerializer(
|
||||||
|
plain_serializer_function_ser_schema(
|
||||||
|
jsonize_array, when_used="json", info_arg=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
A generic serializer that can be applied to interface proxies et al as
|
||||||
|
``__pydantic_serializer__`` .
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _relativize_paths(
|
def _relativize_paths(
|
||||||
value: dict, relative_to: str = ".", skip: Iterable = tuple()
|
value: dict, relative_to: str = ".", skip: Iterable = tuple()
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -97,6 +97,8 @@ def all_passing_cases(request) -> ValidationCase:
|
||||||
that we want to be sure is *very true* in every circumstance.
|
that we want to be sure is *very true* in every circumstance.
|
||||||
Typically, that means only use this in `test_interfaces.py`
|
Typically, that means only use this in `test_interfaces.py`
|
||||||
"""
|
"""
|
||||||
|
if "subclass" in request.param.id.lower():
|
||||||
|
pytest.xfail()
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,10 +138,12 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
|
||||||
for p in DTYPE_AND_INTERFACE_CASES_PASSING
|
for p in DTYPE_AND_INTERFACE_CASES_PASSING
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
def all_passing_cases(request):
|
def dtype_by_interface(request):
|
||||||
"""
|
"""
|
||||||
Tests for all dtypes by all interfaces
|
Tests for all dtypes by all interfaces
|
||||||
"""
|
"""
|
||||||
|
if "subclass" in request.param.id.lower():
|
||||||
|
pytest.xfail()
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,15 @@ Tests that should be applied to all interfaces
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
from zarr.core import Array as ZarrArray
|
from zarr.core import Array as ZarrArray
|
||||||
|
|
||||||
|
from numpydantic import NDArray
|
||||||
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||||
from numpydantic.testing.helpers import ValidationCase
|
from numpydantic.testing.helpers import ValidationCase
|
||||||
|
|
||||||
|
@ -30,6 +32,8 @@ def _test_roundtrip(source: BaseModel, target: BaseModel):
|
||||||
assert target_type is source_type
|
assert target_type is source_type
|
||||||
else:
|
else:
|
||||||
assert np.all(da.equal(target.array, source.array))
|
assert np.all(da.equal(target.array, source.array))
|
||||||
|
elif isinstance(source.array, BaseModel):
|
||||||
|
return _test_roundtrip(source.array, target.array)
|
||||||
else:
|
else:
|
||||||
assert target.array == source.array
|
assert target.array == source.array
|
||||||
|
|
||||||
|
@ -95,8 +99,6 @@ def test_interface_roundtrip_json(all_passing_cases, tmp_output_dir_func):
|
||||||
"""
|
"""
|
||||||
All interfaces should be able to roundtrip to and from json
|
All interfaces should be able to roundtrip to and from json
|
||||||
"""
|
"""
|
||||||
if "subclass" in all_passing_cases.id.lower():
|
|
||||||
pytest.xfail()
|
|
||||||
|
|
||||||
array = all_passing_cases.array(path=tmp_output_dir_func)
|
array = all_passing_cases.array(path=tmp_output_dir_func)
|
||||||
case = all_passing_cases.model(array=array)
|
case = all_passing_cases.model(array=array)
|
||||||
|
@ -154,3 +156,69 @@ def test_interface_mark_roundtrip(all_passing_cases, valid, tmp_output_dir_func)
|
||||||
model = case.model_validate_json(dumped_json)
|
model = case.model_validate_json(dumped_json)
|
||||||
|
|
||||||
_test_roundtrip(case, model)
|
_test_roundtrip(case, model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func):
|
||||||
|
"""
|
||||||
|
Arrays can be dumped when they are specified in an `__extra__` field
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
__pydantic_extra__: dict[str, dtype_by_interface.annotation]
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
instance = Model(array=dtype_by_interface.array(path=tmp_output_dir_func))
|
||||||
|
dumped = instance.model_dump_json(round_trip=True)
|
||||||
|
roundtripped = Model.model_validate_json(dumped)
|
||||||
|
_test_roundtrip(instance, roundtripped)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
|
||||||
|
"""
|
||||||
|
Arrays can be dumped when they are specified along with a union of another type field
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: str | dtype_by_interface.annotation
|
||||||
|
|
||||||
|
array = dtype_by_interface.array(path=tmp_output_dir_func)
|
||||||
|
|
||||||
|
instance = Model(array=array)
|
||||||
|
dumped = instance.model_dump_json(round_trip=True)
|
||||||
|
roundtripped = Model.model_validate_json(dumped)
|
||||||
|
_test_roundtrip(instance, roundtripped)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func):
|
||||||
|
"""
|
||||||
|
Arrays can be dumped when they are specified in an `__extra__` field
|
||||||
|
"""
|
||||||
|
T = TypeVar("T", bound=NDArray)
|
||||||
|
|
||||||
|
class GenType(BaseModel, Generic[T]):
|
||||||
|
array: T
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: GenType[dtype_by_interface.annotation]
|
||||||
|
|
||||||
|
array = dtype_by_interface.array(path=tmp_output_dir_func)
|
||||||
|
instance = Model(**{"array": {"array": array}})
|
||||||
|
dumped = instance.model_dump_json(round_trip=True)
|
||||||
|
roundtripped = Model.model_validate_json(dumped)
|
||||||
|
_test_roundtrip(instance, roundtripped)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func):
|
||||||
|
"""
|
||||||
|
We can roundtrip from an AnyType
|
||||||
|
Args:
|
||||||
|
dtype_by_interface:
|
||||||
|
tmp_output_dir_func:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
Loading…
Reference in a new issue