mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 18:54:28 +00:00
working roundtrip serialization from extra after correcting serialization schema. doh.
This commit is contained in:
parent
33eb1c15d5
commit
6ce3d58528
4 changed files with 45 additions and 27 deletions
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union
|
|||
from pydantic_core import SchemaSerializer
|
||||
from pydantic_core.core_schema import (
|
||||
SerializationInfo,
|
||||
any_schema,
|
||||
plain_serializer_function_ser_schema,
|
||||
)
|
||||
|
||||
|
@ -75,9 +76,11 @@ def postprocess_json(
|
|||
|
||||
|
||||
pydantic_serializer = SchemaSerializer(
|
||||
plain_serializer_function_ser_schema(
|
||||
any_schema(
|
||||
serialization=plain_serializer_function_ser_schema(
|
||||
jsonize_array, when_used="json", info_arg=True
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
"""
|
||||
A generic serializer that can be applied to interface proxies et al as
|
||||
|
|
|
@ -225,34 +225,36 @@ All the interface cases
|
|||
"""
|
||||
|
||||
|
||||
DTYPE_AND_SHAPE_CASES = merged_product(SHAPE_CASES, DTYPE_CASES)
|
||||
DTYPE_AND_SHAPE_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES))
|
||||
"""
|
||||
Merged product of dtype and shape cases
|
||||
"""
|
||||
DTYPE_AND_SHAPE_CASES_PASSING = merged_product(
|
||||
SHAPE_CASES, DTYPE_CASES, conditions={"passes": True}
|
||||
DTYPE_AND_SHAPE_CASES_PASSING = list(
|
||||
merged_product(SHAPE_CASES, DTYPE_CASES, conditions={"passes": True})
|
||||
)
|
||||
"""
|
||||
Merged product of dtype and shape cases that are valid
|
||||
"""
|
||||
|
||||
DTYPE_AND_INTERFACE_CASES = merged_product(INTERFACE_CASES, DTYPE_CASES)
|
||||
DTYPE_AND_INTERFACE_CASES = list(merged_product(INTERFACE_CASES, DTYPE_CASES))
|
||||
"""
|
||||
Merged product of dtype and interface cases
|
||||
"""
|
||||
DTYPE_AND_INTERFACE_CASES_PASSING = merged_product(
|
||||
INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True}
|
||||
DTYPE_AND_INTERFACE_CASES_PASSING = list(
|
||||
merged_product(INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True})
|
||||
)
|
||||
"""
|
||||
Merged product of dtype and interface cases that pass
|
||||
"""
|
||||
|
||||
ALL_CASES = merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES)
|
||||
ALL_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES))
|
||||
"""
|
||||
Merged product of all cases - dtype, shape, and interface
|
||||
"""
|
||||
ALL_CASES_PASSING = merged_product(
|
||||
ALL_CASES_PASSING = list(
|
||||
merged_product(
|
||||
SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True}
|
||||
)
|
||||
)
|
||||
"""
|
||||
Merged product of all cases, but only those that pass
|
||||
|
|
|
@ -129,7 +129,7 @@ class ValidationCase(BaseModel):
|
|||
"""
|
||||
Dtype to use in computed annotation used to validate against
|
||||
"""
|
||||
shape: Tuple[int, ...] = (10, 10, 2, 2)
|
||||
shape: Tuple[int, ...] = (10, 10, 2, 3)
|
||||
"""Shape of the array to validate"""
|
||||
dtype: Union[Type, np.dtype] = float
|
||||
"""Dtype of the array to validate"""
|
||||
|
|
|
@ -14,6 +14,11 @@ from zarr.core import Array as ZarrArray
|
|||
|
||||
from numpydantic import NDArray
|
||||
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||
from numpydantic.testing.cases import (
|
||||
ALL_CASES_PASSING,
|
||||
DTYPE_AND_INTERFACE_CASES_PASSING,
|
||||
INTERFACE_CASES,
|
||||
)
|
||||
from numpydantic.testing.helpers import ValidationCase
|
||||
|
||||
|
||||
|
@ -40,6 +45,26 @@ def _test_roundtrip(source: BaseModel, target: BaseModel):
|
|||
assert target.array.dtype == source.array.dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"interface",
|
||||
[
|
||||
pytest.param(i, marks=getattr(pytest.mark, i.interface.interface.name))
|
||||
for i in INTERFACE_CASES
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"cases", [ALL_CASES_PASSING, DTYPE_AND_INTERFACE_CASES_PASSING]
|
||||
)
|
||||
def test_cases_include_all_interfaces(interface: ValidationCase, cases):
|
||||
"""
|
||||
Test our test cases - we should hit all interfaces in the common "test all" fixtures
|
||||
"""
|
||||
cases = list(cases)
|
||||
assert any(
|
||||
[case.interface is interface.interface for case in cases]
|
||||
), f"Interface case unused in general test cases: {interface.interface}"
|
||||
|
||||
|
||||
def test_dunder_len(interface_cases, tmp_output_dir_func):
|
||||
"""
|
||||
Each interface or proxy type should support __len__
|
||||
|
@ -177,7 +202,8 @@ def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func):
|
|||
@pytest.mark.serialization
|
||||
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
|
||||
"""
|
||||
Arrays can be dumped when they are specified along with a union of another type field
|
||||
Arrays can be dumped when they are specified along with a
|
||||
union of another type field
|
||||
"""
|
||||
|
||||
class Model(BaseModel):
|
||||
|
@ -209,16 +235,3 @@ def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func):
|
|||
dumped = instance.model_dump_json(round_trip=True)
|
||||
roundtripped = Model.model_validate_json(dumped)
|
||||
_test_roundtrip(instance, roundtripped)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func):
|
||||
"""
|
||||
We can roundtrip from an AnyType
|
||||
Args:
|
||||
dtype_by_interface:
|
||||
tmp_output_dir_func:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue