Roundtrip json serialization using dataclasses for all interfaces. Separate JSON schema generation from core_schema generation. More consistent json dumping behavior - always return just the array unless round_trip=True.

This commit is contained in:
sneakers-the-rat 2024-09-20 23:44:59 -07:00
parent 2cb09076fd
commit b436d8d592
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
17 changed files with 532 additions and 99 deletions

View file

@ -1,6 +0,0 @@
# monkeypatch
```{eval-rst}
.. automodule:: numpydantic.monkeypatch
:members:
```

View file

@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
design
syntax
serialization
interfaces
todo
changelog
@ -489,7 +490,6 @@ api/dtype
api/ndarray
api/maps
api/meta
api/monkeypatch
api/schema
api/shape
api/types

2
docs/serialization.md Normal file
View file

@ -0,0 +1,2 @@
# Serialization

View file

@ -109,6 +109,12 @@ filterwarnings = [
# nptyping's alias warnings
'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*'
]
markers = [
"dtype: mark test related to dtype validation",
"shape: mark test related to shape validation",
"json_schema: mark test related to json schema generation",
"serialization: mark test related to serialization"
]
[tool.ruff]
target-version = "py311"

View file

@ -4,12 +4,13 @@ Interfaces between nptyping types and array backends
from numpydantic.interface.dask import DaskInterface
from numpydantic.interface.hdf5 import H5Interface
from numpydantic.interface.interface import Interface
from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.interface.numpy import NumpyInterface
from numpydantic.interface.video import VideoInterface
from numpydantic.interface.zarr import ZarrInterface
__all__ = [
"JsonDict",
"Interface",
"DaskInterface",
"H5Interface",

View file

@ -2,26 +2,59 @@
Interface for Dask arrays
"""
from typing import Any, Optional
from dataclasses import dataclass
from typing import Any, Iterable, Literal, Optional
import numpy as np
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface
from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType
try:
from dask.array import from_array
from dask.array.core import Array as DaskArray
except ImportError: # pragma: no cover
DaskArray = None
def _as_tuple(a_list: list | Any) -> tuple:
"""Make a list of list into a tuple of tuples"""
return tuple(
[_as_tuple(item) if isinstance(item, list) else item for item in a_list]
)
@dataclass(kw_only=True)
class DaskJsonDict(JsonDict):
"""
Round-trip json serialized form of a dask array
"""
type: Literal["dask"]
name: str
chunks: Iterable[tuple[int, ...]]
dtype: str
array: list
def to_array_input(self) -> DaskArray:
"""Construct a dask array"""
np_array = np.array(self.array, dtype=self.dtype)
array = from_array(
np_array,
name=self.name,
chunks=_as_tuple(self.chunks),
)
return array
class DaskInterface(Interface):
"""
Interface for Dask :class:`~dask.array.core.Array`
"""
input_types = (DaskArray,)
name = "dask"
input_types = (DaskArray, dict)
return_type = DaskArray
@classmethod
@ -29,7 +62,24 @@ class DaskInterface(Interface):
"""
check if array is a dask array
"""
return DaskArray is not None and isinstance(array, DaskArray)
if DaskArray is None:
return False
elif isinstance(array, DaskArray):
return True
elif isinstance(array, dict):
return DaskJsonDict.is_valid(array)
else:
return False
def before_validation(self, array: Any) -> DaskArray:
"""
If given a dict (like that from ``model_dump_json(round_trip=True)`` ),
re-cast to dask array
"""
if isinstance(array, dict):
array = DaskJsonDict(**array).to_array_input()
return array
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""Dask arrays require a compute() call to retrieve a single value"""
@ -43,7 +93,7 @@ class DaskInterface(Interface):
@classmethod
def to_json(
cls, array: DaskArray, info: Optional[SerializationInfo] = None
) -> list:
) -> list | DaskJsonDict:
"""
Convert an array to a JSON serializable array by first converting to a numpy
array and then to a list.
@ -56,4 +106,14 @@ class DaskInterface(Interface):
method of serialization here using the python object itself rather than
its JSON representation.
"""
return np.array(array).tolist()
np_array = np.array(array)
as_json = np_array.tolist()
if info.round_trip:
as_json = DaskJsonDict(
type=cls.name,
array=as_json,
name=array.name,
chunks=array.chunks,
dtype=str(np_array.dtype),
)
return as_json

View file

@ -40,6 +40,7 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
"""
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
@ -47,7 +48,7 @@ from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Un
import numpy as np
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface
from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType
try:
@ -76,6 +77,21 @@ class H5ArrayPath(NamedTuple):
"""Refer to a specific field within a compound dtype"""
@dataclass
class H5JsonDict(JsonDict):
"""Round-trip Json-able version of an HDF5 dataset"""
file: str
path: str
field: Optional[str] = None
def to_array_input(self) -> H5ArrayPath:
"""Construct an :class:`.H5ArrayPath`"""
return H5ArrayPath(
**{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields}
)
class H5Proxy:
"""
Proxy class to mimic numpy-like array behavior with an HDF5 array
@ -110,6 +126,7 @@ class H5Proxy:
self.path = path
self.field = field
self._annotation_dtype = annotation_dtype
self._h5arraypath = H5ArrayPath(self.file, self.path, self.field)
def array_exists(self) -> bool:
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
@ -212,6 +229,15 @@ class H5Proxy:
"""self.shape[0]"""
return self.shape[0]
def __eq__(self, other: "H5Proxy") -> bool:
"""
Check that we are referring to the same hdf5 array
"""
if isinstance(other, H5Proxy):
return self._h5arraypath == other._h5arraypath
else:
raise ValueError("Can only compare equality of two H5Proxies")
def open(self, mode: str = "r") -> "h5py.Dataset":
"""
Return the opened :class:`h5py.Dataset` object
@ -251,6 +277,7 @@ class H5Interface(Interface):
passthrough numpy-like interface to the dataset.
"""
name = "hdf5"
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
return_type = H5Proxy
@ -268,6 +295,13 @@ class H5Interface(Interface):
if isinstance(array, (H5ArrayPath, H5Proxy)):
return True
if isinstance(array, dict):
if array.get("type", False) == cls.name:
return True
# continue checking if dict contains an hdf5 file
file = array.get("file", "")
array = (file, "")
if isinstance(array, (tuple, list)) and len(array) in (2, 3):
# check that the first arg is an hdf5 file
try:
@ -294,6 +328,9 @@ class H5Interface(Interface):
def before_validation(self, array: Any) -> NDArrayType:
"""Create an :class:`.H5Proxy` to use throughout validation"""
if isinstance(array, dict):
array = H5JsonDict(**array).to_array_input()
if isinstance(array, H5ArrayPath):
array = H5Proxy.from_h5array(h5array=array)
elif isinstance(array, H5Proxy):
@ -349,21 +386,27 @@ class H5Interface(Interface):
@classmethod
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
"""
Dump to a dictionary containing
Render HDF5 array as JSON
If ``round_trip == True``, we dump just the proxy info, a dictionary like:
* ``file``: :attr:`.file`
* ``path``: :attr:`.path`
* ``attrs``: Any HDF5 attributes on the dataset
* ``array``: The array as a list of lists
Otherwise, we dump the array as a list of lists
"""
try:
dset = array.open()
meta = {
"file": array.file,
"path": array.path,
"attrs": dict(dset.attrs),
"array": dset[:].tolist(),
if info.round_trip:
as_json = {
"type": cls.name,
}
return meta
finally:
array.close()
as_json.update(array._h5arraypath._asdict())
else:
try:
dset = array.open()
as_json = dset[:].tolist()
finally:
array.close()
return as_json

View file

@ -2,12 +2,15 @@
Base Interface metaclass
"""
import inspect
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from importlib.metadata import PackageNotFoundError, version
from operator import attrgetter
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
import numpy as np
from pydantic import SerializationInfo
from pydantic import SerializationInfo, TypeAdapter, ValidationError
from numpydantic.exceptions import (
DtypeError,
@ -21,6 +24,60 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
T = TypeVar("T", bound=NDArrayType)
class InterfaceMark(TypedDict):
"""JSON-able mark to be able to round-trip json dumps"""
module: str
cls: str
version: str
@dataclass(kw_only=True)
class JsonDict:
"""
Representation of array when dumped with round_trip == True.
Using a dataclass rather than a pydantic model to not tempt
us to use more sophisticated types than can be serialized to json.
"""
type: str
@abstractmethod
def to_array_input(self) -> Any:
"""
Convert this roundtrip specifier to the relevant input class
(one of the ``input_types`` of an interface).
"""
def to_dict(self) -> dict:
"""
Convenience method for casting dataclass to dict,
removing None-valued items
"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def get_adapter(cls) -> TypeAdapter:
"""Convenience method to get a typeadapter for this class"""
return TypeAdapter(cls)
@classmethod
def is_valid(cls, val: dict) -> bool:
"""
Check whether a given dictionary matches this JsonDict specification
Returns:
bool - true if valid, false if not
"""
adapter = cls.get_adapter()
try:
_ = adapter.validate_python(val)
return True
except ValidationError:
return False
class Interface(ABC, Generic[T]):
"""
Abstract parent class for interfaces to different array formats
@ -30,7 +87,7 @@ class Interface(ABC, Generic[T]):
return_type: Type[T]
priority: int = 0
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
self.shape = shape
self.dtype = dtype
@ -86,6 +143,7 @@ class Interface(ABC, Generic[T]):
self.raise_for_shape(shape_valid, shape)
array = self.after_validation(array)
return array
def before_validation(self, array: Any) -> NDArrayType:
@ -117,8 +175,6 @@ class Interface(ABC, Generic[T]):
"""
Validate the dtype of the given array, returning
``True`` if valid, ``False`` if not.
"""
if self.dtype is Any:
return True
@ -196,6 +252,13 @@ class Interface(ABC, Generic[T]):
"""
return array
def mark_input(self, array: Any) -> Any:
"""
Preserve metadata about the interface and passed input when dumping with
``round_trip``
"""
return array
@classmethod
@abstractmethod
def check(cls, array: Any) -> bool:
@ -211,17 +274,40 @@ class Interface(ABC, Generic[T]):
installed, etc.)
"""
@property
@abstractmethod
def name(self) -> str:
"""
Short name for this interface
"""
@classmethod
def to_json(
cls, array: Type[T], info: Optional[SerializationInfo] = None
) -> Union[list, dict]:
@abstractmethod
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
"""
Convert an array of :attr:`.return_type` to a JSON-compatible format using
base python types
"""
if not isinstance(array, np.ndarray): # pragma: no cover
array = np.array(array)
return array.tolist()
@classmethod
def mark_json(cls, array: Union[list, dict]) -> dict:
"""
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
add additional annotations that would allow the serialized array to be
roundtripped.
Default is just to add an :class:`.InterfaceMark`
Examples:
>>> from pprint import pprint
>>> pprint(Interface.mark_json([1.0, 2.0]))
{'interface': {'cls': 'Interface',
'module': 'numpydantic.interface.interface',
'version': '1.2.2'},
'value': [1.0, 2.0]}
"""
return {"interface": cls.mark_interface(), "value": array}
@classmethod
def interfaces(
@ -335,3 +421,24 @@ class Interface(ABC, Generic[T]):
raise NoMatchError(f"No matching interfaces found for output {array}")
else:
return matches[0]
@classmethod
def mark_interface(cls) -> InterfaceMark:
"""
Create an interface mark indicating this interface for validation after
JSON serialization with ``round_trip==True``
"""
interface_module = inspect.getmodule(cls)
interface_module = (
None if interface_module is None else interface_module.__name__
)
try:
v = (
None
if interface_module is None
else version(interface_module.split(".")[0])
)
except PackageNotFoundError:
v = None
interface_name = cls.__name__
return InterfaceMark(module=interface_module, cls=interface_name, version=v)

View file

@ -2,9 +2,12 @@
Interface to numpy arrays
"""
from typing import Any
from dataclasses import dataclass
from typing import Any, Literal, Union
from numpydantic.interface.interface import Interface
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface, JsonDict
try:
import numpy as np
@ -18,11 +21,29 @@ except ImportError: # pragma: no cover
np = None
@dataclass
class NumpyJsonDict(JsonDict):
"""
JSON-able roundtrip representation of numpy array
"""
type: Literal["numpy"]
dtype: str
array: list
def to_array_input(self) -> ndarray:
"""
Construct a numpy array
"""
return np.array(self.array, dtype=self.dtype)
class NumpyInterface(Interface):
"""
Numpy :class:`~numpy.ndarray` s!
"""
name = "numpy"
input_types = (ndarray, list)
return_type = ndarray
priority = -999
@ -41,6 +62,8 @@ class NumpyInterface(Interface):
"""
if isinstance(array, ndarray):
return True
elif isinstance(array, dict):
return NumpyJsonDict.is_valid(array)
else:
try:
_ = np.array(array)
@ -53,6 +76,9 @@ class NumpyInterface(Interface):
Coerce to an ndarray. We have already checked if coercion is possible
in :meth:`.check`
"""
if isinstance(array, dict):
array = NumpyJsonDict(**array).to_array_input()
if not isinstance(array, ndarray):
array = np.array(array)
return array
@ -61,3 +87,22 @@ class NumpyInterface(Interface):
def enabled(cls) -> bool:
"""Check that numpy is present in the environment"""
return ENABLED
@classmethod
def to_json(
cls, array: ndarray, info: SerializationInfo = None
) -> Union[list, JsonDict]:
"""
Convert an array of :attr:`.return_type` to a JSON-compatible format using
base python types
"""
if not isinstance(array, np.ndarray): # pragma: no cover
array = np.array(array)
json_array = array.tolist()
if info.round_trip:
json_array = NumpyJsonDict(
type=cls.name, dtype=str(array.dtype), array=json_array
)
return json_array

