Roundtrip json serialization using dataclasses for all interfaces. Separate JSON schema generation from core_schema generation. More consistent json dumping behavior - always return just the array unless round_trip=True.

This commit is contained in:
sneakers-the-rat 2024-09-20 23:44:59 -07:00
parent 2cb09076fd
commit b436d8d592
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
17 changed files with 532 additions and 99 deletions

View file

@ -1,6 +0,0 @@
# monkeypatch
```{eval-rst}
.. automodule:: numpydantic.monkeypatch
:members:
```

View file

@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
design design
syntax syntax
serialization
interfaces interfaces
todo todo
changelog changelog
@ -489,7 +490,6 @@ api/dtype
api/ndarray api/ndarray
api/maps api/maps
api/meta api/meta
api/monkeypatch
api/schema api/schema
api/shape api/shape
api/types api/types

2
docs/serialization.md Normal file
View file

@ -0,0 +1,2 @@
# Serialization

View file

@ -109,6 +109,12 @@ filterwarnings = [
# nptyping's alias warnings # nptyping's alias warnings
'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*' 'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*'
] ]
markers = [
"dtype: mark test related to dtype validation",
"shape: mark test related to shape validation",
"json_schema: mark test related to json schema generation",
"serialization: mark test related to serialization"
]
[tool.ruff] [tool.ruff]
target-version = "py311" target-version = "py311"

View file

@ -4,12 +4,13 @@ Interfaces between nptyping types and array backends
from numpydantic.interface.dask import DaskInterface from numpydantic.interface.dask import DaskInterface
from numpydantic.interface.hdf5 import H5Interface from numpydantic.interface.hdf5 import H5Interface
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.interface.numpy import NumpyInterface from numpydantic.interface.numpy import NumpyInterface
from numpydantic.interface.video import VideoInterface from numpydantic.interface.video import VideoInterface
from numpydantic.interface.zarr import ZarrInterface from numpydantic.interface.zarr import ZarrInterface
__all__ = [ __all__ = [
"JsonDict",
"Interface", "Interface",
"DaskInterface", "DaskInterface",
"H5Interface", "H5Interface",

View file

@ -2,26 +2,59 @@
Interface for Dask arrays Interface for Dask arrays
""" """
from typing import Any, Optional from dataclasses import dataclass
from typing import Any, Iterable, Literal, Optional
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType from numpydantic.types import DtypeType, NDArrayType
try: try:
from dask.array import from_array
from dask.array.core import Array as DaskArray from dask.array.core import Array as DaskArray
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
DaskArray = None DaskArray = None
def _as_tuple(a_list: list | Any) -> tuple:
"""Make a list of list into a tuple of tuples"""
return tuple(
[_as_tuple(item) if isinstance(item, list) else item for item in a_list]
)
@dataclass(kw_only=True)
class DaskJsonDict(JsonDict):
"""
Round-trip json serialized form of a dask array
"""
type: Literal["dask"]
name: str
chunks: Iterable[tuple[int, ...]]
dtype: str
array: list
def to_array_input(self) -> DaskArray:
"""Construct a dask array"""
np_array = np.array(self.array, dtype=self.dtype)
array = from_array(
np_array,
name=self.name,
chunks=_as_tuple(self.chunks),
)
return array
class DaskInterface(Interface): class DaskInterface(Interface):
""" """
Interface for Dask :class:`~dask.array.core.Array` Interface for Dask :class:`~dask.array.core.Array`
""" """
input_types = (DaskArray,) name = "dask"
input_types = (DaskArray, dict)
return_type = DaskArray return_type = DaskArray
@classmethod @classmethod
@ -29,7 +62,24 @@ class DaskInterface(Interface):
""" """
check if array is a dask array check if array is a dask array
""" """
return DaskArray is not None and isinstance(array, DaskArray) if DaskArray is None:
return False
elif isinstance(array, DaskArray):
return True
elif isinstance(array, dict):
return DaskJsonDict.is_valid(array)
else:
return False
def before_validation(self, array: Any) -> DaskArray:
"""
If given a dict (like that from ``model_dump_json(round_trip=True)`` ),
re-cast to dask array
"""
if isinstance(array, dict):
array = DaskJsonDict(**array).to_array_input()
return array
def get_object_dtype(self, array: NDArrayType) -> DtypeType: def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""Dask arrays require a compute() call to retrieve a single value""" """Dask arrays require a compute() call to retrieve a single value"""
@ -43,7 +93,7 @@ class DaskInterface(Interface):
@classmethod @classmethod
def to_json( def to_json(
cls, array: DaskArray, info: Optional[SerializationInfo] = None cls, array: DaskArray, info: Optional[SerializationInfo] = None
) -> list: ) -> list | DaskJsonDict:
""" """
Convert an array to a JSON serializable array by first converting to a numpy Convert an array to a JSON serializable array by first converting to a numpy
array and then to a list. array and then to a list.
@ -56,4 +106,14 @@ class DaskInterface(Interface):
method of serialization here using the python object itself rather than method of serialization here using the python object itself rather than
its JSON representation. its JSON representation.
""" """
return np.array(array).tolist() np_array = np.array(array)
as_json = np_array.tolist()
if info.round_trip:
as_json = DaskJsonDict(
type=cls.name,
array=as_json,
name=array.name,
chunks=array.chunks,
dtype=str(np_array.dtype),
)
return as_json

View file

@ -40,6 +40,7 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
""" """
import sys import sys
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
@ -47,7 +48,7 @@ from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Un
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType from numpydantic.types import DtypeType, NDArrayType
try: try:
@ -76,6 +77,21 @@ class H5ArrayPath(NamedTuple):
"""Refer to a specific field within a compound dtype""" """Refer to a specific field within a compound dtype"""
@dataclass
class H5JsonDict(JsonDict):
"""Round-trip Json-able version of an HDF5 dataset"""
file: str
path: str
field: Optional[str] = None
def to_array_input(self) -> H5ArrayPath:
"""Construct an :class:`.H5ArrayPath`"""
return H5ArrayPath(
**{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields}
)
class H5Proxy: class H5Proxy:
""" """
Proxy class to mimic numpy-like array behavior with an HDF5 array Proxy class to mimic numpy-like array behavior with an HDF5 array
@ -110,6 +126,7 @@ class H5Proxy:
self.path = path self.path = path
self.field = field self.field = field
self._annotation_dtype = annotation_dtype self._annotation_dtype = annotation_dtype
self._h5arraypath = H5ArrayPath(self.file, self.path, self.field)
def array_exists(self) -> bool: def array_exists(self) -> bool:
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`""" """Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
@ -212,6 +229,15 @@ class H5Proxy:
"""self.shape[0]""" """self.shape[0]"""
return self.shape[0] return self.shape[0]
def __eq__(self, other: "H5Proxy") -> bool:
"""
Check that we are referring to the same hdf5 array
"""
if isinstance(other, H5Proxy):
return self._h5arraypath == other._h5arraypath
else:
raise ValueError("Can only compare equality of two H5Proxies")
def open(self, mode: str = "r") -> "h5py.Dataset": def open(self, mode: str = "r") -> "h5py.Dataset":
""" """
Return the opened :class:`h5py.Dataset` object Return the opened :class:`h5py.Dataset` object
@ -251,6 +277,7 @@ class H5Interface(Interface):
passthrough numpy-like interface to the dataset. passthrough numpy-like interface to the dataset.
""" """
name = "hdf5"
input_types = (H5ArrayPath, H5Arraylike, H5Proxy) input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
return_type = H5Proxy return_type = H5Proxy
@ -268,6 +295,13 @@ class H5Interface(Interface):
if isinstance(array, (H5ArrayPath, H5Proxy)): if isinstance(array, (H5ArrayPath, H5Proxy)):
return True return True
if isinstance(array, dict):
if array.get("type", False) == cls.name:
return True
# continue checking if dict contains an hdf5 file
file = array.get("file", "")
array = (file, "")
if isinstance(array, (tuple, list)) and len(array) in (2, 3): if isinstance(array, (tuple, list)) and len(array) in (2, 3):
# check that the first arg is an hdf5 file # check that the first arg is an hdf5 file
try: try:
@ -294,6 +328,9 @@ class H5Interface(Interface):
def before_validation(self, array: Any) -> NDArrayType: def before_validation(self, array: Any) -> NDArrayType:
"""Create an :class:`.H5Proxy` to use throughout validation""" """Create an :class:`.H5Proxy` to use throughout validation"""
if isinstance(array, dict):
array = H5JsonDict(**array).to_array_input()
if isinstance(array, H5ArrayPath): if isinstance(array, H5ArrayPath):
array = H5Proxy.from_h5array(h5array=array) array = H5Proxy.from_h5array(h5array=array)
elif isinstance(array, H5Proxy): elif isinstance(array, H5Proxy):
@ -349,21 +386,27 @@ class H5Interface(Interface):
@classmethod @classmethod
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict: def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
""" """
Dump to a dictionary containing Render HDF5 array as JSON
If ``round_trip == True``, we dump just the proxy info, a dictionary like:
* ``file``: :attr:`.file` * ``file``: :attr:`.file`
* ``path``: :attr:`.path` * ``path``: :attr:`.path`
* ``attrs``: Any HDF5 attributes on the dataset * ``attrs``: Any HDF5 attributes on the dataset
* ``array``: The array as a list of lists * ``array``: The array as a list of lists
Otherwise, we dump the array as a list of lists
""" """
try: if info.round_trip:
dset = array.open() as_json = {
meta = { "type": cls.name,
"file": array.file,
"path": array.path,
"attrs": dict(dset.attrs),
"array": dset[:].tolist(),
} }
return meta as_json.update(array._h5arraypath._asdict())
finally: else:
array.close() try:
dset = array.open()
as_json = dset[:].tolist()
finally:
array.close()
return as_json

View file

@ -2,12 +2,15 @@
Base Interface metaclass Base Interface metaclass
""" """
import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from importlib.metadata import PackageNotFoundError, version
from operator import attrgetter from operator import attrgetter
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo, TypeAdapter, ValidationError
from numpydantic.exceptions import ( from numpydantic.exceptions import (
DtypeError, DtypeError,
@ -21,6 +24,60 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
T = TypeVar("T", bound=NDArrayType) T = TypeVar("T", bound=NDArrayType)
class InterfaceMark(TypedDict):
"""JSON-able mark to be able to round-trip json dumps"""
module: str
cls: str
version: str
@dataclass(kw_only=True)
class JsonDict:
"""
Representation of array when dumped with round_trip == True.
Using a dataclass rather than a pydantic model to not tempt
us to use more sophisticated types than can be serialized to json.
"""
type: str
@abstractmethod
def to_array_input(self) -> Any:
"""
Convert this roundtrip specifier to the relevant input class
(one of the ``input_types`` of an interface).
"""
def to_dict(self) -> dict:
"""
Convenience method for casting dataclass to dict,
removing None-valued items
"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def get_adapter(cls) -> TypeAdapter:
"""Convenience method to get a typeadapter for this class"""
return TypeAdapter(cls)
@classmethod
def is_valid(cls, val: dict) -> bool:
"""
Check whether a given dictionary matches this JsonDict specification
Returns:
bool - true if valid, false if not
"""
adapter = cls.get_adapter()
try:
_ = adapter.validate_python(val)
return True
except ValidationError:
return False
class Interface(ABC, Generic[T]): class Interface(ABC, Generic[T]):
""" """
Abstract parent class for interfaces to different array formats Abstract parent class for interfaces to different array formats
@ -30,7 +87,7 @@ class Interface(ABC, Generic[T]):
return_type: Type[T] return_type: Type[T]
priority: int = 0 priority: int = 0
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None: def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
self.shape = shape self.shape = shape
self.dtype = dtype self.dtype = dtype
@ -86,6 +143,7 @@ class Interface(ABC, Generic[T]):
self.raise_for_shape(shape_valid, shape) self.raise_for_shape(shape_valid, shape)
array = self.after_validation(array) array = self.after_validation(array)
return array return array
def before_validation(self, array: Any) -> NDArrayType: def before_validation(self, array: Any) -> NDArrayType:
@ -117,8 +175,6 @@ class Interface(ABC, Generic[T]):
""" """
Validate the dtype of the given array, returning Validate the dtype of the given array, returning
``True`` if valid, ``False`` if not. ``True`` if valid, ``False`` if not.
""" """
if self.dtype is Any: if self.dtype is Any:
return True return True
@ -196,6 +252,13 @@ class Interface(ABC, Generic[T]):
""" """
return array return array
def mark_input(self, array: Any) -> Any:
"""
Preserve metadata about the interface and passed input when dumping with
``round_trip``
"""
return array
@classmethod @classmethod
@abstractmethod @abstractmethod
def check(cls, array: Any) -> bool: def check(cls, array: Any) -> bool:
@ -211,17 +274,40 @@ class Interface(ABC, Generic[T]):
installed, etc.) installed, etc.)
""" """
@property
@abstractmethod
def name(self) -> str:
"""
Short name for this interface
"""
@classmethod @classmethod
def to_json( @abstractmethod
cls, array: Type[T], info: Optional[SerializationInfo] = None def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
) -> Union[list, dict]:
""" """
Convert an array of :attr:`.return_type` to a JSON-compatible format using Convert an array of :attr:`.return_type` to a JSON-compatible format using
base python types base python types
""" """
if not isinstance(array, np.ndarray): # pragma: no cover
array = np.array(array) @classmethod
return array.tolist() def mark_json(cls, array: Union[list, dict]) -> dict:
"""
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
add additional annotations that would allow the serialized array to be
roundtripped.
Default is just to add an :class:`.InterfaceMark`
Examples:
>>> from pprint import pprint
>>> pprint(Interface.mark_json([1.0, 2.0]))
{'interface': {'cls': 'Interface',
'module': 'numpydantic.interface.interface',
'version': '1.2.2'},
'value': [1.0, 2.0]}
"""
return {"interface": cls.mark_interface(), "value": array}
@classmethod @classmethod
def interfaces( def interfaces(
@ -335,3 +421,24 @@ class Interface(ABC, Generic[T]):
raise NoMatchError(f"No matching interfaces found for output {array}") raise NoMatchError(f"No matching interfaces found for output {array}")
else: else:
return matches[0] return matches[0]
@classmethod
def mark_interface(cls) -> InterfaceMark:
"""
Create an interface mark indicating this interface for validation after
JSON serialization with ``round_trip==True``
"""
interface_module = inspect.getmodule(cls)
interface_module = (
None if interface_module is None else interface_module.__name__
)
try:
v = (
None
if interface_module is None
else version(interface_module.split(".")[0])
)
except PackageNotFoundError:
v = None
interface_name = cls.__name__
return InterfaceMark(module=interface_module, cls=interface_name, version=v)

View file

@ -2,9 +2,12 @@
Interface to numpy arrays Interface to numpy arrays
""" """
from typing import Any from dataclasses import dataclass
from typing import Any, Literal, Union
from numpydantic.interface.interface import Interface from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface, JsonDict
try: try:
import numpy as np import numpy as np
@ -18,11 +21,29 @@ except ImportError: # pragma: no cover
np = None np = None
@dataclass
class NumpyJsonDict(JsonDict):
"""
JSON-able roundtrip representation of numpy array
"""
type: Literal["numpy"]
dtype: str
array: list
def to_array_input(self) -> ndarray:
"""
Construct a numpy array
"""
return np.array(self.array, dtype=self.dtype)
class NumpyInterface(Interface): class NumpyInterface(Interface):
""" """
Numpy :class:`~numpy.ndarray` s! Numpy :class:`~numpy.ndarray` s!
""" """
name = "numpy"
input_types = (ndarray, list) input_types = (ndarray, list)
return_type = ndarray return_type = ndarray
priority = -999 priority = -999
@ -41,6 +62,8 @@ class NumpyInterface(Interface):
""" """
if isinstance(array, ndarray): if isinstance(array, ndarray):
return True return True
elif isinstance(array, dict):
return NumpyJsonDict.is_valid(array)
else: else:
try: try:
_ = np.array(array) _ = np.array(array)
@ -53,6 +76,9 @@ class NumpyInterface(Interface):
Coerce to an ndarray. We have already checked if coercion is possible Coerce to an ndarray. We have already checked if coercion is possible
in :meth:`.check` in :meth:`.check`
""" """
if isinstance(array, dict):
array = NumpyJsonDict(**array).to_array_input()
if not isinstance(array, ndarray): if not isinstance(array, ndarray):
array = np.array(array) array = np.array(array)
return array return array
@ -61,3 +87,22 @@ class NumpyInterface(Interface):
def enabled(cls) -> bool: def enabled(cls) -> bool:
"""Check that numpy is present in the environment""" """Check that numpy is present in the environment"""
return ENABLED return ENABLED
@classmethod
def to_json(
cls, array: ndarray, info: SerializationInfo = None
) -> Union[list, JsonDict]:
"""
Convert an array of :attr:`.return_type` to a JSON-compatible format using
base python types
"""
if not isinstance(array, np.ndarray): # pragma: no cover
array = np.array(array)
json_array = array.tolist()
if info.round_trip:
json_array = NumpyJsonDict(
type=cls.name, dtype=str(array.dtype), array=json_array
)
return json_array

View file

@ -2,11 +2,14 @@
Interface to support treating videos like arrays using OpenCV Interface to support treating videos like arrays using OpenCV
""" """
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Tuple, Union from typing import Any, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
from pydantic_core.core_schema import SerializationInfo
from numpydantic.interface import JsonDict
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface
try: try:
@ -19,6 +22,20 @@ except ImportError: # pragma: no cover
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
@dataclass(kw_only=True)
class VideoJsonDict(JsonDict):
"""Json-able roundtrip representation of a video file"""
type: Literal["video"]
file: str
def to_array_input(self) -> "VideoProxy":
"""
Construct a :class:`.VideoProxy`
"""
return VideoProxy(path=Path(self.file))
class VideoProxy: class VideoProxy:
""" """
Passthrough proxy class to interact with videos as arrays Passthrough proxy class to interact with videos as arrays
@ -184,6 +201,12 @@ class VideoProxy:
def __getattr__(self, item: str): def __getattr__(self, item: str):
return getattr(self.video, item) return getattr(self.video, item)
def __eq__(self, other: "VideoProxy") -> bool:
"""Check if this is a proxy to the same video file"""
if not isinstance(other, VideoProxy):
raise TypeError("Can only compare equality of two VideoProxies")
return self.path == other.path
def __len__(self) -> int: def __len__(self) -> int:
"""Number of frames in the video""" """Number of frames in the video"""
return self.shape[0] return self.shape[0]
@ -194,6 +217,7 @@ class VideoInterface(Interface):
OpenCV interface to treat videos as arrays. OpenCV interface to treat videos as arrays.
""" """
name = "video"
input_types = (str, Path, VideoCapture, VideoProxy) input_types = (str, Path, VideoCapture, VideoProxy)
return_type = VideoProxy return_type = VideoProxy
@ -213,6 +237,9 @@ class VideoInterface(Interface):
): ):
return True return True
if isinstance(array, dict):
array = array.get("file", "")
if isinstance(array, str): if isinstance(array, str):
try: try:
array = Path(array) array = Path(array)
@ -224,10 +251,22 @@ class VideoInterface(Interface):
def before_validation(self, array: Any) -> VideoProxy: def before_validation(self, array: Any) -> VideoProxy:
"""Get a :class:`.VideoProxy` object for this video""" """Get a :class:`.VideoProxy` object for this video"""
if isinstance(array, VideoCapture): if isinstance(array, dict):
proxy = VideoJsonDict(**array).to_array_input()
elif isinstance(array, VideoCapture):
proxy = VideoProxy(video=array) proxy = VideoProxy(video=array)
elif isinstance(array, VideoProxy): elif isinstance(array, VideoProxy):
proxy = array proxy = array
else: else:
proxy = VideoProxy(path=array) proxy = VideoProxy(path=array)
return proxy return proxy
@classmethod
def to_json(
cls, array: VideoProxy, info: SerializationInfo
) -> Union[list, VideoJsonDict]:
"""Return a json-representation of a video"""
if info.round_trip:
return VideoJsonDict(type=cls.name, file=str(array.path))
else:
return np.array(array).tolist()

