mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
stuck at the recursion error in pydantic core
This commit is contained in:
parent
1701ef9d7e
commit
33eb1c15d5
9 changed files with 218 additions and 64 deletions
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
|
@ -47,7 +47,7 @@ jobs:
|
|||
run: pip install "numpy${{ matrix.numpy-version }}"
|
||||
|
||||
- name: Run Tests
|
||||
run: pytest
|
||||
run: pytest -n auto
|
||||
|
||||
- name: Coveralls Parallel
|
||||
uses: coverallsapp/github-action@v2.3.0
|
||||
|
|
28
pdm.lock
28
pdm.lock
|
@ -5,7 +5,7 @@
|
|||
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
|
||||
strategy = ["cross_platform", "inherit_metadata"]
|
||||
lock_version = "4.5.0"
|
||||
content_hash = "sha256:cc2b0fb32896c6df0ad747ddb5dee89af22f5c4c4643ee7a52db47fef30da936"
|
||||
content_hash = "sha256:576c76e5a4a7616b7cb72bca277db1ada330bb8f3f9827d0e161728325e19264"
|
||||
|
||||
[[metadata.targets]]
|
||||
requires_python = "~=3.9"
|
||||
|
@ -576,6 +576,17 @@ files = [
|
|||
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "execnet"
|
||||
version = "2.1.1"
|
||||
requires_python = ">=3.8"
|
||||
summary = "execnet: rapid multi-Python deployment"
|
||||
groups = ["dev", "tests"]
|
||||
files = [
|
||||
{file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"},
|
||||
{file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "executing"
|
||||
version = "2.1.0"
|
||||
|
@ -1647,6 +1658,21 @@ files = [
|
|||
{file = "pytest_depends-1.0.1-py3-none-any.whl", hash = "sha256:a1df072bcc93d77aca3f0946903f5fed8af2d9b0056db1dfc9ed5ac164ab0642"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-xdist"
|
||||
version = "3.6.1"
|
||||
requires_python = ">=3.8"
|
||||
summary = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs"
|
||||
groups = ["dev", "tests"]
|
||||
dependencies = [
|
||||
"execnet>=2.1",
|
||||
"pytest>=7.0.0",
|
||||
]
|
||||
files = [
|
||||
{file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"},
|
||||
{file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
|
|
|
@ -71,6 +71,7 @@ tests = [
|
|||
"coverage>=6.1.1",
|
||||
"pytest-cov<5.0.0,>=4.1.0",
|
||||
"coveralls<4.0.0,>=3.3.1",
|
||||
"pytest-xdist>=3.6.1",
|
||||
]
|
||||
docs = [
|
||||
"numpydantic[arrays]",
|
||||
|
|
|
@ -55,7 +55,6 @@ from typing import (
|
|||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
from pydantic_core import SchemaSerializer, core_schema
|
||||
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
@ -100,49 +99,49 @@ class H5JsonDict(JsonDict):
|
|||
)
|
||||
|
||||
|
||||
def to_json(self, info: SerializationInfo):
|
||||
"""
|
||||
Serialize H5Proxy to JSON, as the interface does,
|
||||
in cases when the interface is not able to be used
|
||||
(eg. like when used as an `extra` field in a model without a type annotation)
|
||||
"""
|
||||
from numpydantic.serialization import postprocess_json
|
||||
|
||||
if info.round_trip:
|
||||
as_json = {
|
||||
"type": H5Interface.name,
|
||||
}
|
||||
as_json.update(self._h5arraypath._asdict())
|
||||
else:
|
||||
try:
|
||||
dset = self.open()
|
||||
as_json = dset[:].tolist()
|
||||
finally:
|
||||
self.close()
|
||||
return postprocess_json(as_json, info)
|
||||
# def to_json(self, info: SerializationInfo):
|
||||
# """
|
||||
# Serialize H5Proxy to JSON, as the interface does,
|
||||
# in cases when the interface is not able to be used
|
||||
# (eg. like when used as an `extra` field in a model without a type annotation)
|
||||
# """
|
||||
# from numpydantic.serialization import postprocess_json
|
||||
#
|
||||
# if info.round_trip:
|
||||
# as_json = {
|
||||
# "type": H5Interface.name,
|
||||
# }
|
||||
# as_json.update(self._h5arraypath._asdict())
|
||||
# else:
|
||||
# try:
|
||||
# dset = self.open()
|
||||
# as_json = dset[:].tolist()
|
||||
# finally:
|
||||
# self.close()
|
||||
# return postprocess_json(as_json, info)
|
||||
|
||||
|
||||
def _make_pydantic_schema():
|
||||
return core_schema.typed_dict_schema(
|
||||
{
|
||||
"file": core_schema.typed_dict_field(
|
||||
core_schema.str_schema(), required=True
|
||||
),
|
||||
"path": core_schema.typed_dict_field(
|
||||
core_schema.str_schema(), required=True
|
||||
),
|
||||
"field": core_schema.typed_dict_field(
|
||||
core_schema.union_schema(
|
||||
[
|
||||
core_schema.str_schema(),
|
||||
core_schema.list_schema(core_schema.str_schema()),
|
||||
],
|
||||
),
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
# serialization=
|
||||
)
|
||||
# def _make_pydantic_schema():
|
||||
# return core_schema.typed_dict_schema(
|
||||
# {
|
||||
# "file": core_schema.typed_dict_field(
|
||||
# core_schema.str_schema(), required=True
|
||||
# ),
|
||||
# "path": core_schema.typed_dict_field(
|
||||
# core_schema.str_schema(), required=True
|
||||
# ),
|
||||
# "field": core_schema.typed_dict_field(
|
||||
# core_schema.union_schema(
|
||||
# [
|
||||
# core_schema.str_schema(),
|
||||
# core_schema.list_schema(core_schema.str_schema()),
|
||||
# ],
|
||||
# ),
|
||||
# required=True,
|
||||
# ),
|
||||
# },
|
||||
# # serialization=
|
||||
# )
|
||||
|
||||
|
||||
class H5Proxy:
|
||||
|
@ -167,11 +166,11 @@ class H5Proxy:
|
|||
annotation_dtype (dtype): Optional - the dtype of our type annotation
|
||||
"""
|
||||
|
||||
__pydantic_serializer__ = SchemaSerializer(
|
||||
core_schema.plain_serializer_function_ser_schema(
|
||||
to_json, when_used="json", info_arg=True
|
||||
),
|
||||
)
|
||||
# __pydantic_serializer__ = SchemaSerializer(
|
||||
# core_schema.plain_serializer_function_ser_schema(
|
||||
# jsonize_array, when_used="json", info_arg=True
|
||||
# ),
|
||||
# )
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -20,6 +20,7 @@ from numpydantic.exceptions import (
|
|||
ShapeError,
|
||||
TooManyMatchesError,
|
||||
)
|
||||
from numpydantic.serialization import pydantic_serializer
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
from numpydantic.validation import validate_dtype, validate_shape
|
||||
|
||||
|
@ -232,6 +233,7 @@ class Interface(ABC, Generic[T]):
|
|||
shape_valid = self.validate_shape(shape)
|
||||
self.raise_for_shape(shape_valid, shape)
|
||||
|
||||
array = self.apply_serializer(array)
|
||||
array = self.after_validation(array)
|
||||
|
||||
return array
|
||||
|
@ -338,6 +340,17 @@ class Interface(ABC, Generic[T]):
|
|||
f"got shape {shape}"
|
||||
)
|
||||
|
||||
def apply_serializer(self, array: NDArrayType) -> NDArrayType:
|
||||
"""
|
||||
Apply a __pydantic_serializer__ method that allows the array to be
|
||||
serialized directly without knowledge of the interface in the containing model.
|
||||
|
||||
Useful for when arrays are specified as `__pydantic_extra__` which has
|
||||
different serializer resolution logic.
|
||||
"""
|
||||
array.__pydantic_serializer__ = pydantic_serializer
|
||||
return array
|
||||
|
||||
def after_validation(self, array: NDArrayType) -> T:
|
||||
"""
|
||||
Optional step post-validation that coerces the intermediate array type into the
|
||||
|
|
|
@ -2,11 +2,12 @@
|
|||
Interface to numpy arrays
|
||||
"""
|
||||
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.serialization import pydantic_serializer
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
@ -36,6 +37,28 @@ class NumpyJsonDict(JsonDict):
|
|||
return np.array(self.value, dtype=self.dtype)
|
||||
|
||||
|
||||
class SerializableNDArray(np.ndarray):
|
||||
"""
|
||||
Trivial subclass of :class:`numpy.ndarray` that allows
|
||||
an additional ``__pydantic_serializer__`` attr to allow it to be
|
||||
json roundtripped without the help of an interface
|
||||
|
||||
References:
|
||||
https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
|
||||
"""
|
||||
|
||||
def __new__(cls, input_array: np.ndarray, **kwargs: dict[str, Any]):
|
||||
"""Create a new ndarray instance, adding a new attribute"""
|
||||
obj = np.asarray(input_array, **kwargs).view(cls)
|
||||
obj.__pydantic_serializer__ = pydantic_serializer
|
||||
return obj
|
||||
|
||||
def __array_finalize__(self, obj: Optional[np.ndarray]) -> None:
|
||||
if obj is None:
|
||||
return
|
||||
self.__pydantic_serializer__ = getattr(obj, "__pydantic_serializer__", None)
|
||||
|
||||
|
||||
class NumpyInterface(Interface):
|
||||
"""
|
||||
Numpy :class:`~numpy.ndarray` s!
|
||||
|
@ -62,7 +85,7 @@ class NumpyInterface(Interface):
|
|||
if array is None:
|
||||
return False
|
||||
|
||||
if isinstance(array, ndarray):
|
||||
if isinstance(array, (ndarray, SerializableNDArray)):
|
||||
return True
|
||||
elif isinstance(array, dict):
|
||||
return NumpyJsonDict.is_valid(array)
|
||||
|
@ -78,12 +101,14 @@ class NumpyInterface(Interface):
|
|||
Coerce to an ndarray. We have already checked if coercion is possible
|
||||
in :meth:`.check`
|
||||
"""
|
||||
if not isinstance(array, ndarray):
|
||||
array = np.array(array)
|
||||
if not isinstance(array, SerializableNDArray):
|
||||
array = SerializableNDArray(array)
|
||||
|
||||
try:
|
||||
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
|
||||
array = np.vectorize(lambda x: self.dtype(**x))(array)
|
||||
array = SerializableNDArray(
|
||||
np.vectorize(lambda x: self.dtype(**x))(array)
|
||||
)
|
||||
except TypeError:
|
||||
# fine, dtype isn't a type
|
||||
pass
|
||||
|
|
|
@ -4,11 +4,16 @@ and :func:`pydantic.BaseModel.model_dump_json` .
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union
|
||||
|
||||
from pydantic_core.core_schema import SerializationInfo
|
||||
from pydantic_core import SchemaSerializer
|
||||
from pydantic_core.core_schema import (
|
||||
SerializationInfo,
|
||||
plain_serializer_function_ser_schema,
|
||||
)
|
||||
|
||||
from numpydantic.interface import Interface, JsonDict
|
||||
if TYPE_CHECKING:
|
||||
from numpydantic.interface import Interface
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
@ -16,8 +21,8 @@ U = TypeVar("U")
|
|||
|
||||
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||
"""Use an interface class to render an array as JSON"""
|
||||
# return [1, 2, 3]
|
||||
# pdb.set_trace()
|
||||
from numpydantic.interface import Interface
|
||||
|
||||
interface_cls = Interface.match_output(value)
|
||||
array = interface_cls.to_json(value, info)
|
||||
array = postprocess_json(array, info, interface_cls)
|
||||
|
@ -25,11 +30,13 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
|||
|
||||
|
||||
def postprocess_json(
|
||||
array: Union[dict, list], info: SerializationInfo, interface_cls: type[Interface]
|
||||
array: Union[dict, list], info: SerializationInfo, interface_cls: type["Interface"]
|
||||
) -> Union[dict, list]:
|
||||
"""
|
||||
Modify json after dumping from an interface
|
||||
"""
|
||||
from numpydantic.interface import JsonDict
|
||||
|
||||
# perf: keys to skip in generation - anything named "value" is array data.
|
||||
skip = ["value"]
|
||||
if isinstance(array, JsonDict):
|
||||
|
@ -67,6 +74,17 @@ def postprocess_json(
|
|||
return array
|
||||
|
||||
|
||||
pydantic_serializer = SchemaSerializer(
|
||||
plain_serializer_function_ser_schema(
|
||||
jsonize_array, when_used="json", info_arg=True
|
||||
),
|
||||
)
|
||||
"""
|
||||
A generic serializer that can be applied to interface proxies et al as
|
||||
``__pydantic_serializer__`` .
|
||||
"""
|
||||
|
||||
|
||||
def _relativize_paths(
|
||||
value: dict, relative_to: str = ".", skip: Iterable = tuple()
|
||||
) -> dict:
|
||||
|
|
|
@ -97,6 +97,8 @@ def all_passing_cases(request) -> ValidationCase:
|
|||
that we want to be sure is *very true* in every circumstance.
|
||||
Typically, that means only use this in `test_interfaces.py`
|
||||
"""
|
||||
if "subclass" in request.param.id.lower():
|
||||
pytest.xfail()
|
||||
return request.param
|
||||
|
||||
|
||||
|
@ -136,10 +138,12 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
|
|||
for p in DTYPE_AND_INTERFACE_CASES_PASSING
|
||||
)
|
||||
)
|
||||
def all_passing_cases(request):
|
||||
def dtype_by_interface(request):
|
||||
"""
|
||||
Tests for all dtypes by all interfaces
|
||||
"""
|
||||
if "subclass" in request.param.id.lower():
|
||||
pytest.xfail()
|
||||
return request.param
|
||||
|
||||
|
||||
|
|
|
@ -4,13 +4,15 @@ Tests that should be applied to all interfaces
|
|||
|
||||
import json
|
||||
from importlib.metadata import version
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import dask.array as da
|
||||
import numpy as np
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from zarr.core import Array as ZarrArray
|
||||
|
||||
from numpydantic import NDArray
|
||||
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||
from numpydantic.testing.helpers import ValidationCase
|
||||
|
||||
|
@ -30,6 +32,8 @@ def _test_roundtrip(source: BaseModel, target: BaseModel):
|
|||
assert target_type is source_type
|
||||
else:
|
||||
assert np.all(da.equal(target.array, source.array))
|
||||
elif isinstance(source.array, BaseModel):
|
||||
return _test_roundtrip(source.array, target.array)
|
||||
else:
|
||||
assert target.array == source.array
|
||||
|
||||
|
@ -95,8 +99,6 @@ def test_interface_roundtrip_json(all_passing_cases, tmp_output_dir_func):
|
|||
"""
|
||||
All interfaces should be able to roundtrip to and from json
|
||||
"""
|
||||
if "subclass" in all_passing_cases.id.lower():
|
||||
pytest.xfail()
|
||||
|
||||
array = all_passing_cases.array(path=tmp_output_dir_func)
|
||||
case = all_passing_cases.model(array=array)
|
||||
|
@ -154,3 +156,69 @@ def test_interface_mark_roundtrip(all_passing_cases, valid, tmp_output_dir_func)
|
|||
model = case.model_validate_json(dumped_json)
|
||||
|
||||
_test_roundtrip(case, model)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func):
|
||||
"""
|
||||
Arrays can be dumped when they are specified in an `__extra__` field
|
||||
"""
|
||||
|
||||
class Model(BaseModel):
|
||||
__pydantic_extra__: dict[str, dtype_by_interface.annotation]
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
instance = Model(array=dtype_by_interface.array(path=tmp_output_dir_func))
|
||||
dumped = instance.model_dump_json(round_trip=True)
|
||||
roundtripped = Model.model_validate_json(dumped)
|
||||
_test_roundtrip(instance, roundtripped)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
|
||||
"""
|
||||
Arrays can be dumped when they are specified along with a union of another type field
|
||||
"""
|
||||
|
||||
class Model(BaseModel):
|
||||
array: str | dtype_by_interface.annotation
|
||||
|
||||
array = dtype_by_interface.array(path=tmp_output_dir_func)
|
||||
|
||||
instance = Model(array=array)
|
||||
dumped = instance.model_dump_json(round_trip=True)
|
||||
roundtripped = Model.model_validate_json(dumped)
|
||||
_test_roundtrip(instance, roundtripped)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func):
|
||||
"""
|
||||
Arrays can be dumped when they are specified in an `__extra__` field
|
||||
"""
|
||||
T = TypeVar("T", bound=NDArray)
|
||||
|
||||
class GenType(BaseModel, Generic[T]):
|
||||
array: T
|
||||
|
||||
class Model(BaseModel):
|
||||
array: GenType[dtype_by_interface.annotation]
|
||||
|
||||
array = dtype_by_interface.array(path=tmp_output_dir_func)
|
||||
instance = Model(**{"array": {"array": array}})
|
||||
dumped = instance.model_dump_json(round_trip=True)
|
||||
roundtripped = Model.model_validate_json(dumped)
|
||||
_test_roundtrip(instance, roundtripped)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func):
|
||||
"""
|
||||
We can roundtrip from an AnyType
|
||||
Args:
|
||||
dtype_by_interface:
|
||||
tmp_output_dir_func:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue