working roundtrip serialization from extra after correcting serialization schema. doh.

This commit is contained in:
sneakers-the-rat 2024-10-18 19:31:22 -07:00
parent 33eb1c15d5
commit 6ce3d58528
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
4 changed files with 45 additions and 27 deletions

View file

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union
from pydantic_core import SchemaSerializer
from pydantic_core.core_schema import (
SerializationInfo,
any_schema,
plain_serializer_function_ser_schema,
)
@ -75,9 +76,11 @@ def postprocess_json(
pydantic_serializer = SchemaSerializer(
plain_serializer_function_ser_schema(
any_schema(
serialization=plain_serializer_function_ser_schema(
jsonize_array, when_used="json", info_arg=True
),
)
)
)
"""
A generic serializer that can be applied to interface proxies et al as

View file

@ -225,34 +225,36 @@ All the interface cases
"""
DTYPE_AND_SHAPE_CASES = merged_product(SHAPE_CASES, DTYPE_CASES)
DTYPE_AND_SHAPE_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES))
"""
Merged product of dtype and shape cases
"""
DTYPE_AND_SHAPE_CASES_PASSING = merged_product(
SHAPE_CASES, DTYPE_CASES, conditions={"passes": True}
DTYPE_AND_SHAPE_CASES_PASSING = list(
merged_product(SHAPE_CASES, DTYPE_CASES, conditions={"passes": True})
)
"""
Merged product of dtype and shape cases that are valid
"""
DTYPE_AND_INTERFACE_CASES = merged_product(INTERFACE_CASES, DTYPE_CASES)
DTYPE_AND_INTERFACE_CASES = list(merged_product(INTERFACE_CASES, DTYPE_CASES))
"""
Merged product of dtype and interface cases
"""
DTYPE_AND_INTERFACE_CASES_PASSING = merged_product(
INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True}
DTYPE_AND_INTERFACE_CASES_PASSING = list(
merged_product(INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True})
)
"""
Merged product of dtype and interface cases that pass
"""
ALL_CASES = merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES)
ALL_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES))
"""
Merged product of all cases - dtype, shape, and interface
"""
ALL_CASES_PASSING = merged_product(
ALL_CASES_PASSING = list(
merged_product(
SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True}
)
)
"""
Merged product of all cases, but only those that pass

View file

@ -129,7 +129,7 @@ class ValidationCase(BaseModel):
"""
Dtype to use in computed annotation used to validate against
"""
shape: Tuple[int, ...] = (10, 10, 2, 2)
shape: Tuple[int, ...] = (10, 10, 2, 3)
"""Shape of the array to validate"""
dtype: Union[Type, np.dtype] = float
"""Dtype of the array to validate"""

View file

@ -14,6 +14,11 @@ from zarr.core import Array as ZarrArray
from numpydantic import NDArray
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
from numpydantic.testing.cases import (
ALL_CASES_PASSING,
DTYPE_AND_INTERFACE_CASES_PASSING,
INTERFACE_CASES,
)
from numpydantic.testing.helpers import ValidationCase
@ -40,6 +45,26 @@ def _test_roundtrip(source: BaseModel, target: BaseModel):
assert target.array.dtype == source.array.dtype
@pytest.mark.parametrize(
"interface",
[
pytest.param(i, marks=getattr(pytest.mark, i.interface.interface.name))
for i in INTERFACE_CASES
],
)
@pytest.mark.parametrize(
"cases", [ALL_CASES_PASSING, DTYPE_AND_INTERFACE_CASES_PASSING]
)
def test_cases_include_all_interfaces(interface: ValidationCase, cases):
"""
Test our test cases - we should hit all interfaces in the common "test all" fixtures
"""
cases = list(cases)
assert any(
[case.interface is interface.interface for case in cases]
), f"Interface case unused in general test cases: {interface.interface}"
def test_dunder_len(interface_cases, tmp_output_dir_func):
"""
Each interface or proxy type should support __len__
@ -177,7 +202,8 @@ def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func):
@pytest.mark.serialization
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
"""
Arrays can be dumped when they are specified along with a union of another type field
Arrays can be dumped when they are specified along with a
union of another type field
"""
class Model(BaseModel):
@ -209,16 +235,3 @@ def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func):
dumped = instance.model_dump_json(round_trip=True)
roundtripped = Model.model_validate_json(dumped)
_test_roundtrip(instance, roundtripped)
@pytest.mark.serialization
def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func):
"""
We can roundtrip from an AnyType
Args:
dtype_by_interface:
tmp_output_dir_func:
Returns:
"""