diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f2c8974..9b57320 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,7 +47,7 @@ jobs: run: pip install "numpy${{ matrix.numpy-version }}" - name: Run Tests - run: pytest + run: pytest -n auto - name: Coveralls Parallel uses: coverallsapp/github-action@v2.3.0 diff --git a/pdm.lock b/pdm.lock index 8fa42d5..c1b8ae4 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:cc2b0fb32896c6df0ad747ddb5dee89af22f5c4c4643ee7a52db47fef30da936" +content_hash = "sha256:576c76e5a4a7616b7cb72bca277db1ada330bb8f3f9827d0e161728325e19264" [[metadata.targets]] requires_python = "~=3.9" @@ -576,6 +576,17 @@ files = [ {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, ] +[[package]] +name = "execnet" +version = "2.1.1" +requires_python = ">=3.8" +summary = "execnet: rapid multi-Python deployment" +groups = ["dev", "tests"] +files = [ + {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"}, + {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, +] + [[package]] name = "executing" version = "2.1.0" @@ -1647,6 +1658,21 @@ files = [ {file = "pytest_depends-1.0.1-py3-none-any.whl", hash = "sha256:a1df072bcc93d77aca3f0946903f5fed8af2d9b0056db1dfc9ed5ac164ab0642"}, ] +[[package]] +name = "pytest-xdist" +version = "3.6.1" +requires_python = ">=3.8" +summary = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +groups = ["dev", "tests"] +dependencies = [ + "execnet>=2.1", + "pytest>=7.0.0", +] +files = [ + {file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"}, + {file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/pyproject.toml b/pyproject.toml index bbbce4c..fa2c6a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ tests = [ "coverage>=6.1.1", "pytest-cov<5.0.0,>=4.1.0", "coveralls<4.0.0,>=3.3.1", + "pytest-xdist>=3.6.1", ] docs = [ "numpydantic[arrays]", diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index d243233..3f3195d 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -55,7 +55,6 @@ from typing import ( import numpy as np from pydantic import SerializationInfo -from pydantic_core import SchemaSerializer, core_schema from numpydantic.interface.interface import Interface, JsonDict from numpydantic.types import DtypeType, NDArrayType @@ -100,49 +99,49 @@ class H5JsonDict(JsonDict): ) -def to_json(self, info: SerializationInfo): - """ - Serialize H5Proxy to JSON, as the interface does, - in cases when the interface is not able to be used - (eg. like when used as an `extra` field in a model without a type annotation) - """ - from numpydantic.serialization import postprocess_json - - if info.round_trip: - as_json = { - "type": H5Interface.name, - } - as_json.update(self._h5arraypath._asdict()) - else: - try: - dset = self.open() - as_json = dset[:].tolist() - finally: - self.close() - return postprocess_json(as_json, info) +# def to_json(self, info: SerializationInfo): +# """ +# Serialize H5Proxy to JSON, as the interface does, +# in cases when the interface is not able to be used +# (eg. like when used as an `extra` field in a model without a type annotation) +# """ +# from numpydantic.serialization import postprocess_json +# +# if info.round_trip: +# as_json = { +# "type": H5Interface.name, +# } +# as_json.update(self._h5arraypath._asdict()) +# else: +# try: +# dset = self.open() +# as_json = dset[:].tolist() +# finally: +# self.close() +# return postprocess_json(as_json, info) -def _make_pydantic_schema(): - return core_schema.typed_dict_schema( - { - "file": core_schema.typed_dict_field( - core_schema.str_schema(), required=True - ), - "path": core_schema.typed_dict_field( - core_schema.str_schema(), required=True - ), - "field": core_schema.typed_dict_field( - core_schema.union_schema( - [ - core_schema.str_schema(), - core_schema.list_schema(core_schema.str_schema()), - ], - ), - required=True, - ), - }, - # serialization= - ) +# def _make_pydantic_schema(): +# return core_schema.typed_dict_schema( +# { +# "file": core_schema.typed_dict_field( +# core_schema.str_schema(), required=True +# ), +# "path": core_schema.typed_dict_field( +# core_schema.str_schema(), required=True +# ), +# "field": core_schema.typed_dict_field( +# core_schema.union_schema( +# [ +# core_schema.str_schema(), +# core_schema.list_schema(core_schema.str_schema()), +# ], +# ), +# required=True, +# ), +# }, +# # serialization= +# ) class H5Proxy: @@ -167,11 +166,11 @@ class H5Proxy: annotation_dtype (dtype): Optional - the dtype of our type annotation """ - __pydantic_serializer__ = SchemaSerializer( - core_schema.plain_serializer_function_ser_schema( - to_json, when_used="json", info_arg=True - ), - ) + # __pydantic_serializer__ = SchemaSerializer( + # core_schema.plain_serializer_function_ser_schema( + # jsonize_array, when_used="json", info_arg=True + # ), + # ) def __init__( self, diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 42bb891..c8117d8 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -20,6 +20,7 @@ from numpydantic.exceptions import ( ShapeError, TooManyMatchesError, ) +from numpydantic.serialization import pydantic_serializer from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.validation import validate_dtype, validate_shape @@ -232,6 +233,7 @@ class Interface(ABC, Generic[T]): shape_valid = self.validate_shape(shape) self.raise_for_shape(shape_valid, shape) + array = self.apply_serializer(array) array = self.after_validation(array) return array @@ -338,6 +340,17 @@ class Interface(ABC, Generic[T]): f"got shape {shape}" ) + def apply_serializer(self, array: NDArrayType) -> NDArrayType: + """ + Apply a __pydantic_serializer__ method that allows the array to be + serialized directly without knowledge of the interface in the containing model. + + Useful for when arrays are specified as `__pydantic_extra__` which has + different serializer resolution logic. + """ + array.__pydantic_serializer__ = pydantic_serializer + return array + def after_validation(self, array: NDArrayType) -> T: """ Optional step post-validation that coerces the intermediate array type into the diff --git a/src/numpydantic/interface/numpy.py b/src/numpydantic/interface/numpy.py index 20019f7..f6004b3 100644 --- a/src/numpydantic/interface/numpy.py +++ b/src/numpydantic/interface/numpy.py @@ -2,11 +2,12 @@ Interface to numpy arrays """ -from typing import Any, Literal, Union +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, SerializationInfo from numpydantic.interface.interface import Interface, JsonDict +from numpydantic.serialization import pydantic_serializer try: import numpy as np @@ -36,6 +37,28 @@ class NumpyJsonDict(JsonDict): return np.array(self.value, dtype=self.dtype) +class SerializableNDArray(np.ndarray): + """ + Trivial subclass of :class:`numpy.ndarray` that allows + an additional ``__pydantic_serializer__`` attr to allow it to be + json roundtripped without the help of an interface + + References: + https://numpy.org/doc/stable/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray + """ + + def __new__(cls, input_array: np.ndarray, **kwargs: dict[str, Any]): + """Create a new ndarray instance, adding a new attribute""" + obj = np.asarray(input_array, **kwargs).view(cls) + obj.__pydantic_serializer__ = pydantic_serializer + return obj + + def __array_finalize__(self, obj: Optional[np.ndarray]) -> None: + if obj is None: + return + self.__pydantic_serializer__ = getattr(obj, "__pydantic_serializer__", None) + + class NumpyInterface(Interface): """ Numpy :class:`~numpy.ndarray` s! @@ -62,7 +85,7 @@ class NumpyInterface(Interface): if array is None: return False - if isinstance(array, ndarray): + if isinstance(array, (ndarray, SerializableNDArray)): return True elif isinstance(array, dict): return NumpyJsonDict.is_valid(array) @@ -78,12 +101,14 @@ class NumpyInterface(Interface): Coerce to an ndarray. We have already checked if coercion is possible in :meth:`.check` """ - if not isinstance(array, ndarray): - array = np.array(array) + if not isinstance(array, SerializableNDArray): + array = SerializableNDArray(array) try: if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict): - array = np.vectorize(lambda x: self.dtype(**x))(array) + array = SerializableNDArray( + np.vectorize(lambda x: self.dtype(**x))(array) + ) except TypeError: # fine, dtype isn't a type pass diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index 2b570cf..d9a73e0 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -4,11 +4,16 @@ and :func:`pydantic.BaseModel.model_dump_json` . """ from pathlib import Path -from typing import Any, Callable, Iterable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar, Union -from pydantic_core.core_schema import SerializationInfo +from pydantic_core import SchemaSerializer +from pydantic_core.core_schema import ( + SerializationInfo, + plain_serializer_function_ser_schema, +) -from numpydantic.interface import Interface, JsonDict +if TYPE_CHECKING: + from numpydantic.interface import Interface T = TypeVar("T") U = TypeVar("U") @@ -16,8 +21,8 @@ U = TypeVar("U") def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: """Use an interface class to render an array as JSON""" - # return [1, 2, 3] - # pdb.set_trace() + from numpydantic.interface import Interface + interface_cls = Interface.match_output(value) array = interface_cls.to_json(value, info) array = postprocess_json(array, info, interface_cls) @@ -25,11 +30,13 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: def postprocess_json( - array: Union[dict, list], info: SerializationInfo, interface_cls: type[Interface] + array: Union[dict, list], info: SerializationInfo, interface_cls: type["Interface"] ) -> Union[dict, list]: """ Modify json after dumping from an interface """ + from numpydantic.interface import JsonDict + # perf: keys to skip in generation - anything named "value" is array data. skip = ["value"] if isinstance(array, JsonDict): @@ -67,6 +74,17 @@ def postprocess_json( return array +pydantic_serializer = SchemaSerializer( + plain_serializer_function_ser_schema( + jsonize_array, when_used="json", info_arg=True + ), +) +""" +A generic serializer that can be applied to interface proxies et al as +``__pydantic_serializer__`` . +""" + + def _relativize_paths( value: dict, relative_to: str = ".", skip: Iterable = tuple() ) -> dict: diff --git a/tests/test_interface/conftest.py b/tests/test_interface/conftest.py index 5ba52fb..3e667bf 100644 --- a/tests/test_interface/conftest.py +++ b/tests/test_interface/conftest.py @@ -97,6 +97,8 @@ def all_passing_cases(request) -> ValidationCase: that we want to be sure is *very true* in every circumstance. Typically, that means only use this in `test_interfaces.py` """ + if "subclass" in request.param.id.lower(): + pytest.xfail() return request.param @@ -136,10 +138,12 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func): for p in DTYPE_AND_INTERFACE_CASES_PASSING ) ) -def all_passing_cases(request): +def dtype_by_interface(request): """ Tests for all dtypes by all interfaces """ + if "subclass" in request.param.id.lower(): + pytest.xfail() return request.param diff --git a/tests/test_interface/test_interfaces.py b/tests/test_interface/test_interfaces.py index 10adec3..cd113bf 100644 --- a/tests/test_interface/test_interfaces.py +++ b/tests/test_interface/test_interfaces.py @@ -4,13 +4,15 @@ Tests that should be applied to all interfaces import json from importlib.metadata import version +from typing import Generic, TypeVar import dask.array as da import numpy as np import pytest -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from zarr.core import Array as ZarrArray +from numpydantic import NDArray from numpydantic.interface import Interface, InterfaceMark, MarkedJson from numpydantic.testing.helpers import ValidationCase @@ -30,6 +32,8 @@ def _test_roundtrip(source: BaseModel, target: BaseModel): assert target_type is source_type else: assert np.all(da.equal(target.array, source.array)) + elif isinstance(source.array, BaseModel): + return _test_roundtrip(source.array, target.array) else: assert target.array == source.array @@ -95,8 +99,6 @@ def test_interface_roundtrip_json(all_passing_cases, tmp_output_dir_func): """ All interfaces should be able to roundtrip to and from json """ - if "subclass" in all_passing_cases.id.lower(): - pytest.xfail() array = all_passing_cases.array(path=tmp_output_dir_func) case = all_passing_cases.model(array=array) @@ -154,3 +156,69 @@ def test_interface_mark_roundtrip(all_passing_cases, valid, tmp_output_dir_func) model = case.model_validate_json(dumped_json) _test_roundtrip(case, model) + + +@pytest.mark.serialization +def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func): + """ + Arrays can be dumped when they are specified in an `__extra__` field + """ + + class Model(BaseModel): + __pydantic_extra__: dict[str, dtype_by_interface.annotation] + model_config = ConfigDict(extra="allow") + + instance = Model(array=dtype_by_interface.array(path=tmp_output_dir_func)) + dumped = instance.model_dump_json(round_trip=True) + roundtripped = Model.model_validate_json(dumped) + _test_roundtrip(instance, roundtripped) + + +@pytest.mark.serialization +def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func): + """ + Arrays can be dumped when they are specified along with a union of another type field + """ + + class Model(BaseModel): + array: str | dtype_by_interface.annotation + + array = dtype_by_interface.array(path=tmp_output_dir_func) + + instance = Model(array=array) + dumped = instance.model_dump_json(round_trip=True) + roundtripped = Model.model_validate_json(dumped) + _test_roundtrip(instance, roundtripped) + + +@pytest.mark.serialization +def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func): + """ + Arrays can be dumped when they are specified in an `__extra__` field + """ + T = TypeVar("T", bound=NDArray) + + class GenType(BaseModel, Generic[T]): + array: T + + class Model(BaseModel): + array: GenType[dtype_by_interface.annotation] + + array = dtype_by_interface.array(path=tmp_output_dir_func) + instance = Model(**{"array": {"array": array}}) + dumped = instance.model_dump_json(round_trip=True) + roundtripped = Model.model_validate_json(dumped) + _test_roundtrip(instance, roundtripped) + + +@pytest.mark.serialization +def test_roundtrip_from_any(dtype_by_interface, tmp_output_dir_func): + """ + We can roundtrip from an AnyType + Args: + dtype_by_interface: + tmp_output_dir_func: + + Returns: + + """