View file

@ -5,12 +5,12 @@ Interface to zarr arrays
import contextlib import contextlib
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Sequence, Union from typing import Any, Literal, Optional, Sequence, Union
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType from numpydantic.types import DtypeType
try: try:
@ -56,11 +56,34 @@ class ZarrArrayPath:
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath") raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
@dataclass(kw_only=True)
class ZarrJsonDict(JsonDict):
"""Round-trip Json-able version of a Zarr Array"""
info: dict[str, str]
type: Literal["zarr"]
file: Optional[str] = None
path: Optional[str] = None
array: Optional[list] = None
def to_array_input(self) -> ZarrArray | ZarrArrayPath:
"""
Construct a ZarrArrayPath if file and path are present,
otherwise a ZarrArray
"""
if self.file:
array = ZarrArrayPath(file=self.file, path=self.path)
else:
array = zarr.array(self.array)
return array
class ZarrInterface(Interface): class ZarrInterface(Interface):
""" """
Interface to in-memory or on-disk zarr arrays Interface to in-memory or on-disk zarr arrays
""" """
name = "zarr"
input_types = (Path, ZarrArray, ZarrArrayPath) input_types = (Path, ZarrArray, ZarrArrayPath)
return_type = ZarrArray return_type = ZarrArray
@ -73,6 +96,9 @@ class ZarrInterface(Interface):
def _get_array( def _get_array(
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence] array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
) -> ZarrArray: ) -> ZarrArray:
if isinstance(array, dict):
array = ZarrJsonDict(**array).to_array_input()
if isinstance(array, ZarrArray): if isinstance(array, ZarrArray):
return array return array
@ -92,6 +118,12 @@ class ZarrInterface(Interface):
if isinstance(array, ZarrArray): if isinstance(array, ZarrArray):
return True return True
if isinstance(array, dict):
if array.get("type", False) == cls.name:
return True
# continue checking if dict contains a zarr file
array = array.get("file", "")
# See if can be coerced to ZarrArrayPath # See if can be coerced to ZarrArrayPath
if isinstance(array, (Path, str)): if isinstance(array, (Path, str)):
array = ZarrArrayPath(file=array) array = ZarrArrayPath(file=array)
@ -135,26 +167,46 @@ class ZarrInterface(Interface):
cls, cls,
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence], array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence],
info: Optional[SerializationInfo] = None, info: Optional[SerializationInfo] = None,
) -> dict: ) -> list | ZarrJsonDict:
""" """
Dump just the metadata for an array from :meth:`zarr.core.Array.info_items` Dump a Zarr Array to JSON
plus the :meth:`zarr.core.Array.hexdigest`.
If ``info.round_trip == False``, dump the array as a list of lists.
This may be a memory-intensive operation.
Otherwise, dump the metadata for an array from :meth:`zarr.core.Array.info_items`
plus the :meth:`zarr.core.Array.hexdigest` as a :class:`.ZarrJsonDict`
If either the ``zarr_dump_array`` value in the context dictionary is ``True``
or the zarr array is an in-memory array, dump the array as well
(since without a persistent array it would be impossible to roundtrip and
dumping to JSON would be meaningless)
The full array can be returned by passing ``'zarr_dump_array': True`` to the Passing ``'zarr_dump_array': True`` to the serialization ``context`` looks like this::
serialization ``context`` ::
model.model_dump_json(context={'zarr_dump_array': True}) model.model_dump_json(context={'zarr_dump_array': True})
""" """
dump_array = False
if info is not None and info.context is not None:
dump_array = info.context.get("zarr_dump_array", False)
array = cls._get_array(array) array = cls._get_array(array)
info = array.info_items()
info_dict = {i[0]: i[1] for i in info}
info_dict["hexdigest"] = array.hexdigest()
if dump_array: if info.round_trip:
info_dict["array"] = array[:].tolist() dump_array = False
if info is not None and info.context is not None:
dump_array = info.context.get("zarr_dump_array", False)
is_file = False
return info_dict as_json = {"type": cls.name}
if hasattr(array.store, "dir_path"):
is_file = True
as_json["file"] = array.store.dir_path()
as_json["path"] = array.name
as_json["info"] = {i[0]: i[1] for i in array.info_items()}
as_json["info"]["hexdigest"] = array.hexdigest()
if dump_array or not is_file:
as_json["array"] = array[:].tolist()
as_json = ZarrJsonDict(**as_json)
else:
as_json = array[:].tolist()
return as_json