View file

@ -2,11 +2,14 @@
Interface to support treating videos like arrays using OpenCV
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Tuple, Union
from typing import Any, Literal, Optional, Tuple, Union
import numpy as np
from pydantic_core.core_schema import SerializationInfo
from numpydantic.interface import JsonDict
from numpydantic.interface.interface import Interface
try:
@ -19,6 +22,20 @@ except ImportError: # pragma: no cover
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
@dataclass(kw_only=True)
class VideoJsonDict(JsonDict):
"""Json-able roundtrip representation of a video file"""
type: Literal["video"]
file: str
def to_array_input(self) -> "VideoProxy":
"""
Construct a :class:`.VideoProxy`
"""
return VideoProxy(path=Path(self.file))
class VideoProxy:
"""
Passthrough proxy class to interact with videos as arrays
@ -184,6 +201,12 @@ class VideoProxy:
def __getattr__(self, item: str):
return getattr(self.video, item)
def __eq__(self, other: "VideoProxy") -> bool:
"""Check if this is a proxy to the same video file"""
if not isinstance(other, VideoProxy):
raise TypeError("Can only compare equality of two VideoProxies")
return self.path == other.path
def __len__(self) -> int:
"""Number of frames in the video"""
return self.shape[0]
@ -194,6 +217,7 @@ class VideoInterface(Interface):
OpenCV interface to treat videos as arrays.
"""
name = "video"
input_types = (str, Path, VideoCapture, VideoProxy)
return_type = VideoProxy
@ -213,6 +237,9 @@ class VideoInterface(Interface):
):
return True
if isinstance(array, dict):
array = array.get("file", "")
if isinstance(array, str):
try:
array = Path(array)
@ -224,10 +251,22 @@ class VideoInterface(Interface):
def before_validation(self, array: Any) -> VideoProxy:
"""Get a :class:`.VideoProxy` object for this video"""
if isinstance(array, VideoCapture):
if isinstance(array, dict):
proxy = VideoJsonDict(**array).to_array_input()
elif isinstance(array, VideoCapture):
proxy = VideoProxy(video=array)
elif isinstance(array, VideoProxy):
proxy = array
else:
proxy = VideoProxy(path=array)
return proxy
@classmethod
def to_json(
cls, array: VideoProxy, info: SerializationInfo
) -> Union[list, VideoJsonDict]:
"""Return a json-representation of a video"""
if info.round_trip:
return VideoJsonDict(type=cls.name, file=str(array.path))
else:
return np.array(array).tolist()

