stuck at the recursion error in pydantic core

This commit is contained in:
sneakers-the-rat 2024-10-18 00:30:33 -07:00
parent 1701ef9d7e
commit 33eb1c15d5
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
9 changed files with 218 additions and 64 deletions

View file

@ -47,7 +47,7 @@ jobs:
run: pip install "numpy${{ matrix.numpy-version }}" run: pip install "numpy${{ matrix.numpy-version }}"
- name: Run Tests - name: Run Tests
run: pytest run: pytest -n auto
- name: Coveralls Parallel - name: Coveralls Parallel
uses: coverallsapp/github-action@v2.3.0 uses: coverallsapp/github-action@v2.3.0

View file

@ -5,7 +5,7 @@
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"] groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
strategy = ["cross_platform", "inherit_metadata"] strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.5.0" lock_version = "4.5.0"
content_hash = "sha256:cc2b0fb32896c6df0ad747ddb5dee89af22f5c4c4643ee7a52db47fef30da936" content_hash = "sha256:576c76e5a4a7616b7cb72bca277db1ada330bb8f3f9827d0e161728325e19264"
[[metadata.targets]] [[metadata.targets]]
requires_python = "~=3.9" requires_python = "~=3.9"
@ -576,6 +576,17 @@ files = [
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
] ]
[[package]]
name = "execnet"
version = "2.1.1"
requires_python = ">=3.8"
summary = "execnet: rapid multi-Python deployment"
groups = ["dev", "tests"]
files = [
{file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"},
{file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"},
]
[[package]] [[package]]
name = "executing" name = "executing"
version = "2.1.0" version = "2.1.0"
@ -1647,6 +1658,21 @@ files = [
{file = "pytest_depends-1.0.1-py3-none-any.whl", hash = "sha256:a1df072bcc93d77aca3f0946903f5fed8af2d9b0056db1dfc9ed5ac164ab0642"}, {file = "pytest_depends-1.0.1-py3-none-any.whl", hash = "sha256:a1df072bcc93d77aca3f0946903f5fed8af2d9b0056db1dfc9ed5ac164ab0642"},
] ]
[[package]]
name = "pytest-xdist"
version = "3.6.1"
requires_python = ">=3.8"
summary = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs"
groups = ["dev", "tests"]
dependencies = [
"execnet>=2.1",
"pytest>=7.0.0",
]
files = [
{file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"},
{file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"},
]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"

View file

@ -71,6 +71,7 @@ tests = [
"coverage>=6.1.1", "coverage>=6.1.1",
"pytest-cov<5.0.0,>=4.1.0", "pytest-cov<5.0.0,>=4.1.0",
"coveralls<4.0.0,>=3.3.1", "coveralls<4.0.0,>=3.3.1",
"pytest-xdist>=3.6.1",
] ]
docs = [ docs = [
"numpydantic[arrays]", "numpydantic[arrays]",

View file

@ -55,7 +55,6 @@ from typing import (
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo
from pydantic_core import SchemaSerializer, core_schema
from numpydantic.interface.interface import Interface, JsonDict from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType from numpydantic.types import DtypeType, NDArrayType
@ -100,49 +99,49 @@ class H5JsonDict(JsonDict):
) )
def to_json(self, info: SerializationInfo): # def to_json(self, info: SerializationInfo):
""" # """
Serialize H5Proxy to JSON, as the interface does, # Serialize H5Proxy to JSON, as the interface does,
in cases when the interface is not able to be used # in cases when the interface is not able to be used
(eg. like when used as an `extra` field in a model without a type annotation) # (eg. like when used as an `extra` field in a model without a type annotation)
""" # """
from numpydantic.serialization import postprocess_json # from numpydantic.serialization import postprocess_json
#
if info.round_trip: # if info.round_trip:
as_json = { # as_json = {
"type": H5Interface.name, # "type": H5Interface.name,
} # }
as_json.update(self._h5arraypath._asdict()) # as_json.update(self._h5arraypath._asdict())
else: # else:
try: # try:
dset = self.open() # dset = self.open()
as_json = dset[:].tolist() # as_json = dset[:].tolist()
finally: # finally:
self.close() # self.close()
return postprocess_json(as_json, info) # return postprocess_json(as_json, info)
def _make_pydantic_schema(): # def _make_pydantic_schema():
return core_schema.typed_dict_schema( # return core_schema.typed_dict_schema(
{ # {
"file": core_schema.typed_dict_field( # "file": core_schema.typed_dict_field(
core_schema.str_schema(), required=True # core_schema.str_schema(), required=True
), # ),
"path": core_schema.typed_dict_field( # "path": core_schema.typed_dict_field(
core_schema.str_schema(), required=True # core_schema.str_schema(), required=True
), # ),
"field": core_schema.typed_dict_field( # "field": core_schema.typed_dict_field(
core_schema.union_schema( # core_schema.union_schema(
[ # [
core_schema.str_schema(), # core_schema.str_schema(),
core_schema.list_schema(core_schema.str_schema()), # core_schema.list_schema(core_schema.str_schema()),
], # ],
), # ),
required=True, # required=True,
), # ),
}, # },
# serialization= # # serialization=
) # )
class H5Proxy: class H5Proxy:
@ -167,11 +166,11 @@ class H5Proxy:
annotation_dtype (dtype): Optional - the dtype of our type annotation annotation_dtype (dtype): Optional - the dtype of our type annotation
""" """
__pydantic_serializer__ = SchemaSerializer( # __pydantic_serializer__ = SchemaSerializer(
core_schema.plain_serializer_function_ser_schema( # core_schema.plain_serializer_function_ser_schema(
to_json, when_used="json", info_arg=True # jsonize_array, when_used="json", info_arg=True
), # ),
) # )
def __init__( def __init__(
self, self,

View file

@ -20,6 +20,7 @@ from numpydantic.exceptions import (
ShapeError, ShapeError,
TooManyMatchesError, TooManyMatchesError,
) )
from numpydantic.serialization import pydantic_serializer
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation import validate_dtype, validate_shape from numpydantic.validation import validate_dtype, validate_shape
@ -232,6 +233,7 @@ class Interface(ABC, Generic[T]):
shape_valid = self.validate_shape(shape) shape_valid = self.validate_shape(shape)
self.raise_for_shape(shape_valid, shape) self.raise_for_shape(shape_valid, shape)
array = self.apply_serializer(array)
array = self.after_validation(array) array = self.after_validation(array)
return array return array
@ -338,6 +340,17 @@ class Interface(ABC, Generic[T]):
f"got shape {shape}" f"got shape {shape}"
) )
def apply_serializer(self, array: NDArrayType) -> NDArrayType:
"""
Apply a __pydantic_serializer__ method that allows the array to be
serialized directly without knowledge of the interface in the containing model.
Useful for when arrays are specified as `__pydantic_extra__` which has
different serializer resolution logic.
"""
array.__pydantic_serializer__ = pydantic_serializer
return array
def after_validation(self, array: NDArrayType) -> T: def after_validation(self, array: NDArrayType) -> T:
""" """
Optional step post-validation that coerces the intermediate array type into the Optional step post-validation that coerces the intermediate array type into the

View file

@ -2,11 +2,12 @@
Interface to numpy arrays Interface to numpy arrays
""" """
from typing import Any, Literal, Union from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, SerializationInfo from pydantic import BaseModel, SerializationInfo
from numpydantic.interface.interface import Interface, JsonDict from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.serialization import pydantic_serializer
try: try:
import numpy as np import numpy as np
@ -36,6 +37,28 @@ class NumpyJsonDict(JsonDict):
return np.array(self.value, dtype=self.dtype) return np.array(self.value, dtype=self.dtype)
class SerializableNDArray(np.ndarray):
"""
Trivial subclass of :class:`numpy.ndarray` that allows
an additional ``__pydantic_serializer__`` attr to allow it to be
json roundtripped without the help of an interface
References:
https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
"""
def __new__(cls, input_array: np.ndarray, **kwargs: dict[str, Any]):
"""Create a new ndarray instance, adding a new attribute"""
obj = np.asarray(input_array, **kwargs).view(cls)
obj.__pydantic_serializer__ = pydantic_serializer
return obj
def __array_finalize__(self, obj: Optional[np.ndarray]) -> None:
if obj is None:
return
self.__pydantic_serializer__ = getattr(obj, "__pydantic_serializer__", None)
class NumpyInterface(Interface): class NumpyInterface(Interface):
""" """
Numpy :class:`~numpy.ndarray` s! Numpy :class:`~numpy.ndarray` s!
@ -62,7 +85,7 @@ class NumpyInterface(Interface):
if array is None: if array is None:
return False return False
if isinstance(array, ndarray): if isinstance(array, (ndarray, SerializableNDArray)):
return True return True
elif isinstance(array, dict): elif isinstance(array, dict):
return NumpyJsonDict.is_valid(array) return NumpyJsonDict.is_valid(array)
@ -78,12 +101,14 @@ class NumpyInterface(Interface):
Coerce to an ndarray. We have already checked if coercion is possible Coerce to an ndarray. We have already checked if coercion is possible
in :meth:`.check` in :meth:`.check`
""" """
if not isinstance(array, ndarray): if not isinstance(array, SerializableNDArray):
array = np.array(array) array = SerializableNDArray(array)
try: try:
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict): if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
array = np.vectorize(lambda x: self.dtype(**x))(array) array = SerializableNDArray(
np.vectorize(lambda x: self.dtype(**x))(array)
)
except TypeError: except TypeError:
# fine, dtype isn't a type # fine, dtype isn't a type
pass pass

View file

@ -4,11 +4,16 @@ and :func:`pydantic.BaseModel.model_dump_json` .
""" """
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterable, TypeVar, Union from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union
from pydantic_core.core_schema import SerializationInfo from pydantic_core import SchemaSerializer
from pydantic_core.core_schema import (
SerializationInfo,
plain_serializer_function_ser_schema,
)
from numpydantic.interface import Interface, JsonDict if TYPE_CHECKING:
from numpydantic.interface import Interface
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U") U = TypeVar("U")
@ -16,8 +21,8 @@ U = TypeVar("U")
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON""" """Use an interface class to render an array as JSON"""
# return [1, 2, 3] from numpydantic.interface import Interface
# pdb.set_trace()
interface_cls = Interface.match_output(value) interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info) array = interface_cls.to_json(value, info)
array = postprocess_json(array, info, interface_cls) array = postprocess_json(array, info, interface_cls)
@ -25,11 +30,13 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
def postprocess_json( def postprocess_json(
array: Union[dict, list], info: SerializationInfo, interface_cls: type[Interface] array: Union[dict, list], info: SerializationInfo, interface_cls: type["Interface"]
) -> Union[dict, list]: ) -> Union[dict, list]:
""" """
Modify json after dumping from an interface Modify json after dumping from an interface
""" """
from numpydantic.interface import JsonDict
# perf: keys to skip in generation - anything named "value" is array data. # perf: keys to skip in generation - anything named "value" is array data.
skip = ["value"] skip = ["value"]
if isinstance(array, JsonDict): if isinstance(array, JsonDict):
@ -67,6 +74,17 @@ def postprocess_json(
return array return array
pydantic_serializer = SchemaSerializer(
plain_serializer_function_ser_schema(
jsonize_array, when_used="json", info_arg=True
),
)
"""
A generic serializer that can be applied to interface proxies et al as
``__pydantic_serializer__`` .
"""
def _relativize_paths( def _relativize_paths(
value: dict, relative_to: str = ".", skip: Iterable = tuple() value: dict, relative_to: str = ".", skip: Iterable = tuple()
) -> dict: ) -> dict:

View file

@ -97,6 +97,8 @@ def all_passing_cases(request) -> ValidationCase:
that we want to be sure is *very true* in every circumstance. that we want to be sure is *very true* in every circumstance.
Typically, that means only use this in `test_interfaces.py` Typically, that means only use this in `test_interfaces.py`
""" """
if "subclass" in request.param.id.lower():
pytest.xfail()
return request.param return request.param
@ -136,10 +138,12 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
for p in DTYPE_AND_INTERFACE_CASES_PASSING for p in DTYPE_AND_INTERFACE_CASES_PASSING
) )
) )
def all_passing_cases(request): def dtype_by_interface(request):
""" """
Tests for all dtypes by all interfaces Tests for all dtypes by all interfaces
""" """
if "subclass" in request.param.id.lower():
pytest.xfail()
return request.param return request.param

View file

@ -4,13 +4,15 @@ Tests that should be applied to all interfaces
import json import json
from importlib.metadata import version from importlib.metadata import version
from typing import Generic, TypeVar
import dask.array as da import dask.array as da
import numpy as np import numpy as np
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from zarr.core import Array as ZarrArray from zarr.core import Array as ZarrArray
from numpydantic import NDArray
from numpydantic.interface import Interface, InterfaceMark, MarkedJson from numpydantic.interface import Interface, InterfaceMark, MarkedJson
from numpydantic.testing.helpers import ValidationCase from numpydantic.testing.helpers import ValidationCase
@ -30,6 +32,8 @@ def _test_roundtrip(source: BaseModel, target: BaseModel):
assert target_type is source_type assert target_type is source_type
else: else:
assert np.all(da.equal(target.array, source.array)) assert np.all(da.equal(target.array, source.array))
elif isinstance(source.array, BaseModel):
return _test_roundtrip(source.array, target.array)
else: else:
assert target.array == source.array assert target.array == source.array
@ -95,8 +99,6 @@ def test_interface_roundtrip_json(all_passing_cases, tmp_output_dir_func):
""" """
All interfaces should be able to roundtrip to and from json All interfaces should be able to roundtrip to and from json
""" """
if "subclass" in all_passing_cases.id.lower():
pytest.xfail()
array = all_passing_cases.array(path=tmp_output_dir_func) array = all_passing_cases.array(path=tmp_output_dir_func)
case = all_passing_cases.model(array=array) case = all_passing_cases.model(array=array)
@ -154,3 +156,69 @@ def test_interface_mark_roundtrip(all_passing_cases, valid, tmp_output_dir_func)
model = case.model_validate_json(dumped_json) model = case.model_validate_json(dumped_json)
_test_roundtrip(case, model) _test_roundtrip(case, model)
@pytest.mark.serialization
def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func):
"""
Arrays can be dumped when they are specified in an `__extra__` field
"""
class Model(BaseModel):
__pydantic_extra__: dict[str, dtype_by_interface.annotation]
model_config = ConfigDict(extra="allow")
instance = Model(array=dtype_by_interface.array(path=tmp_output_dir_func))
dumped = instance.model_dump_json(round_trip=True)
roundtripped = Model.model_validate_json(dumped)
_test_roundtrip(instance, roundtripped)
@pytest.mark.serialization
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
"""
Arrays can be dumped when they are specified along with a union of another type field
"""
class Model(BaseModel):
array: str | dtype_by_interface.annotation
array = dtype_by_interface.array(path=tmp_output_dir_func)
instance = Model(array=array)
dumped = instance.model_dump_json(round_trip=True)
roundtripped = Model.model_validate_json(dumped)
_test_roundtrip(instance, roundtripped)
@pytest.mark.serialization
def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func):
"""
Arrays can be dumped when they are specified in an `__extra__` field
"""
T = TypeVar("T", bound=NDArray)
class GenType(BaseModel, Generic[T]):
array: T
class Model(BaseModel):
array: GenType[dtype_by_interface.annotation]
array = dtype_by_interface.array(path=tmp_output_dir_func)
instance = Model(**{"array": {"array": array}})
dumped = instance.model_dump_json(round_trip=True)
roundtripped = Model.model_validate_json(dumped)
_test_roundtrip(instance, roundtripped)
@pytest.mark.serialization
def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func):
"""
We can roundtrip from an AnyType
Args:
dtype_by_interface:
tmp_output_dir_func:
Returns:
"""