mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-15 03:04:29 +00:00
working roundtrip serialization from extra after correcting serialization schema. doh.
This commit is contained in:
parent
33eb1c15d5
commit
6ce3d58528
4 changed files with 45 additions and 27 deletions
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union
|
||||||
from pydantic_core import SchemaSerializer
|
from pydantic_core import SchemaSerializer
|
||||||
from pydantic_core.core_schema import (
|
from pydantic_core.core_schema import (
|
||||||
SerializationInfo,
|
SerializationInfo,
|
||||||
|
any_schema,
|
||||||
plain_serializer_function_ser_schema,
|
plain_serializer_function_ser_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -75,9 +76,11 @@ def postprocess_json(
|
||||||
|
|
||||||
|
|
||||||
pydantic_serializer = SchemaSerializer(
|
pydantic_serializer = SchemaSerializer(
|
||||||
plain_serializer_function_ser_schema(
|
any_schema(
|
||||||
|
serialization=plain_serializer_function_ser_schema(
|
||||||
jsonize_array, when_used="json", info_arg=True
|
jsonize_array, when_used="json", info_arg=True
|
||||||
),
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
A generic serializer that can be applied to interface proxies et al as
|
A generic serializer that can be applied to interface proxies et al as
|
||||||
|
|
|
@ -225,34 +225,36 @@ All the interface cases
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
DTYPE_AND_SHAPE_CASES = merged_product(SHAPE_CASES, DTYPE_CASES)
|
DTYPE_AND_SHAPE_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES))
|
||||||
"""
|
"""
|
||||||
Merged product of dtype and shape cases
|
Merged product of dtype and shape cases
|
||||||
"""
|
"""
|
||||||
DTYPE_AND_SHAPE_CASES_PASSING = merged_product(
|
DTYPE_AND_SHAPE_CASES_PASSING = list(
|
||||||
SHAPE_CASES, DTYPE_CASES, conditions={"passes": True}
|
merged_product(SHAPE_CASES, DTYPE_CASES, conditions={"passes": True})
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
Merged product of dtype and shape cases that are valid
|
Merged product of dtype and shape cases that are valid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DTYPE_AND_INTERFACE_CASES = merged_product(INTERFACE_CASES, DTYPE_CASES)
|
DTYPE_AND_INTERFACE_CASES = list(merged_product(INTERFACE_CASES, DTYPE_CASES))
|
||||||
"""
|
"""
|
||||||
Merged product of dtype and interface cases
|
Merged product of dtype and interface cases
|
||||||
"""
|
"""
|
||||||
DTYPE_AND_INTERFACE_CASES_PASSING = merged_product(
|
DTYPE_AND_INTERFACE_CASES_PASSING = list(
|
||||||
INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True}
|
merged_product(INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True})
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
Merged product of dtype and interface cases that pass
|
Merged product of dtype and interface cases that pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ALL_CASES = merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES)
|
ALL_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES))
|
||||||
"""
|
"""
|
||||||
Merged product of all cases - dtype, shape, and interface
|
Merged product of all cases - dtype, shape, and interface
|
||||||
"""
|
"""
|
||||||
ALL_CASES_PASSING = merged_product(
|
ALL_CASES_PASSING = list(
|
||||||
|
merged_product(
|
||||||
SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True}
|
SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
Merged product of all cases, but only those that pass
|
Merged product of all cases, but only those that pass
|
||||||
|
|
|
@ -129,7 +129,7 @@ class ValidationCase(BaseModel):
|
||||||
"""
|
"""
|
||||||
Dtype to use in computed annotation used to validate against
|
Dtype to use in computed annotation used to validate against
|
||||||
"""
|
"""
|
||||||
shape: Tuple[int, ...] = (10, 10, 2, 2)
|
shape: Tuple[int, ...] = (10, 10, 2, 3)
|
||||||
"""Shape of the array to validate"""
|
"""Shape of the array to validate"""
|
||||||
dtype: Union[Type, np.dtype] = float
|
dtype: Union[Type, np.dtype] = float
|
||||||
"""Dtype of the array to validate"""
|
"""Dtype of the array to validate"""
|
||||||
|
|
|
@ -14,6 +14,11 @@ from zarr.core import Array as ZarrArray
|
||||||
|
|
||||||
from numpydantic import NDArray
|
from numpydantic import NDArray
|
||||||
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||||
|
from numpydantic.testing.cases import (
|
||||||
|
ALL_CASES_PASSING,
|
||||||
|
DTYPE_AND_INTERFACE_CASES_PASSING,
|
||||||
|
INTERFACE_CASES,
|
||||||
|
)
|
||||||
from numpydantic.testing.helpers import ValidationCase
|
from numpydantic.testing.helpers import ValidationCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,6 +45,26 @@ def _test_roundtrip(source: BaseModel, target: BaseModel):
|
||||||
assert target.array.dtype == source.array.dtype
|
assert target.array.dtype == source.array.dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"interface",
|
||||||
|
[
|
||||||
|
pytest.param(i, marks=getattr(pytest.mark, i.interface.interface.name))
|
||||||
|
for i in INTERFACE_CASES
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"cases", [ALL_CASES_PASSING, DTYPE_AND_INTERFACE_CASES_PASSING]
|
||||||
|
)
|
||||||
|
def test_cases_include_all_interfaces(interface: ValidationCase, cases):
|
||||||
|
"""
|
||||||
|
Test our test cases - we should hit all interfaces in the common "test all" fixtures
|
||||||
|
"""
|
||||||
|
cases = list(cases)
|
||||||
|
assert any(
|
||||||
|
[case.interface is interface.interface for case in cases]
|
||||||
|
), f"Interface case unused in general test cases: {interface.interface}"
|
||||||
|
|
||||||
|
|
||||||
def test_dunder_len(interface_cases, tmp_output_dir_func):
|
def test_dunder_len(interface_cases, tmp_output_dir_func):
|
||||||
"""
|
"""
|
||||||
Each interface or proxy type should support __len__
|
Each interface or proxy type should support __len__
|
||||||
|
@ -177,7 +202,8 @@ def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func):
|
||||||
@pytest.mark.serialization
|
@pytest.mark.serialization
|
||||||
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
|
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
|
||||||
"""
|
"""
|
||||||
Arrays can be dumped when they are specified along with a union of another type field
|
Arrays can be dumped when they are specified along with a
|
||||||
|
union of another type field
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Model(BaseModel):
|
class Model(BaseModel):
|
||||||
|
@ -209,16 +235,3 @@ def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func):
|
||||||
dumped = instance.model_dump_json(round_trip=True)
|
dumped = instance.model_dump_json(round_trip=True)
|
||||||
roundtripped = Model.model_validate_json(dumped)
|
roundtripped = Model.model_validate_json(dumped)
|
||||||
_test_roundtrip(instance, roundtripped)
|
_test_roundtrip(instance, roundtripped)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.serialization
|
|
||||||
def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func):
|
|
||||||
"""
|
|
||||||
We can roundtrip from an AnyType
|
|
||||||
Args:
|
|
||||||
dtype_by_interface:
|
|
||||||
tmp_output_dir_func:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
Loading…
Reference in a new issue