View file

@ -5,12 +5,12 @@ Interface to zarr arrays
import contextlib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Sequence, Union
from typing import Any, Literal, Optional, Sequence, Union
import numpy as np
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface
from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType
try:
@ -56,11 +56,34 @@ class ZarrArrayPath:
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
@dataclass(kw_only=True)
class ZarrJsonDict(JsonDict):
"""Round-trip Json-able version of a Zarr Array"""
info: dict[str, str]
type: Literal["zarr"]
file: Optional[str] = None
path: Optional[str] = None
array: Optional[list] = None
def to_array_input(self) -> ZarrArray | ZarrArrayPath:
"""
Construct a ZarrArrayPath if file and path are present,
otherwise a ZarrArray
"""
if self.file:
array = ZarrArrayPath(file=self.file, path=self.path)
else:
array = zarr.array(self.array)
return array
class ZarrInterface(Interface):
"""
Interface to in-memory or on-disk zarr arrays
"""
name = "zarr"
input_types = (Path, ZarrArray, ZarrArrayPath)
return_type = ZarrArray
@ -73,6 +96,9 @@ class ZarrInterface(Interface):
def _get_array(
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
) -> ZarrArray:
if isinstance(array, dict):
array = ZarrJsonDict(**array).to_array_input()
if isinstance(array, ZarrArray):
return array
@ -92,6 +118,12 @@ class ZarrInterface(Interface):
if isinstance(array, ZarrArray):
return True
if isinstance(array, dict):
if array.get("type", False) == cls.name:
return True
# continue checking if dict contains a zarr file
array = array.get("file", "")
# See if can be coerced to ZarrArrayPath
if isinstance(array, (Path, str)):
array = ZarrArrayPath(file=array)
@ -135,26 +167,46 @@ class ZarrInterface(Interface):
cls,
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence],
info: Optional[SerializationInfo] = None,
) -> dict:
) -> list | ZarrJsonDict:
"""
Dump just the metadata for an array from :meth:`zarr.core.Array.info_items`
plus the :meth:`zarr.core.Array.hexdigest`.
Dump a Zarr Array to JSON
If ``info.round_trip == False``, dump the array as a list of lists.
This may be a memory-intensive operation.
Otherwise, dump the metadata for an array from :meth:`zarr.core.Array.info_items`
plus the :meth:`zarr.core.Array.hexdigest` as a :class:`.ZarrJsonDict`
If either the ``zarr_dump_array`` value in the context dictionary is ``True``
or the zarr array is an in-memory array, dump the array as well
(since without a persistent array it would be impossible to roundtrip and
dumping to JSON would be meaningless)
The full array can be returned by passing ``'zarr_dump_array': True`` to the
serialization ``context`` ::
Passing ``'zarr_dump_array': True`` to the serialization ``context`` looks like this::
model.model_dump_json(context={'zarr_dump_array': True})
"""
dump_array = False
if info is not None and info.context is not None:
dump_array = info.context.get("zarr_dump_array", False)
array = cls._get_array(array)
info = array.info_items()
info_dict = {i[0]: i[1] for i in info}
info_dict["hexdigest"] = array.hexdigest()
if dump_array:
info_dict["array"] = array[:].tolist()
if info.round_trip:
dump_array = False
if info is not None and info.context is not None:
dump_array = info.context.get("zarr_dump_array", False)
is_file = False
return info_dict
as_json = {"type": cls.name}
if hasattr(array.store, "dir_path"):
is_file = True
as_json["file"] = array.store.dir_path()
as_json["path"] = array.name
as_json["info"] = {i[0]: i[1] for i in array.info_items()}
as_json["info"]["hexdigest"] = array.hexdigest()
if dump_array or not is_file:
as_json["array"] = array[:].tolist()
as_json = ZarrJsonDict(**as_json)
else:
as_json = array[:].tolist()
return as_json