View file

@ -24,7 +24,6 @@ from numpydantic.exceptions import InterfaceError
from numpydantic.interface import Interface from numpydantic.interface import Interface
from numpydantic.maps import python_to_nptyping from numpydantic.maps import python_to_nptyping
from numpydantic.schema import ( from numpydantic.schema import (
_handler_type,
_jsonize_array, _jsonize_array,
get_validate_interface, get_validate_interface,
make_json_schema, make_json_schema,
@ -41,6 +40,9 @@ from numpydantic.vendor.nptyping.typing_ import (
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from nptyping.base_meta_classes import SubscriptableMeta from nptyping.base_meta_classes import SubscriptableMeta
from pydantic._internal._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from numpydantic import Shape from numpydantic import Shape
@ -164,33 +166,34 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
def __get_pydantic_core_schema__( def __get_pydantic_core_schema__(
cls, cls,
_source_type: "NDArray", _source_type: "NDArray",
_handler: _handler_type, _handler: "CallbackGetCoreSchemaHandler",
) -> core_schema.CoreSchema: ) -> core_schema.CoreSchema:
shape, dtype = _source_type.__args__ shape, dtype = _source_type.__args__
shape: ShapeType shape: ShapeType
dtype: DtypeType dtype: DtypeType
# get pydantic core schema as a list of lists for JSON schema # make core schema for json schema, store it and any model definitions
list_schema = make_json_schema(shape, dtype, _handler) # note that there is a big of fragility in this function,
# as we need to access a private method of _handler to
# flatten out the json schema. See help(make_json_schema)
json_schema = make_json_schema(shape, dtype, _handler)
return core_schema.json_or_python_schema( return core_schema.with_info_plain_validator_function(
json_schema=list_schema, get_validate_interface(shape, dtype),
python_schema=core_schema.with_info_plain_validator_function(
get_validate_interface(shape, dtype)
),
serialization=core_schema.plain_serializer_function_ser_schema( serialization=core_schema.plain_serializer_function_ser_schema(
_jsonize_array, when_used="json", info_arg=True _jsonize_array, when_used="json", info_arg=True
), ),
metadata=json_schema,
) )
@classmethod @classmethod
def __get_pydantic_json_schema__( def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> core_schema.JsonSchema: ) -> core_schema.JsonSchema:
json_schema = handler(schema) shape, dtype = cls.__args__
json_schema = handler(schema["metadata"])
json_schema = handler.resolve_ref_schema(json_schema) json_schema = handler.resolve_ref_schema(json_schema)
dtype = cls.__args__[1]
if not isinstance(dtype, tuple) and dtype.__module__ not in ( if not isinstance(dtype, tuple) and dtype.__module__ not in (
"builtins", "builtins",
"typing", "typing",

View file

@ -13,19 +13,24 @@ from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo from pydantic_core.core_schema import ListSchema, ValidationInfo
from numpydantic import dtype as dt from numpydantic import dtype as dt
from numpydantic.interface import Interface from numpydantic.interface import Interface, JsonDict
from numpydantic.maps import np_to_python from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.vendor.nptyping.structure import StructureMeta from numpydantic.vendor.nptyping.structure import StructureMeta
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from pydantic._internal._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from numpydantic import Shape from numpydantic import Shape
_handler_type = Callable[[Any], core_schema.CoreSchema]
_UNSUPPORTED_TYPES = (complex,) _UNSUPPORTED_TYPES = (complex,)
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: def _numeric_dtype(
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
) -> CoreSchema:
"""Make a numeric dtype that respects min/max values from extended numpy types""" """Make a numeric dtype that respects min/max values from extended numpy types"""
if dtype in (np.number,): if dtype in (np.number,):
dtype = float dtype = float
@ -36,14 +41,19 @@ def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
elif issubclass(dtype, np.integer): elif issubclass(dtype, np.integer):
info = np.iinfo(dtype) info = np.iinfo(dtype)
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min)) schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
elif dtype is float:
schema = core_schema.float_schema()
elif dtype is int:
schema = core_schema.int_schema()
else: else:
schema = _handler.generate_schema(dtype) schema = _handler.generate_schema(dtype)
return schema return schema
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: def _lol_dtype(
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
) -> CoreSchema:
"""Get the innermost dtype schema to use in the generated pydantic schema""" """Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, StructureMeta): # pragma: no cover if isinstance(dtype, StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported") raise NotImplementedError("Structured dtypes are currently unsupported")
@ -84,6 +94,10 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
# TODO: warn and log here # TODO: warn and log here
elif python_type in (float, int): elif python_type in (float, int):
array_type = _numeric_dtype(dtype, _handler) array_type = _numeric_dtype(dtype, _handler)
elif python_type is bool:
array_type = core_schema.bool_schema()
elif python_type is Any:
array_type = core_schema.any_schema()
else: else:
array_type = _handler.generate_schema(python_type) array_type = _handler.generate_schema(python_type)
@ -208,14 +222,24 @@ def _unbounded_shape(
def make_json_schema( def make_json_schema(
shape: ShapeType, dtype: DtypeType, _handler: _handler_type shape: ShapeType, dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
) -> ListSchema: ) -> ListSchema:
""" """
Make a list of list JSON schema from a shape and a dtype. Make a list of list pydantic core schema for an array from a shape and a dtype.
Used to generate JSON schema in the containing model, but not for validation,
which is handled by interfaces.
First resolves the dtype into a pydantic ``CoreSchema`` , First resolves the dtype into a pydantic ``CoreSchema`` ,
and then uses that with :func:`.list_of_lists_schema` . and then uses that with :func:`.list_of_lists_schema` .
.. admonition:: Potentially Fragile
Uses a private method from the handler to flatten out nested definitions
(e.g. when dtype is a pydantic model)
so that they are present in the generated schema directly rather than
as references. Otherwise, at the time __get_pydantic_json_schema__ is called,
the definition references are lost.
Args: Args:
shape ( ShapeType ): Specification of a shape, as a tuple or shape ( ShapeType ): Specification of a shape, as a tuple or
an nptyping ``Shape`` an nptyping ``Shape``
@ -234,6 +258,8 @@ def make_json_schema(
else: else:
list_schema = list_of_lists_schema(shape, dtype_schema) list_schema = list_of_lists_schema(shape, dtype_schema)
list_schema = _handler._generate_schema.clean_schema(list_schema)
return list_schema return list_schema
@ -257,4 +283,11 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON""" """Use an interface class to render an array as JSON"""
interface_cls = Interface.match_output(value) interface_cls = Interface.match_output(value)
return interface_cls.to_json(value, info) array = interface_cls.to_json(value, info)
if isinstance(array, JsonDict):
array = array.to_dict()
if info.context and info.context.get("mark_interface", False):
array = interface_cls.mark_json(array)
return array

View file

@ -101,7 +101,8 @@ def test_assignment(hdf5_array, model_blank):
assert (model.array[1:3, 2:4] == 10).all() assert (model.array[1:3, 2:4] == 10).all()
def test_to_json(hdf5_array, array_model): @pytest.mark.parametrize("round_trip", (True, False))
def test_to_json(hdf5_array, array_model, round_trip):
""" """
Test serialization of HDF5 arrays to JSON Test serialization of HDF5 arrays to JSON
Args: Args:
@ -115,13 +116,13 @@ def test_to_json(hdf5_array, array_model):
instance = model(array=array) # type: BaseModel instance = model(array=array) # type: BaseModel
json_str = instance.model_dump_json() json_str = instance.model_dump_json(round_trip=round_trip)
json_dict = json.loads(json_str)["array"] json_dumped = json.loads(json_str)["array"]
if round_trip:
assert json_dict["file"] == str(array.file) assert json_dumped["file"] == str(array.file)
assert json_dict["path"] == str(array.path) assert json_dumped["path"] == str(array.path)
assert json_dict["attrs"] == {} else:
assert json_dict["array"] == instance.array[:].tolist() assert json_dumped == instance.array[:].tolist()
def test_compound_dtype(tmp_path): def test_compound_dtype(tmp_path):

View file

@ -2,8 +2,11 @@
Tests that should be applied to all interfaces Tests that should be applied to all interfaces
""" """
import pytest
from typing import Callable from typing import Callable
import numpy as np import numpy as np
import dask.array as da
from zarr.core import Array as ZarrArray
from numpydantic.interface import Interface from numpydantic.interface import Interface
@ -35,8 +38,31 @@ def test_interface_to_numpy_array(all_interfaces):
_ = np.array(all_interfaces.array) _ = np.array(all_interfaces.array)
@pytest.mark.serialization
def test_interface_dump_json(all_interfaces): def test_interface_dump_json(all_interfaces):
""" """
All interfaces should be able to dump to json All interfaces should be able to dump to json
""" """
all_interfaces.model_dump_json() all_interfaces.model_dump_json()
@pytest.mark.serialization
@pytest.mark.parametrize("round_trip", [True, False])
def test_interface_roundtrip_json(all_interfaces, round_trip):
"""
All interfaces should be able to roundtrip to and from json
"""
json = all_interfaces.model_dump_json(round_trip=round_trip)
model = all_interfaces.model_validate_json(json)
if round_trip:
assert type(model.array) is type(all_interfaces.array)
if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)):
assert np.array_equal(model.array, np.array(all_interfaces.array))
elif isinstance(all_interfaces.array, da.Array):
assert np.all(da.equal(model.array, all_interfaces.array))
else:
assert model.array == all_interfaces.array
assert model.array.dtype == all_interfaces.array.dtype
else:
assert np.array_equal(model.array, np.array(all_interfaces.array))

View file

@ -123,7 +123,10 @@ def test_zarr_array_path_from_iterable(zarr_array):
assert apath.path == inner_path assert apath.path == inner_path
def test_zarr_to_json(store, model_blank): @pytest.mark.serialization
@pytest.mark.parametrize("dump_array", [True, False])
@pytest.mark.parametrize("roundtrip", [True, False])
def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
expected_fields = ( expected_fields = (
"Type", "Type",
"Data type", "Data type",
@ -137,17 +140,22 @@ def test_zarr_to_json(store, model_blank):
array = zarr.array(lol_array, store=store) array = zarr.array(lol_array, store=store)
instance = model_blank(array=array) instance = model_blank(array=array)
as_json = json.loads(instance.model_dump_json())["array"]
assert "array" not in as_json
for field in expected_fields:
assert field in as_json
assert len(as_json["hexdigest"]) == 40
# dump the array itself too context = {"zarr_dump_array": dump_array}
as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[ as_json = json.loads(
"array" instance.model_dump_json(round_trip=roundtrip, context=context)
] )["array"]
for field in expected_fields:
assert field in as_json if roundtrip:
assert len(as_json["hexdigest"]) == 40 if dump_array:
assert as_json["array"] == lol_array assert as_json["array"] == lol_array
else:
if as_json.get("file", False):
assert "array" not in as_json
for field in expected_fields:
assert field in as_json["info"]
assert len(as_json["info"]["hexdigest"]) == 40
else:
assert as_json == lol_array

View file

@ -15,6 +15,7 @@ from numpydantic import dtype
from numpydantic.dtype import Number from numpydantic.dtype import Number
@pytest.mark.json_schema
def test_ndarray_type(): def test_ndarray_type():
class Model(BaseModel): class Model(BaseModel):
array: NDArray[Shape["2 x, * y"], Number] array: NDArray[Shape["2 x, * y"], Number]
@ -40,6 +41,7 @@ def test_ndarray_type():
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5))) instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
@pytest.mark.json_schema
def test_schema_unsupported_type(): def test_schema_unsupported_type():
""" """
Complex numbers should just be made with an `any` schema Complex numbers should just be made with an `any` schema
@ -55,6 +57,7 @@ def test_schema_unsupported_type():
} }
@pytest.mark.json_schema
def test_schema_tuple(): def test_schema_tuple():
""" """
Types specified as tupled should have their schemas as a union Types specified as tupled should have their schemas as a union
@ -72,6 +75,7 @@ def test_schema_tuple():
assert all([i["minimum"] == 0 for i in conditions]) assert all([i["minimum"] == 0 for i in conditions])
@pytest.mark.json_schema
def test_schema_number(): def test_schema_number():
""" """
np.numeric should just be the float schema np.numeric should just be the float schema
@ -164,6 +168,7 @@ def test_ndarray_coercion():
amod = Model(array=["a", "b", "c"]) amod = Model(array=["a", "b", "c"])
@pytest.mark.serialization
def test_ndarray_serialize(): def test_ndarray_serialize():
""" """
Arrays should be dumped to a list when using json, but kept as ndarray otherwise Arrays should be dumped to a list when using json, but kept as ndarray otherwise
@ -188,6 +193,7 @@ _json_schema_types = [
] ]
@pytest.mark.json_schema
def test_json_schema_basic(array_model): def test_json_schema_basic(array_model):
""" """
NDArray types should correctly generate a list of lists JSON schema NDArray types should correctly generate a list of lists JSON schema
@ -210,6 +216,8 @@ def test_json_schema_basic(array_model):
assert inner["items"]["type"] == "number" assert inner["items"]["type"] == "number"
@pytest.mark.dtype
@pytest.mark.json_schema
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float]) @pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
def test_json_schema_dtype_single(dtype, array_model): def test_json_schema_dtype_single(dtype, array_model):
""" """
@ -240,6 +248,7 @@ def test_json_schema_dtype_single(dtype, array_model):
) )
@pytest.mark.dtype
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype,expected", "dtype,expected",
[ [
@ -266,6 +275,8 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
assert inner_type["type"] == expected assert inner_type["type"] == expected
@pytest.mark.dtype
@pytest.mark.json_schema
def test_json_schema_dtype_model(): def test_json_schema_dtype_model():
""" """
Pydantic models can be used in arrays as dtypes Pydantic models can be used in arrays as dtypes
@ -314,6 +325,8 @@ def _recursive_array(schema):
assert any_of[1]["minimum"] == 0 assert any_of[1]["minimum"] == 0
@pytest.mark.shape
@pytest.mark.json_schema
def test_json_schema_ellipsis(): def test_json_schema_ellipsis():
""" """
NDArray types should create a recursive JSON schema for any-shaped arrays NDArray types should create a recursive JSON schema for any-shaped arrays