From b436d8d592ddac5aa18e038907ddf56d72f6c25c Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 20 Sep 2024 23:44:59 -0700 Subject: [PATCH] Roundtrip json serialization using dataclasses for all interfaces. Separate JSON schema generation from core_schema generation. More consistent json dumping behavior - always return just the array unless `round_trip=True`. --- docs/api/monkeypatch.md | 6 -- docs/index.md | 2 +- docs/serialization.md | 2 + pyproject.toml | 6 ++ src/numpydantic/interface/__init__.py | 3 +- src/numpydantic/interface/dask.py | 72 +++++++++++-- src/numpydantic/interface/hdf5.py | 67 +++++++++--- src/numpydantic/interface/interface.py | 129 ++++++++++++++++++++++-- src/numpydantic/interface/numpy.py | 49 ++++++++- src/numpydantic/interface/video.py | 43 +++++++- src/numpydantic/interface/zarr.py | 86 ++++++++++++---- src/numpydantic/ndarray.py | 25 +++-- src/numpydantic/schema.py | 49 +++++++-- tests/test_interface/test_hdf5.py | 17 ++-- tests/test_interface/test_interfaces.py | 26 +++++ tests/test_interface/test_zarr.py | 36 ++++--- tests/test_ndarray.py | 13 +++ 17 files changed, 532 insertions(+), 99 deletions(-) delete mode 100644 docs/api/monkeypatch.md create mode 100644 docs/serialization.md diff --git a/docs/api/monkeypatch.md b/docs/api/monkeypatch.md deleted file mode 100644 index d397869..0000000 --- a/docs/api/monkeypatch.md +++ /dev/null @@ -1,6 +0,0 @@ -# monkeypatch - -```{eval-rst} -.. automodule:: numpydantic.monkeypatch - :members: -``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index af2c908..0cea5b4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True}) design syntax +serialization interfaces todo changelog @@ -489,7 +490,6 @@ api/dtype api/ndarray api/maps api/meta -api/monkeypatch api/schema api/shape api/types diff --git a/docs/serialization.md b/docs/serialization.md new file mode 100644 index 0000000..d5162bf --- /dev/null +++ b/docs/serialization.md @@ -0,0 +1,2 @@ +# Serialization + diff --git a/pyproject.toml b/pyproject.toml index 4dddb69..f1b7173 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,12 @@ filterwarnings = [ # nptyping's alias warnings 'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*' ] +markers = [ + "dtype: mark test related to dtype validation", + "shape: mark test related to shape validation", + "json_schema: mark test related to json schema generation", + "serialization: mark test related to serialization" +] [tool.ruff] target-version = "py311" diff --git a/src/numpydantic/interface/__init__.py b/src/numpydantic/interface/__init__.py index 0a0c490..c5bd3f2 100644 --- a/src/numpydantic/interface/__init__.py +++ b/src/numpydantic/interface/__init__.py @@ -4,12 +4,13 @@ Interfaces between nptyping types and array backends from numpydantic.interface.dask import DaskInterface from numpydantic.interface.hdf5 import H5Interface -from numpydantic.interface.interface import Interface +from numpydantic.interface.interface import Interface, JsonDict from numpydantic.interface.numpy import NumpyInterface from numpydantic.interface.video import VideoInterface from numpydantic.interface.zarr import ZarrInterface __all__ = [ + "JsonDict", "Interface", "DaskInterface", "H5Interface", diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index 7719e98..6f02025 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -2,26 +2,59 @@ Interface for Dask arrays """ -from typing import Any, Optional +from dataclasses import dataclass +from typing import Any, Iterable, Literal, Optional import numpy as np from pydantic import SerializationInfo -from numpydantic.interface.interface import Interface +from numpydantic.interface.interface import Interface, JsonDict from numpydantic.types import DtypeType, NDArrayType try: + from dask.array import from_array from dask.array.core import Array as DaskArray except ImportError: # pragma: no cover DaskArray = None +def _as_tuple(a_list: list | Any) -> tuple: + """Make a list of list into a tuple of tuples""" + return tuple( + [_as_tuple(item) if isinstance(item, list) else item for item in a_list] + ) + + +@dataclass(kw_only=True) +class DaskJsonDict(JsonDict): + """ + Round-trip json serialized form of a dask array + """ + + type: Literal["dask"] + name: str + chunks: Iterable[tuple[int, ...]] + dtype: str + array: list + + def to_array_input(self) -> DaskArray: + """Construct a dask array""" + np_array = np.array(self.array, dtype=self.dtype) + array = from_array( + np_array, + name=self.name, + chunks=_as_tuple(self.chunks), + ) + return array + + class DaskInterface(Interface): """ Interface for Dask :class:`~dask.array.core.Array` """ - input_types = (DaskArray,) + name = "dask" + input_types = (DaskArray, dict) return_type = DaskArray @classmethod @@ -29,7 +62,24 @@ class DaskInterface(Interface): """ check if array is a dask array """ - return DaskArray is not None and isinstance(array, DaskArray) + if DaskArray is None: + return False + elif isinstance(array, DaskArray): + return True + elif isinstance(array, dict): + return DaskJsonDict.is_valid(array) + else: + return False + + def before_validation(self, array: Any) -> DaskArray: + """ + If given a dict (like that from ``model_dump_json(round_trip=True)`` ), + re-cast to dask array + """ + if isinstance(array, dict): + array = DaskJsonDict(**array).to_array_input() + + return array def get_object_dtype(self, array: NDArrayType) -> DtypeType: """Dask arrays require a compute() call to retrieve a single value""" @@ -43,7 +93,7 @@ class DaskInterface(Interface): @classmethod def to_json( cls, array: DaskArray, info: Optional[SerializationInfo] = None - ) -> list: + ) -> list | DaskJsonDict: """ Convert an array to a JSON serializable array by first converting to a numpy array and then to a list. @@ -56,4 +106,14 @@ class DaskInterface(Interface): method of serialization here using the python object itself rather than its JSON representation. """ - return np.array(array).tolist() + np_array = np.array(array) + as_json = np_array.tolist() + if info.round_trip: + as_json = DaskJsonDict( + type=cls.name, + array=as_json, + name=array.name, + chunks=array.chunks, + dtype=str(np_array.dtype), + ) + return as_json diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 3696a6f..9a5bf93 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -40,6 +40,7 @@ as ``S32`` isoformatted byte strings (timezones optional) like: """ import sys +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union @@ -47,7 +48,7 @@ from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Un import numpy as np from pydantic import SerializationInfo -from numpydantic.interface.interface import Interface +from numpydantic.interface.interface import Interface, JsonDict from numpydantic.types import DtypeType, NDArrayType try: @@ -76,6 +77,21 @@ class H5ArrayPath(NamedTuple): """Refer to a specific field within a compound dtype""" +@dataclass +class H5JsonDict(JsonDict): + """Round-trip Json-able version of an HDF5 dataset""" + + file: str + path: str + field: Optional[str] = None + + def to_array_input(self) -> H5ArrayPath: + """Construct an :class:`.H5ArrayPath`""" + return H5ArrayPath( + **{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields} + ) + + class H5Proxy: """ Proxy class to mimic numpy-like array behavior with an HDF5 array @@ -110,6 +126,7 @@ class H5Proxy: self.path = path self.field = field self._annotation_dtype = annotation_dtype + self._h5arraypath = H5ArrayPath(self.file, self.path, self.field) def array_exists(self) -> bool: """Check that there is in fact an array at :attr:`.path` within :attr:`.file`""" @@ -212,6 +229,15 @@ class H5Proxy: """self.shape[0]""" return self.shape[0] + def __eq__(self, other: "H5Proxy") -> bool: + """ + Check that we are referring to the same hdf5 array + """ + if isinstance(other, H5Proxy): + return self._h5arraypath == other._h5arraypath + else: + raise ValueError("Can only compare equality of two H5Proxies") + def open(self, mode: str = "r") -> "h5py.Dataset": """ Return the opened :class:`h5py.Dataset` object @@ -251,6 +277,7 @@ class H5Interface(Interface): passthrough numpy-like interface to the dataset. """ + name = "hdf5" input_types = (H5ArrayPath, H5Arraylike, H5Proxy) return_type = H5Proxy @@ -268,6 +295,13 @@ class H5Interface(Interface): if isinstance(array, (H5ArrayPath, H5Proxy)): return True + if isinstance(array, dict): + if array.get("type", False) == cls.name: + return True + # continue checking if dict contains an hdf5 file + file = array.get("file", "") + array = (file, "") + if isinstance(array, (tuple, list)) and len(array) in (2, 3): # check that the first arg is an hdf5 file try: @@ -294,6 +328,9 @@ class H5Interface(Interface): def before_validation(self, array: Any) -> NDArrayType: """Create an :class:`.H5Proxy` to use throughout validation""" + if isinstance(array, dict): + array = H5JsonDict(**array).to_array_input() + if isinstance(array, H5ArrayPath): array = H5Proxy.from_h5array(h5array=array) elif isinstance(array, H5Proxy): @@ -349,21 +386,27 @@ class H5Interface(Interface): @classmethod def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict: """ - Dump to a dictionary containing + Render HDF5 array as JSON + + If ``round_trip == True``, we dump just the proxy info, a dictionary like: * ``file``: :attr:`.file` * ``path``: :attr:`.path` * ``attrs``: Any HDF5 attributes on the dataset * ``array``: The array as a list of lists + + Otherwise, we dump the array as a list of lists """ - try: - dset = array.open() - meta = { - "file": array.file, - "path": array.path, - "attrs": dict(dset.attrs), - "array": dset[:].tolist(), + if info.round_trip: + as_json = { + "type": cls.name, } - return meta - finally: - array.close() + as_json.update(array._h5arraypath._asdict()) + else: + try: + dset = array.open() + as_json = dset[:].tolist() + finally: + array.close() + + return as_json diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 1ef307f..fa32b8e 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -2,12 +2,15 @@ Base Interface metaclass """ +import inspect from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from importlib.metadata import PackageNotFoundError, version from operator import attrgetter -from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union import numpy as np -from pydantic import SerializationInfo +from pydantic import SerializationInfo, TypeAdapter, ValidationError from numpydantic.exceptions import ( DtypeError, @@ -21,6 +24,60 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType T = TypeVar("T", bound=NDArrayType) +class InterfaceMark(TypedDict): + """JSON-able mark to be able to round-trip json dumps""" + + module: str + cls: str + version: str + + +@dataclass(kw_only=True) +class JsonDict: + """ + Representation of array when dumped with round_trip == True. + + Using a dataclass rather than a pydantic model to not tempt + us to use more sophisticated types than can be serialized to json. + """ + + type: str + + @abstractmethod + def to_array_input(self) -> Any: + """ + Convert this roundtrip specifier to the relevant input class + (one of the ``input_types`` of an interface). + """ + + def to_dict(self) -> dict: + """ + Convenience method for casting dataclass to dict, + removing None-valued items + """ + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def get_adapter(cls) -> TypeAdapter: + """Convenience method to get a typeadapter for this class""" + return TypeAdapter(cls) + + @classmethod + def is_valid(cls, val: dict) -> bool: + """ + Check whether a given dictionary matches this JsonDict specification + + Returns: + bool - true if valid, false if not + """ + adapter = cls.get_adapter() + try: + _ = adapter.validate_python(val) + return True + except ValidationError: + return False + + class Interface(ABC, Generic[T]): """ Abstract parent class for interfaces to different array formats @@ -30,7 +87,7 @@ class Interface(ABC, Generic[T]): return_type: Type[T] priority: int = 0 - def __init__(self, shape: ShapeType, dtype: DtypeType) -> None: + def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None: self.shape = shape self.dtype = dtype @@ -86,6 +143,7 @@ class Interface(ABC, Generic[T]): self.raise_for_shape(shape_valid, shape) array = self.after_validation(array) + return array def before_validation(self, array: Any) -> NDArrayType: @@ -117,8 +175,6 @@ class Interface(ABC, Generic[T]): """ Validate the dtype of the given array, returning ``True`` if valid, ``False`` if not. - - """ if self.dtype is Any: return True @@ -196,6 +252,13 @@ class Interface(ABC, Generic[T]): """ return array + def mark_input(self, array: Any) -> Any: + """ + Preserve metadata about the interface and passed input when dumping with + ``round_trip`` + """ + return array + @classmethod @abstractmethod def check(cls, array: Any) -> bool: @@ -211,17 +274,40 @@ class Interface(ABC, Generic[T]): installed, etc.) """ + @property + @abstractmethod + def name(self) -> str: + """ + Short name for this interface + """ + @classmethod - def to_json( - cls, array: Type[T], info: Optional[SerializationInfo] = None - ) -> Union[list, dict]: + @abstractmethod + def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]: """ Convert an array of :attr:`.return_type` to a JSON-compatible format using base python types """ - if not isinstance(array, np.ndarray): # pragma: no cover - array = np.array(array) - return array.tolist() + + @classmethod + def mark_json(cls, array: Union[list, dict]) -> dict: + """ + When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``, + add additional annotations that would allow the serialized array to be + roundtripped. + + Default is just to add an :class:`.InterfaceMark` + + Examples: + + >>> from pprint import pprint + >>> pprint(Interface.mark_json([1.0, 2.0])) + {'interface': {'cls': 'Interface', + 'module': 'numpydantic.interface.interface', + 'version': '1.2.2'}, + 'value': [1.0, 2.0]} + """ + return {"interface": cls.mark_interface(), "value": array} @classmethod def interfaces( @@ -335,3 +421,24 @@ class Interface(ABC, Generic[T]): raise NoMatchError(f"No matching interfaces found for output {array}") else: return matches[0] + + @classmethod + def mark_interface(cls) -> InterfaceMark: + """ + Create an interface mark indicating this interface for validation after + JSON serialization with ``round_trip==True`` + """ + interface_module = inspect.getmodule(cls) + interface_module = ( + None if interface_module is None else interface_module.__name__ + ) + try: + v = ( + None + if interface_module is None + else version(interface_module.split(".")[0]) + ) + except PackageNotFoundError: + v = None + interface_name = cls.__name__ + return InterfaceMark(module=interface_module, cls=interface_name, version=v) diff --git a/src/numpydantic/interface/numpy.py b/src/numpydantic/interface/numpy.py index 5ee988a..8c68f1e 100644 --- a/src/numpydantic/interface/numpy.py +++ b/src/numpydantic/interface/numpy.py @@ -2,9 +2,12 @@ Interface to numpy arrays """ -from typing import Any +from dataclasses import dataclass +from typing import Any, Literal, Union -from numpydantic.interface.interface import Interface +from pydantic import SerializationInfo + +from numpydantic.interface.interface import Interface, JsonDict try: import numpy as np @@ -18,11 +21,29 @@ except ImportError: # pragma: no cover np = None +@dataclass +class NumpyJsonDict(JsonDict): + """ + JSON-able roundtrip representation of numpy array + """ + + type: Literal["numpy"] + dtype: str + array: list + + def to_array_input(self) -> ndarray: + """ + Construct a numpy array + """ + return np.array(self.array, dtype=self.dtype) + + class NumpyInterface(Interface): """ Numpy :class:`~numpy.ndarray` s! """ + name = "numpy" input_types = (ndarray, list) return_type = ndarray priority = -999 @@ -41,6 +62,8 @@ class NumpyInterface(Interface): """ if isinstance(array, ndarray): return True + elif isinstance(array, dict): + return NumpyJsonDict.is_valid(array) else: try: _ = np.array(array) @@ -53,6 +76,9 @@ class NumpyInterface(Interface): Coerce to an ndarray. We have already checked if coercion is possible in :meth:`.check` """ + if isinstance(array, dict): + array = NumpyJsonDict(**array).to_array_input() + if not isinstance(array, ndarray): array = np.array(array) return array @@ -61,3 +87,22 @@ class NumpyInterface(Interface): def enabled(cls) -> bool: """Check that numpy is present in the environment""" return ENABLED + + @classmethod + def to_json( + cls, array: ndarray, info: SerializationInfo = None + ) -> Union[list, JsonDict]: + """ + Convert an array of :attr:`.return_type` to a JSON-compatible format using + base python types + """ + if not isinstance(array, np.ndarray): # pragma: no cover + array = np.array(array) + + json_array = array.tolist() + + if info.round_trip: + json_array = NumpyJsonDict( + type=cls.name, dtype=str(array.dtype), array=json_array + ) + return json_array diff --git a/src/numpydantic/interface/video.py b/src/numpydantic/interface/video.py index 1660545..3b3b499 100644 --- a/src/numpydantic/interface/video.py +++ b/src/numpydantic/interface/video.py @@ -2,11 +2,14 @@ Interface to support treating videos like arrays using OpenCV """ +from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Tuple, Union +from typing import Any, Literal, Optional, Tuple, Union import numpy as np +from pydantic_core.core_schema import SerializationInfo +from numpydantic.interface import JsonDict from numpydantic.interface.interface import Interface try: @@ -19,6 +22,20 @@ except ImportError: # pragma: no cover VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") +@dataclass(kw_only=True) +class VideoJsonDict(JsonDict): + """Json-able roundtrip representation of a video file""" + + type: Literal["video"] + file: str + + def to_array_input(self) -> "VideoProxy": + """ + Construct a :class:`.VideoProxy` + """ + return VideoProxy(path=Path(self.file)) + + class VideoProxy: """ Passthrough proxy class to interact with videos as arrays @@ -184,6 +201,12 @@ class VideoProxy: def __getattr__(self, item: str): return getattr(self.video, item) + def __eq__(self, other: "VideoProxy") -> bool: + """Check if this is a proxy to the same video file""" + if not isinstance(other, VideoProxy): + raise TypeError("Can only compare equality of two VideoProxies") + return self.path == other.path + def __len__(self) -> int: """Number of frames in the video""" return self.shape[0] @@ -194,6 +217,7 @@ class VideoInterface(Interface): OpenCV interface to treat videos as arrays. """ + name = "video" input_types = (str, Path, VideoCapture, VideoProxy) return_type = VideoProxy @@ -213,6 +237,9 @@ class VideoInterface(Interface): ): return True + if isinstance(array, dict): + array = array.get("file", "") + if isinstance(array, str): try: array = Path(array) @@ -224,10 +251,22 @@ class VideoInterface(Interface): def before_validation(self, array: Any) -> VideoProxy: """Get a :class:`.VideoProxy` object for this video""" - if isinstance(array, VideoCapture): + if isinstance(array, dict): + proxy = VideoJsonDict(**array).to_array_input() + elif isinstance(array, VideoCapture): proxy = VideoProxy(video=array) elif isinstance(array, VideoProxy): proxy = array else: proxy = VideoProxy(path=array) return proxy + + @classmethod + def to_json( + cls, array: VideoProxy, info: SerializationInfo + ) -> Union[list, VideoJsonDict]: + """Return a json-representation of a video""" + if info.round_trip: + return VideoJsonDict(type=cls.name, file=str(array.path)) + else: + return np.array(array).tolist() diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index 87f538a..1e5b612 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -5,12 +5,12 @@ Interface to zarr arrays import contextlib from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Sequence, Union +from typing import Any, Literal, Optional, Sequence, Union import numpy as np from pydantic import SerializationInfo -from numpydantic.interface.interface import Interface +from numpydantic.interface.interface import Interface, JsonDict from numpydantic.types import DtypeType try: @@ -56,11 +56,34 @@ class ZarrArrayPath: raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath") +@dataclass(kw_only=True) +class ZarrJsonDict(JsonDict): + """Round-trip Json-able version of a Zarr Array""" + + info: dict[str, str] + type: Literal["zarr"] + file: Optional[str] = None + path: Optional[str] = None + array: Optional[list] = None + + def to_array_input(self) -> ZarrArray | ZarrArrayPath: + """ + Construct a ZarrArrayPath if file and path are present, + otherwise a ZarrArray + """ + if self.file: + array = ZarrArrayPath(file=self.file, path=self.path) + else: + array = zarr.array(self.array) + return array + + class ZarrInterface(Interface): """ Interface to in-memory or on-disk zarr arrays """ + name = "zarr" input_types = (Path, ZarrArray, ZarrArrayPath) return_type = ZarrArray @@ -73,6 +96,9 @@ class ZarrInterface(Interface): def _get_array( array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence] ) -> ZarrArray: + if isinstance(array, dict): + array = ZarrJsonDict(**array).to_array_input() + if isinstance(array, ZarrArray): return array @@ -92,6 +118,12 @@ class ZarrInterface(Interface): if isinstance(array, ZarrArray): return True + if isinstance(array, dict): + if array.get("type", False) == cls.name: + return True + # continue checking if dict contains a zarr file + array = array.get("file", "") + # See if can be coerced to ZarrArrayPath if isinstance(array, (Path, str)): array = ZarrArrayPath(file=array) @@ -135,26 +167,46 @@ class ZarrInterface(Interface): cls, array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence], info: Optional[SerializationInfo] = None, - ) -> dict: + ) -> list | ZarrJsonDict: """ - Dump just the metadata for an array from :meth:`zarr.core.Array.info_items` - plus the :meth:`zarr.core.Array.hexdigest`. + Dump a Zarr Array to JSON + + If ``info.round_trip == False``, dump the array as a list of lists. + This may be a memory-intensive operation. + + Otherwise, dump the metadata for an array from :meth:`zarr.core.Array.info_items` + plus the :meth:`zarr.core.Array.hexdigest` as a :class:`.ZarrJsonDict` + + If either the ``zarr_dump_array`` value in the context dictionary is ``True`` + or the zarr array is an in-memory array, dump the array as well + (since without a persistent array it would be impossible to roundtrip and + dumping to JSON would be meaningless) - The full array can be returned by passing ``'zarr_dump_array': True`` to the - serialization ``context`` :: + Passing ``'zarr_dump_array': True`` to the serialization ``context`` looks like this:: model.model_dump_json(context={'zarr_dump_array': True}) """ - dump_array = False - if info is not None and info.context is not None: - dump_array = info.context.get("zarr_dump_array", False) - array = cls._get_array(array) - info = array.info_items() - info_dict = {i[0]: i[1] for i in info} - info_dict["hexdigest"] = array.hexdigest() - if dump_array: - info_dict["array"] = array[:].tolist() + if info.round_trip: + dump_array = False + if info is not None and info.context is not None: + dump_array = info.context.get("zarr_dump_array", False) + is_file = False - return info_dict + as_json = {"type": cls.name} + if hasattr(array.store, "dir_path"): + is_file = True + as_json["file"] = array.store.dir_path() + as_json["path"] = array.name + as_json["info"] = {i[0]: i[1] for i in array.info_items()} + as_json["info"]["hexdigest"] = array.hexdigest() + + if dump_array or not is_file: + as_json["array"] = array[:].tolist() + + as_json = ZarrJsonDict(**as_json) + else: + as_json = array[:].tolist() + + return as_json diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index d951d3a..d494154 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -24,7 +24,6 @@ from numpydantic.exceptions import InterfaceError from numpydantic.interface import Interface from numpydantic.maps import python_to_nptyping from numpydantic.schema import ( - _handler_type, _jsonize_array, get_validate_interface, make_json_schema, @@ -41,6 +40,9 @@ from numpydantic.vendor.nptyping.typing_ import ( if TYPE_CHECKING: # pragma: no cover from nptyping.base_meta_classes import SubscriptableMeta + from pydantic._internal._schema_generation_shared import ( + CallbackGetCoreSchemaHandler, + ) from numpydantic import Shape @@ -164,33 +166,34 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): def __get_pydantic_core_schema__( cls, _source_type: "NDArray", - _handler: _handler_type, + _handler: "CallbackGetCoreSchemaHandler", ) -> core_schema.CoreSchema: shape, dtype = _source_type.__args__ shape: ShapeType dtype: DtypeType - # get pydantic core schema as a list of lists for JSON schema - list_schema = make_json_schema(shape, dtype, _handler) + # make core schema for json schema, store it and any model definitions + # note that there is a big of fragility in this function, + # as we need to access a private method of _handler to + # flatten out the json schema. See help(make_json_schema) + json_schema = make_json_schema(shape, dtype, _handler) - return core_schema.json_or_python_schema( - json_schema=list_schema, - python_schema=core_schema.with_info_plain_validator_function( - get_validate_interface(shape, dtype) - ), + return core_schema.with_info_plain_validator_function( + get_validate_interface(shape, dtype), serialization=core_schema.plain_serializer_function_ser_schema( _jsonize_array, when_used="json", info_arg=True ), + metadata=json_schema, ) @classmethod def __get_pydantic_json_schema__( cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> core_schema.JsonSchema: - json_schema = handler(schema) + shape, dtype = cls.__args__ + json_schema = handler(schema["metadata"]) json_schema = handler.resolve_ref_schema(json_schema) - dtype = cls.__args__[1] if not isinstance(dtype, tuple) and dtype.__module__ not in ( "builtins", "typing", diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index d98f880..36d6812 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -13,19 +13,24 @@ from pydantic_core import CoreSchema, core_schema from pydantic_core.core_schema import ListSchema, ValidationInfo from numpydantic import dtype as dt -from numpydantic.interface import Interface +from numpydantic.interface import Interface, JsonDict from numpydantic.maps import np_to_python from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.vendor.nptyping.structure import StructureMeta if TYPE_CHECKING: # pragma: no cover + from pydantic._internal._schema_generation_shared import ( + CallbackGetCoreSchemaHandler, + ) + from numpydantic import Shape -_handler_type = Callable[[Any], core_schema.CoreSchema] _UNSUPPORTED_TYPES = (complex,) -def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: +def _numeric_dtype( + dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler" +) -> CoreSchema: """Make a numeric dtype that respects min/max values from extended numpy types""" if dtype in (np.number,): dtype = float @@ -36,14 +41,19 @@ def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: elif issubclass(dtype, np.integer): info = np.iinfo(dtype) schema = core_schema.int_schema(le=int(info.max), ge=int(info.min)) - + elif dtype is float: + schema = core_schema.float_schema() + elif dtype is int: + schema = core_schema.int_schema() else: schema = _handler.generate_schema(dtype) return schema -def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: +def _lol_dtype( + dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler" +) -> CoreSchema: """Get the innermost dtype schema to use in the generated pydantic schema""" if isinstance(dtype, StructureMeta): # pragma: no cover raise NotImplementedError("Structured dtypes are currently unsupported") @@ -84,6 +94,10 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: # TODO: warn and log here elif python_type in (float, int): array_type = _numeric_dtype(dtype, _handler) + elif python_type is bool: + array_type = core_schema.bool_schema() + elif python_type is Any: + array_type = core_schema.any_schema() else: array_type = _handler.generate_schema(python_type) @@ -208,14 +222,24 @@ def _unbounded_shape( def make_json_schema( - shape: ShapeType, dtype: DtypeType, _handler: _handler_type + shape: ShapeType, dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler" ) -> ListSchema: """ - Make a list of list JSON schema from a shape and a dtype. + Make a list of list pydantic core schema for an array from a shape and a dtype. + Used to generate JSON schema in the containing model, but not for validation, + which is handled by interfaces. First resolves the dtype into a pydantic ``CoreSchema`` , and then uses that with :func:`.list_of_lists_schema` . + .. admonition:: Potentially Fragile + + Uses a private method from the handler to flatten out nested definitions + (e.g. when dtype is a pydantic model) + so that they are present in the generated schema directly rather than + as references. Otherwise, at the time __get_pydantic_json_schema__ is called, + the definition references are lost. + Args: shape ( ShapeType ): Specification of a shape, as a tuple or an nptyping ``Shape`` @@ -234,6 +258,8 @@ def make_json_schema( else: list_schema = list_of_lists_schema(shape, dtype_schema) + list_schema = _handler._generate_schema.clean_schema(list_schema) + return list_schema @@ -257,4 +283,11 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: """Use an interface class to render an array as JSON""" interface_cls = Interface.match_output(value) - return interface_cls.to_json(value, info) + array = interface_cls.to_json(value, info) + if isinstance(array, JsonDict): + array = array.to_dict() + + if info.context and info.context.get("mark_interface", False): + array = interface_cls.mark_json(array) + + return array diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 9ca9e94..b64d3fe 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -101,7 +101,8 @@ def test_assignment(hdf5_array, model_blank): assert (model.array[1:3, 2:4] == 10).all() -def test_to_json(hdf5_array, array_model): +@pytest.mark.parametrize("round_trip", (True, False)) +def test_to_json(hdf5_array, array_model, round_trip): """ Test serialization of HDF5 arrays to JSON Args: @@ -115,13 +116,13 @@ def test_to_json(hdf5_array, array_model): instance = model(array=array) # type: BaseModel - json_str = instance.model_dump_json() - json_dict = json.loads(json_str)["array"] - - assert json_dict["file"] == str(array.file) - assert json_dict["path"] == str(array.path) - assert json_dict["attrs"] == {} - assert json_dict["array"] == instance.array[:].tolist() + json_str = instance.model_dump_json(round_trip=round_trip) + json_dumped = json.loads(json_str)["array"] + if round_trip: + assert json_dumped["file"] == str(array.file) + assert json_dumped["path"] == str(array.path) + else: + assert json_dumped == instance.array[:].tolist() def test_compound_dtype(tmp_path): diff --git a/tests/test_interface/test_interfaces.py b/tests/test_interface/test_interfaces.py index 3b1370a..3d51ac0 100644 --- a/tests/test_interface/test_interfaces.py +++ b/tests/test_interface/test_interfaces.py @@ -2,8 +2,11 @@ Tests that should be applied to all interfaces """ +import pytest from typing import Callable import numpy as np +import dask.array as da +from zarr.core import Array as ZarrArray from numpydantic.interface import Interface @@ -35,8 +38,31 @@ def test_interface_to_numpy_array(all_interfaces): _ = np.array(all_interfaces.array) +@pytest.mark.serialization def test_interface_dump_json(all_interfaces): """ All interfaces should be able to dump to json """ all_interfaces.model_dump_json() + + +@pytest.mark.serialization +@pytest.mark.parametrize("round_trip", [True, False]) +def test_interface_roundtrip_json(all_interfaces, round_trip): + """ + All interfaces should be able to roundtrip to and from json + """ + json = all_interfaces.model_dump_json(round_trip=round_trip) + model = all_interfaces.model_validate_json(json) + if round_trip: + assert type(model.array) is type(all_interfaces.array) + if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)): + assert np.array_equal(model.array, np.array(all_interfaces.array)) + elif isinstance(all_interfaces.array, da.Array): + assert np.all(da.equal(model.array, all_interfaces.array)) + else: + assert model.array == all_interfaces.array + + assert model.array.dtype == all_interfaces.array.dtype + else: + assert np.array_equal(model.array, np.array(all_interfaces.array)) diff --git a/tests/test_interface/test_zarr.py b/tests/test_interface/test_zarr.py index 2e465f2..05eb8d5 100644 --- a/tests/test_interface/test_zarr.py +++ b/tests/test_interface/test_zarr.py @@ -123,7 +123,10 @@ def test_zarr_array_path_from_iterable(zarr_array): assert apath.path == inner_path -def test_zarr_to_json(store, model_blank): +@pytest.mark.serialization +@pytest.mark.parametrize("dump_array", [True, False]) +@pytest.mark.parametrize("roundtrip", [True, False]) +def test_zarr_to_json(store, model_blank, roundtrip, dump_array): expected_fields = ( "Type", "Data type", @@ -137,17 +140,22 @@ def test_zarr_to_json(store, model_blank): array = zarr.array(lol_array, store=store) instance = model_blank(array=array) - as_json = json.loads(instance.model_dump_json())["array"] - assert "array" not in as_json - for field in expected_fields: - assert field in as_json - assert len(as_json["hexdigest"]) == 40 - # dump the array itself too - as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[ - "array" - ] - for field in expected_fields: - assert field in as_json - assert len(as_json["hexdigest"]) == 40 - assert as_json["array"] == lol_array + context = {"zarr_dump_array": dump_array} + as_json = json.loads( + instance.model_dump_json(round_trip=roundtrip, context=context) + )["array"] + + if roundtrip: + if dump_array: + assert as_json["array"] == lol_array + else: + if as_json.get("file", False): + assert "array" not in as_json + + for field in expected_fields: + assert field in as_json["info"] + assert len(as_json["info"]["hexdigest"]) == 40 + + else: + assert as_json == lol_array diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index f92a66d..cef7dc3 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -15,6 +15,7 @@ from numpydantic import dtype from numpydantic.dtype import Number +@pytest.mark.json_schema def test_ndarray_type(): class Model(BaseModel): array: NDArray[Shape["2 x, * y"], Number] @@ -40,6 +41,7 @@ def test_ndarray_type(): instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5))) +@pytest.mark.json_schema def test_schema_unsupported_type(): """ Complex numbers should just be made with an `any` schema @@ -55,6 +57,7 @@ def test_schema_unsupported_type(): } +@pytest.mark.json_schema def test_schema_tuple(): """ Types specified as tupled should have their schemas as a union @@ -72,6 +75,7 @@ def test_schema_tuple(): assert all([i["minimum"] == 0 for i in conditions]) +@pytest.mark.json_schema def test_schema_number(): """ np.numeric should just be the float schema @@ -164,6 +168,7 @@ def test_ndarray_coercion(): amod = Model(array=["a", "b", "c"]) +@pytest.mark.serialization def test_ndarray_serialize(): """ Arrays should be dumped to a list when using json, but kept as ndarray otherwise @@ -188,6 +193,7 @@ _json_schema_types = [ ] +@pytest.mark.json_schema def test_json_schema_basic(array_model): """ NDArray types should correctly generate a list of lists JSON schema @@ -210,6 +216,8 @@ def test_json_schema_basic(array_model): assert inner["items"]["type"] == "number" +@pytest.mark.dtype +@pytest.mark.json_schema @pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float]) def test_json_schema_dtype_single(dtype, array_model): """ @@ -240,6 +248,7 @@ def test_json_schema_dtype_single(dtype, array_model): ) +@pytest.mark.dtype @pytest.mark.parametrize( "dtype,expected", [ @@ -266,6 +275,8 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model): assert inner_type["type"] == expected +@pytest.mark.dtype +@pytest.mark.json_schema def test_json_schema_dtype_model(): """ Pydantic models can be used in arrays as dtypes @@ -314,6 +325,8 @@ def _recursive_array(schema): assert any_of[1]["minimum"] == 0 +@pytest.mark.shape +@pytest.mark.json_schema def test_json_schema_ellipsis(): """ NDArray types should create a recursive JSON schema for any-shaped arrays