View file

@ -24,7 +24,6 @@ from numpydantic.exceptions import InterfaceError
from numpydantic.interface import Interface
from numpydantic.maps import python_to_nptyping
from numpydantic.schema import (
_handler_type,
_jsonize_array,
get_validate_interface,
make_json_schema,
@ -41,6 +40,9 @@ from numpydantic.vendor.nptyping.typing_ import (
if TYPE_CHECKING: # pragma: no cover
from nptyping.base_meta_classes import SubscriptableMeta
from pydantic._internal._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from numpydantic import Shape
@ -164,33 +166,34 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
def __get_pydantic_core_schema__(
cls,
_source_type: "NDArray",
_handler: _handler_type,
_handler: "CallbackGetCoreSchemaHandler",
) -> core_schema.CoreSchema:
shape, dtype = _source_type.__args__
shape: ShapeType
dtype: DtypeType
# get pydantic core schema as a list of lists for JSON schema
list_schema = make_json_schema(shape, dtype, _handler)
# make core schema for json schema, store it and any model definitions
# note that there is a big of fragility in this function,
# as we need to access a private method of _handler to
# flatten out the json schema. See help(make_json_schema)
json_schema = make_json_schema(shape, dtype, _handler)
return core_schema.json_or_python_schema(
json_schema=list_schema,
python_schema=core_schema.with_info_plain_validator_function(
get_validate_interface(shape, dtype)
),
return core_schema.with_info_plain_validator_function(
get_validate_interface(shape, dtype),
serialization=core_schema.plain_serializer_function_ser_schema(
_jsonize_array, when_used="json", info_arg=True
),
metadata=json_schema,
)
@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> core_schema.JsonSchema:
json_schema = handler(schema)
shape, dtype = cls.__args__
json_schema = handler(schema["metadata"])
json_schema = handler.resolve_ref_schema(json_schema)
dtype = cls.__args__[1]
if not isinstance(dtype, tuple) and dtype.__module__ not in (
"builtins",
"typing",

View file

@ -13,19 +13,24 @@ from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo
from numpydantic import dtype as dt
from numpydantic.interface import Interface
from numpydantic.interface import Interface, JsonDict
from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.vendor.nptyping.structure import StructureMeta
if TYPE_CHECKING: # pragma: no cover
from pydantic._internal._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from numpydantic import Shape
_handler_type = Callable[[Any], core_schema.CoreSchema]
_UNSUPPORTED_TYPES = (complex,)
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
def _numeric_dtype(
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
) -> CoreSchema:
"""Make a numeric dtype that respects min/max values from extended numpy types"""
if dtype in (np.number,):
dtype = float
@ -36,14 +41,19 @@ def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
elif issubclass(dtype, np.integer):
info = np.iinfo(dtype)
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
elif dtype is float:
schema = core_schema.float_schema()
elif dtype is int:
schema = core_schema.int_schema()
else:
schema = _handler.generate_schema(dtype)
return schema
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
def _lol_dtype(
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
) -> CoreSchema:
"""Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported")
@ -84,6 +94,10 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
# TODO: warn and log here
elif python_type in (float, int):
array_type = _numeric_dtype(dtype, _handler)
elif python_type is bool:
array_type = core_schema.bool_schema()
elif python_type is Any:
array_type = core_schema.any_schema()
else:
array_type = _handler.generate_schema(python_type)
@ -208,14 +222,24 @@ def _unbounded_shape(
def make_json_schema(
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
shape: ShapeType, dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
) -> ListSchema:
"""
Make a list of list JSON schema from a shape and a dtype.
Make a list of list pydantic core schema for an array from a shape and a dtype.
Used to generate JSON schema in the containing model, but not for validation,
which is handled by interfaces.
First resolves the dtype into a pydantic ``CoreSchema`` ,
and then uses that with :func:`.list_of_lists_schema` .
.. admonition:: Potentially Fragile
Uses a private method from the handler to flatten out nested definitions
(e.g. when dtype is a pydantic model)
so that they are present in the generated schema directly rather than
as references. Otherwise, at the time __get_pydantic_json_schema__ is called,
the definition references are lost.
Args:
shape ( ShapeType ): Specification of a shape, as a tuple or
an nptyping ``Shape``
@ -234,6 +258,8 @@ def make_json_schema(
else:
list_schema = list_of_lists_schema(shape, dtype_schema)
list_schema = _handler._generate_schema.clean_schema(list_schema)
return list_schema
@ -257,4 +283,11 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON"""
interface_cls = Interface.match_output(value)
return interface_cls.to_json(value, info)
array = interface_cls.to_json(value, info)
if isinstance(array, JsonDict):
array = array.to_dict()
if info.context and info.context.get("mark_interface", False):
array = interface_cls.mark_json(array)
return array

View file

@ -101,7 +101,8 @@ def test_assignment(hdf5_array, model_blank):
assert (model.array[1:3, 2:4] == 10).all()
def test_to_json(hdf5_array, array_model):
@pytest.mark.parametrize("round_trip", (True, False))
def test_to_json(hdf5_array, array_model, round_trip):
"""
Test serialization of HDF5 arrays to JSON
Args:
@ -115,13 +116,13 @@ def test_to_json(hdf5_array, array_model):
instance = model(array=array) # type: BaseModel
json_str = instance.model_dump_json()
json_dict = json.loads(json_str)["array"]
assert json_dict["file"] == str(array.file)
assert json_dict["path"] == str(array.path)
assert json_dict["attrs"] == {}
assert json_dict["array"] == instance.array[:].tolist()
json_str = instance.model_dump_json(round_trip=round_trip)
json_dumped = json.loads(json_str)["array"]
if round_trip:
assert json_dumped["file"] == str(array.file)
assert json_dumped["path"] == str(array.path)
else:
assert json_dumped == instance.array[:].tolist()
def test_compound_dtype(tmp_path):

View file

@ -2,8 +2,11 @@
Tests that should be applied to all interfaces
"""
import pytest
from typing import Callable
import numpy as np
import dask.array as da
from zarr.core import Array as ZarrArray
from numpydantic.interface import Interface
@ -35,8 +38,31 @@ def test_interface_to_numpy_array(all_interfaces):
_ = np.array(all_interfaces.array)
@pytest.mark.serialization
def test_interface_dump_json(all_interfaces):
"""
All interfaces should be able to dump to json
"""
all_interfaces.model_dump_json()
@pytest.mark.serialization
@pytest.mark.parametrize("round_trip", [True, False])
def test_interface_roundtrip_json(all_interfaces, round_trip):
"""
All interfaces should be able to roundtrip to and from json
"""
json = all_interfaces.model_dump_json(round_trip=round_trip)
model = all_interfaces.model_validate_json(json)
if round_trip:
assert type(model.array) is type(all_interfaces.array)
if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)):
assert np.array_equal(model.array, np.array(all_interfaces.array))
elif isinstance(all_interfaces.array, da.Array):
assert np.all(da.equal(model.array, all_interfaces.array))
else:
assert model.array == all_interfaces.array
assert model.array.dtype == all_interfaces.array.dtype
else:
assert np.array_equal(model.array, np.array(all_interfaces.array))

View file

@ -123,7 +123,10 @@ def test_zarr_array_path_from_iterable(zarr_array):
assert apath.path == inner_path
def test_zarr_to_json(store, model_blank):
@pytest.mark.serialization
@pytest.mark.parametrize("dump_array", [True, False])
@pytest.mark.parametrize("roundtrip", [True, False])
def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
expected_fields = (
"Type",
"Data type",
@ -137,17 +140,22 @@ def test_zarr_to_json(store, model_blank):
array = zarr.array(lol_array, store=store)
instance = model_blank(array=array)
as_json = json.loads(instance.model_dump_json())["array"]
assert "array" not in as_json
for field in expected_fields:
assert field in as_json
assert len(as_json["hexdigest"]) == 40
# dump the array itself too
as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[
"array"
]
for field in expected_fields:
assert field in as_json
assert len(as_json["hexdigest"]) == 40
assert as_json["array"] == lol_array
context = {"zarr_dump_array": dump_array}
as_json = json.loads(
instance.model_dump_json(round_trip=roundtrip, context=context)
)["array"]
if roundtrip:
if dump_array:
assert as_json["array"] == lol_array
else:
if as_json.get("file", False):
assert "array" not in as_json
for field in expected_fields:
assert field in as_json["info"]
assert len(as_json["info"]["hexdigest"]) == 40
else:
assert as_json == lol_array

View file

@ -15,6 +15,7 @@ from numpydantic import dtype
from numpydantic.dtype import Number
@pytest.mark.json_schema
def test_ndarray_type():
class Model(BaseModel):
array: NDArray[Shape["2 x, * y"], Number]
@ -40,6 +41,7 @@ def test_ndarray_type():
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
@pytest.mark.json_schema
def test_schema_unsupported_type():
"""
Complex numbers should just be made with an `any` schema
@ -55,6 +57,7 @@ def test_schema_unsupported_type():
}
@pytest.mark.json_schema
def test_schema_tuple():
"""
Types specified as tupled should have their schemas as a union
@ -72,6 +75,7 @@ def test_schema_tuple():
assert all([i["minimum"] == 0 for i in conditions])
@pytest.mark.json_schema
def test_schema_number():
"""
np.numeric should just be the float schema
@ -164,6 +168,7 @@ def test_ndarray_coercion():
amod = Model(array=["a", "b", "c"])
@pytest.mark.serialization
def test_ndarray_serialize():
"""
Arrays should be dumped to a list when using json, but kept as ndarray otherwise
@ -188,6 +193,7 @@ _json_schema_types = [
]
@pytest.mark.json_schema
def test_json_schema_basic(array_model):
"""
NDArray types should correctly generate a list of lists JSON schema
@ -210,6 +216,8 @@ def test_json_schema_basic(array_model):
assert inner["items"]["type"] == "number"
@pytest.mark.dtype
@pytest.mark.json_schema
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
def test_json_schema_dtype_single(dtype, array_model):
"""
@ -240,6 +248,7 @@ def test_json_schema_dtype_single(dtype, array_model):
)
@pytest.mark.dtype
@pytest.mark.parametrize(
"dtype,expected",
[
@ -266,6 +275,8 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
assert inner_type["type"] == expected
@pytest.mark.dtype
@pytest.mark.json_schema
def test_json_schema_dtype_model():
"""
Pydantic models can be used in arrays as dtypes
@ -314,6 +325,8 @@ def _recursive_array(schema):
assert any_of[1]["minimum"] == 0
@pytest.mark.shape
@pytest.mark.json_schema
def test_json_schema_ellipsis():
"""
NDArray types should create a recursive JSON schema for any-shaped arrays