From 6ce3d585283a53552d974511fbb5f6fbf5884c0e Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 18 Oct 2024 19:31:22 -0700 Subject: [PATCH] working roundtrip serialization from extra after correcting serialization schema. doh. --- src/numpydantic/serialization.py | 9 ++++-- src/numpydantic/testing/cases.py | 20 ++++++------ src/numpydantic/testing/helpers.py | 2 +- tests/test_interface/test_interfaces.py | 41 ++++++++++++++++--------- 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index d9a73e0..2fda25c 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union from pydantic_core import SchemaSerializer from pydantic_core.core_schema import ( SerializationInfo, + any_schema, plain_serializer_function_ser_schema, ) @@ -75,9 +76,11 @@ def postprocess_json( pydantic_serializer = SchemaSerializer( - plain_serializer_function_ser_schema( - jsonize_array, when_used="json", info_arg=True - ), + any_schema( + serialization=plain_serializer_function_ser_schema( + jsonize_array, when_used="json", info_arg=True + ) + ) ) """ A generic serializer that can be applied to interface proxies et al as diff --git a/src/numpydantic/testing/cases.py b/src/numpydantic/testing/cases.py index d0f44ee..73b401c 100644 --- a/src/numpydantic/testing/cases.py +++ b/src/numpydantic/testing/cases.py @@ -225,34 +225,36 @@ All the interface cases """ -DTYPE_AND_SHAPE_CASES = merged_product(SHAPE_CASES, DTYPE_CASES) +DTYPE_AND_SHAPE_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES)) """ Merged product of dtype and shape cases """ -DTYPE_AND_SHAPE_CASES_PASSING = merged_product( - SHAPE_CASES, DTYPE_CASES, conditions={"passes": True} +DTYPE_AND_SHAPE_CASES_PASSING = list( + merged_product(SHAPE_CASES, DTYPE_CASES, conditions={"passes": True}) ) """ Merged product of dtype and shape cases that are valid """ -DTYPE_AND_INTERFACE_CASES = merged_product(INTERFACE_CASES, DTYPE_CASES) +DTYPE_AND_INTERFACE_CASES = list(merged_product(INTERFACE_CASES, DTYPE_CASES)) """ Merged product of dtype and interface cases """ -DTYPE_AND_INTERFACE_CASES_PASSING = merged_product( - INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True} +DTYPE_AND_INTERFACE_CASES_PASSING = list( + merged_product(INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True}) ) """ Merged product of dtype and interface cases that pass """ -ALL_CASES = merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES) +ALL_CASES = list(merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES)) """ Merged product of all cases - dtype, shape, and interface """ -ALL_CASES_PASSING = merged_product( - SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True} +ALL_CASES_PASSING = list( + merged_product( + SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True} + ) ) """ Merged product of all cases, but only those that pass diff --git a/src/numpydantic/testing/helpers.py b/src/numpydantic/testing/helpers.py index b337e7d..37d8f1e 100644 --- a/src/numpydantic/testing/helpers.py +++ b/src/numpydantic/testing/helpers.py @@ -129,7 +129,7 @@ class ValidationCase(BaseModel): """ Dtype to use in computed annotation used to validate against """ - shape: Tuple[int, ...] = (10, 10, 2, 2) + shape: Tuple[int, ...] = (10, 10, 2, 3) """Shape of the array to validate""" dtype: Union[Type, np.dtype] = float """Dtype of the array to validate""" diff --git a/tests/test_interface/test_interfaces.py b/tests/test_interface/test_interfaces.py index cd113bf..a2eaf82 100644 --- a/tests/test_interface/test_interfaces.py +++ b/tests/test_interface/test_interfaces.py @@ -14,6 +14,11 @@ from zarr.core import Array as ZarrArray from numpydantic import NDArray from numpydantic.interface import Interface, InterfaceMark, MarkedJson +from numpydantic.testing.cases import ( + ALL_CASES_PASSING, + DTYPE_AND_INTERFACE_CASES_PASSING, + INTERFACE_CASES, +) from numpydantic.testing.helpers import ValidationCase @@ -40,6 +45,26 @@ def _test_roundtrip(source: BaseModel, target: BaseModel): assert target.array.dtype == source.array.dtype +@pytest.mark.parametrize( + "interface", + [ + pytest.param(i, marks=getattr(pytest.mark, i.interface.interface.name)) + for i in INTERFACE_CASES + ], +) +@pytest.mark.parametrize( + "cases", [ALL_CASES_PASSING, DTYPE_AND_INTERFACE_CASES_PASSING] +) +def test_cases_include_all_interfaces(interface: ValidationCase, cases): + """ + Test our test cases - we should hit all interfaces in the common "test all" fixtures + """ + cases = list(cases) + assert any( + [case.interface is interface.interface for case in cases] + ), f"Interface case unused in general test cases: {interface.interface}" + + def test_dunder_len(interface_cases, tmp_output_dir_func): """ Each interface or proxy type should support __len__ @@ -177,7 +202,8 @@ def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func): @pytest.mark.serialization def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func): """ - Arrays can be dumped when they are specified along with a union of another type field + Arrays can be dumped when they are specified along with a + union of another type field """ class Model(BaseModel): @@ -209,16 +235,3 @@ def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func): dumped = instance.model_dump_json(round_trip=True) roundtripped = Model.model_validate_json(dumped) _test_roundtrip(instance, roundtripped) - - -@pytest.mark.serialization -def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func): - """ - We can roundtrip from an AnyType - Args: - dtype_by_interface: - tmp_output_dir_func: - - Returns: - - """