mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 02:34:28 +00:00
Roundtrip json serialization using dataclasses for all interfaces. Separate JSON schema generation from core_schema generation. More consistent json dumping behavior - always return just the array unless round_trip=True
.
This commit is contained in:
parent
2cb09076fd
commit
b436d8d592
17 changed files with 532 additions and 99 deletions
|
@ -1,6 +0,0 @@
|
|||
# monkeypatch
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: numpydantic.monkeypatch
|
||||
:members:
|
||||
```
|
|
@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
|
|||
|
||||
design
|
||||
syntax
|
||||
serialization
|
||||
interfaces
|
||||
todo
|
||||
changelog
|
||||
|
@ -489,7 +490,6 @@ api/dtype
|
|||
api/ndarray
|
||||
api/maps
|
||||
api/meta
|
||||
api/monkeypatch
|
||||
api/schema
|
||||
api/shape
|
||||
api/types
|
||||
|
|
2
docs/serialization.md
Normal file
2
docs/serialization.md
Normal file
|
@ -0,0 +1,2 @@
|
|||
# Serialization
|
||||
|
|
@ -109,6 +109,12 @@ filterwarnings = [
|
|||
# nptyping's alias warnings
|
||||
'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*'
|
||||
]
|
||||
markers = [
|
||||
"dtype: mark test related to dtype validation",
|
||||
"shape: mark test related to shape validation",
|
||||
"json_schema: mark test related to json schema generation",
|
||||
"serialization: mark test related to serialization"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
|
|
|
@ -4,12 +4,13 @@ Interfaces between nptyping types and array backends
|
|||
|
||||
from numpydantic.interface.dask import DaskInterface
|
||||
from numpydantic.interface.hdf5 import H5Interface
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.interface.numpy import NumpyInterface
|
||||
from numpydantic.interface.video import VideoInterface
|
||||
from numpydantic.interface.zarr import ZarrInterface
|
||||
|
||||
__all__ = [
|
||||
"JsonDict",
|
||||
"Interface",
|
||||
"DaskInterface",
|
||||
"H5Interface",
|
||||
|
|
|
@ -2,26 +2,59 @@
|
|||
Interface for Dask arrays
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
||||
try:
|
||||
from dask.array import from_array
|
||||
from dask.array.core import Array as DaskArray
|
||||
except ImportError: # pragma: no cover
|
||||
DaskArray = None
|
||||
|
||||
|
||||
def _as_tuple(a_list: list | Any) -> tuple:
|
||||
"""Make a list of list into a tuple of tuples"""
|
||||
return tuple(
|
||||
[_as_tuple(item) if isinstance(item, list) else item for item in a_list]
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class DaskJsonDict(JsonDict):
|
||||
"""
|
||||
Round-trip json serialized form of a dask array
|
||||
"""
|
||||
|
||||
type: Literal["dask"]
|
||||
name: str
|
||||
chunks: Iterable[tuple[int, ...]]
|
||||
dtype: str
|
||||
array: list
|
||||
|
||||
def to_array_input(self) -> DaskArray:
|
||||
"""Construct a dask array"""
|
||||
np_array = np.array(self.array, dtype=self.dtype)
|
||||
array = from_array(
|
||||
np_array,
|
||||
name=self.name,
|
||||
chunks=_as_tuple(self.chunks),
|
||||
)
|
||||
return array
|
||||
|
||||
|
||||
class DaskInterface(Interface):
|
||||
"""
|
||||
Interface for Dask :class:`~dask.array.core.Array`
|
||||
"""
|
||||
|
||||
input_types = (DaskArray,)
|
||||
name = "dask"
|
||||
input_types = (DaskArray, dict)
|
||||
return_type = DaskArray
|
||||
|
||||
@classmethod
|
||||
|
@ -29,7 +62,24 @@ class DaskInterface(Interface):
|
|||
"""
|
||||
check if array is a dask array
|
||||
"""
|
||||
return DaskArray is not None and isinstance(array, DaskArray)
|
||||
if DaskArray is None:
|
||||
return False
|
||||
elif isinstance(array, DaskArray):
|
||||
return True
|
||||
elif isinstance(array, dict):
|
||||
return DaskJsonDict.is_valid(array)
|
||||
else:
|
||||
return False
|
||||
|
||||
def before_validation(self, array: Any) -> DaskArray:
|
||||
"""
|
||||
If given a dict (like that from ``model_dump_json(round_trip=True)`` ),
|
||||
re-cast to dask array
|
||||
"""
|
||||
if isinstance(array, dict):
|
||||
array = DaskJsonDict(**array).to_array_input()
|
||||
|
||||
return array
|
||||
|
||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||
|
@ -43,7 +93,7 @@ class DaskInterface(Interface):
|
|||
@classmethod
|
||||
def to_json(
|
||||
cls, array: DaskArray, info: Optional[SerializationInfo] = None
|
||||
) -> list:
|
||||
) -> list | DaskJsonDict:
|
||||
"""
|
||||
Convert an array to a JSON serializable array by first converting to a numpy
|
||||
array and then to a list.
|
||||
|
@ -56,4 +106,14 @@ class DaskInterface(Interface):
|
|||
method of serialization here using the python object itself rather than
|
||||
its JSON representation.
|
||||
"""
|
||||
return np.array(array).tolist()
|
||||
np_array = np.array(array)
|
||||
as_json = np_array.tolist()
|
||||
if info.round_trip:
|
||||
as_json = DaskJsonDict(
|
||||
type=cls.name,
|
||||
array=as_json,
|
||||
name=array.name,
|
||||
chunks=array.chunks,
|
||||
dtype=str(np_array.dtype),
|
||||
)
|
||||
return as_json
|
||||
|
|
|
@ -40,6 +40,7 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
|
|||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
|
||||
|
@ -47,7 +48,7 @@ from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Un
|
|||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
||||
try:
|
||||
|
@ -76,6 +77,21 @@ class H5ArrayPath(NamedTuple):
|
|||
"""Refer to a specific field within a compound dtype"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class H5JsonDict(JsonDict):
|
||||
"""Round-trip Json-able version of an HDF5 dataset"""
|
||||
|
||||
file: str
|
||||
path: str
|
||||
field: Optional[str] = None
|
||||
|
||||
def to_array_input(self) -> H5ArrayPath:
|
||||
"""Construct an :class:`.H5ArrayPath`"""
|
||||
return H5ArrayPath(
|
||||
**{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields}
|
||||
)
|
||||
|
||||
|
||||
class H5Proxy:
|
||||
"""
|
||||
Proxy class to mimic numpy-like array behavior with an HDF5 array
|
||||
|
@ -110,6 +126,7 @@ class H5Proxy:
|
|||
self.path = path
|
||||
self.field = field
|
||||
self._annotation_dtype = annotation_dtype
|
||||
self._h5arraypath = H5ArrayPath(self.file, self.path, self.field)
|
||||
|
||||
def array_exists(self) -> bool:
|
||||
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
||||
|
@ -212,6 +229,15 @@ class H5Proxy:
|
|||
"""self.shape[0]"""
|
||||
return self.shape[0]
|
||||
|
||||
def __eq__(self, other: "H5Proxy") -> bool:
|
||||
"""
|
||||
Check that we are referring to the same hdf5 array
|
||||
"""
|
||||
if isinstance(other, H5Proxy):
|
||||
return self._h5arraypath == other._h5arraypath
|
||||
else:
|
||||
raise ValueError("Can only compare equality of two H5Proxies")
|
||||
|
||||
def open(self, mode: str = "r") -> "h5py.Dataset":
|
||||
"""
|
||||
Return the opened :class:`h5py.Dataset` object
|
||||
|
@ -251,6 +277,7 @@ class H5Interface(Interface):
|
|||
passthrough numpy-like interface to the dataset.
|
||||
"""
|
||||
|
||||
name = "hdf5"
|
||||
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
|
||||
return_type = H5Proxy
|
||||
|
||||
|
@ -268,6 +295,13 @@ class H5Interface(Interface):
|
|||
if isinstance(array, (H5ArrayPath, H5Proxy)):
|
||||
return True
|
||||
|
||||
if isinstance(array, dict):
|
||||
if array.get("type", False) == cls.name:
|
||||
return True
|
||||
# continue checking if dict contains an hdf5 file
|
||||
file = array.get("file", "")
|
||||
array = (file, "")
|
||||
|
||||
if isinstance(array, (tuple, list)) and len(array) in (2, 3):
|
||||
# check that the first arg is an hdf5 file
|
||||
try:
|
||||
|
@ -294,6 +328,9 @@ class H5Interface(Interface):
|
|||
|
||||
def before_validation(self, array: Any) -> NDArrayType:
|
||||
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
||||
if isinstance(array, dict):
|
||||
array = H5JsonDict(**array).to_array_input()
|
||||
|
||||
if isinstance(array, H5ArrayPath):
|
||||
array = H5Proxy.from_h5array(h5array=array)
|
||||
elif isinstance(array, H5Proxy):
|
||||
|
@ -349,21 +386,27 @@ class H5Interface(Interface):
|
|||
@classmethod
|
||||
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
|
||||
"""
|
||||
Dump to a dictionary containing
|
||||
Render HDF5 array as JSON
|
||||
|
||||
If ``round_trip == True``, we dump just the proxy info, a dictionary like:
|
||||
|
||||
* ``file``: :attr:`.file`
|
||||
* ``path``: :attr:`.path`
|
||||
* ``attrs``: Any HDF5 attributes on the dataset
|
||||
* ``array``: The array as a list of lists
|
||||
|
||||
Otherwise, we dump the array as a list of lists
|
||||
"""
|
||||
if info.round_trip:
|
||||
as_json = {
|
||||
"type": cls.name,
|
||||
}
|
||||
as_json.update(array._h5arraypath._asdict())
|
||||
else:
|
||||
try:
|
||||
dset = array.open()
|
||||
meta = {
|
||||
"file": array.file,
|
||||
"path": array.path,
|
||||
"attrs": dict(dset.attrs),
|
||||
"array": dset[:].tolist(),
|
||||
}
|
||||
return meta
|
||||
as_json = dset[:].tolist()
|
||||
finally:
|
||||
array.close()
|
||||
|
||||
return as_json
|
||||
|
|
|
@ -2,12 +2,15 @@
|
|||
Base Interface metaclass
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from operator import attrgetter
|
||||
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
from pydantic import SerializationInfo, TypeAdapter, ValidationError
|
||||
|
||||
from numpydantic.exceptions import (
|
||||
DtypeError,
|
||||
|
@ -21,6 +24,60 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
|||
T = TypeVar("T", bound=NDArrayType)
|
||||
|
||||
|
||||
class InterfaceMark(TypedDict):
|
||||
"""JSON-able mark to be able to round-trip json dumps"""
|
||||
|
||||
module: str
|
||||
cls: str
|
||||
version: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class JsonDict:
|
||||
"""
|
||||
Representation of array when dumped with round_trip == True.
|
||||
|
||||
Using a dataclass rather than a pydantic model to not tempt
|
||||
us to use more sophisticated types than can be serialized to json.
|
||||
"""
|
||||
|
||||
type: str
|
||||
|
||||
@abstractmethod
|
||||
def to_array_input(self) -> Any:
|
||||
"""
|
||||
Convert this roundtrip specifier to the relevant input class
|
||||
(one of the ``input_types`` of an interface).
|
||||
"""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convenience method for casting dataclass to dict,
|
||||
removing None-valued items
|
||||
"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def get_adapter(cls) -> TypeAdapter:
|
||||
"""Convenience method to get a typeadapter for this class"""
|
||||
return TypeAdapter(cls)
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls, val: dict) -> bool:
|
||||
"""
|
||||
Check whether a given dictionary matches this JsonDict specification
|
||||
|
||||
Returns:
|
||||
bool - true if valid, false if not
|
||||
"""
|
||||
adapter = cls.get_adapter()
|
||||
try:
|
||||
_ = adapter.validate_python(val)
|
||||
return True
|
||||
except ValidationError:
|
||||
return False
|
||||
|
||||
|
||||
class Interface(ABC, Generic[T]):
|
||||
"""
|
||||
Abstract parent class for interfaces to different array formats
|
||||
|
@ -30,7 +87,7 @@ class Interface(ABC, Generic[T]):
|
|||
return_type: Type[T]
|
||||
priority: int = 0
|
||||
|
||||
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
|
||||
def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
|
@ -86,6 +143,7 @@ class Interface(ABC, Generic[T]):
|
|||
self.raise_for_shape(shape_valid, shape)
|
||||
|
||||
array = self.after_validation(array)
|
||||
|
||||
return array
|
||||
|
||||
def before_validation(self, array: Any) -> NDArrayType:
|
||||
|
@ -117,8 +175,6 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
Validate the dtype of the given array, returning
|
||||
``True`` if valid, ``False`` if not.
|
||||
|
||||
|
||||
"""
|
||||
if self.dtype is Any:
|
||||
return True
|
||||
|
@ -196,6 +252,13 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
return array
|
||||
|
||||
def mark_input(self, array: Any) -> Any:
|
||||
"""
|
||||
Preserve metadata about the interface and passed input when dumping with
|
||||
``round_trip``
|
||||
"""
|
||||
return array
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def check(cls, array: Any) -> bool:
|
||||
|
@ -211,17 +274,40 @@ class Interface(ABC, Generic[T]):
|
|||
installed, etc.)
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Short name for this interface
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_json(
|
||||
cls, array: Type[T], info: Optional[SerializationInfo] = None
|
||||
) -> Union[list, dict]:
|
||||
@abstractmethod
|
||||
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
|
||||
"""
|
||||
Convert an array of :attr:`.return_type` to a JSON-compatible format using
|
||||
base python types
|
||||
"""
|
||||
if not isinstance(array, np.ndarray): # pragma: no cover
|
||||
array = np.array(array)
|
||||
return array.tolist()
|
||||
|
||||
@classmethod
|
||||
def mark_json(cls, array: Union[list, dict]) -> dict:
|
||||
"""
|
||||
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
|
||||
add additional annotations that would allow the serialized array to be
|
||||
roundtripped.
|
||||
|
||||
Default is just to add an :class:`.InterfaceMark`
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from pprint import pprint
|
||||
>>> pprint(Interface.mark_json([1.0, 2.0]))
|
||||
{'interface': {'cls': 'Interface',
|
||||
'module': 'numpydantic.interface.interface',
|
||||
'version': '1.2.2'},
|
||||
'value': [1.0, 2.0]}
|
||||
"""
|
||||
return {"interface": cls.mark_interface(), "value": array}
|
||||
|
||||
@classmethod
|
||||
def interfaces(
|
||||
|
@ -335,3 +421,24 @@ class Interface(ABC, Generic[T]):
|
|||
raise NoMatchError(f"No matching interfaces found for output {array}")
|
||||
else:
|
||||
return matches[0]
|
||||
|
||||
@classmethod
|
||||
def mark_interface(cls) -> InterfaceMark:
|
||||
"""
|
||||
Create an interface mark indicating this interface for validation after
|
||||
JSON serialization with ``round_trip==True``
|
||||
"""
|
||||
interface_module = inspect.getmodule(cls)
|
||||
interface_module = (
|
||||
None if interface_module is None else interface_module.__name__
|
||||
)
|
||||
try:
|
||||
v = (
|
||||
None
|
||||
if interface_module is None
|
||||
else version(interface_module.split(".")[0])
|
||||
)
|
||||
except PackageNotFoundError:
|
||||
v = None
|
||||
interface_name = cls.__name__
|
||||
return InterfaceMark(module=interface_module, cls=interface_name, version=v)
|
||||
|
|
|
@ -2,9 +2,12 @@
|
|||
Interface to numpy arrays
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
@ -18,11 +21,29 @@ except ImportError: # pragma: no cover
|
|||
np = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumpyJsonDict(JsonDict):
|
||||
"""
|
||||
JSON-able roundtrip representation of numpy array
|
||||
"""
|
||||
|
||||
type: Literal["numpy"]
|
||||
dtype: str
|
||||
array: list
|
||||
|
||||
def to_array_input(self) -> ndarray:
|
||||
"""
|
||||
Construct a numpy array
|
||||
"""
|
||||
return np.array(self.array, dtype=self.dtype)
|
||||
|
||||
|
||||
class NumpyInterface(Interface):
|
||||
"""
|
||||
Numpy :class:`~numpy.ndarray` s!
|
||||
"""
|
||||
|
||||
name = "numpy"
|
||||
input_types = (ndarray, list)
|
||||
return_type = ndarray
|
||||
priority = -999
|
||||
|
@ -41,6 +62,8 @@ class NumpyInterface(Interface):
|
|||
"""
|
||||
if isinstance(array, ndarray):
|
||||
return True
|
||||
elif isinstance(array, dict):
|
||||
return NumpyJsonDict.is_valid(array)
|
||||
else:
|
||||
try:
|
||||
_ = np.array(array)
|
||||
|
@ -53,6 +76,9 @@ class NumpyInterface(Interface):
|
|||
Coerce to an ndarray. We have already checked if coercion is possible
|
||||
in :meth:`.check`
|
||||
"""
|
||||
if isinstance(array, dict):
|
||||
array = NumpyJsonDict(**array).to_array_input()
|
||||
|
||||
if not isinstance(array, ndarray):
|
||||
array = np.array(array)
|
||||
return array
|
||||
|
@ -61,3 +87,22 @@ class NumpyInterface(Interface):
|
|||
def enabled(cls) -> bool:
|
||||
"""Check that numpy is present in the environment"""
|
||||
return ENABLED
|
||||
|
||||
@classmethod
|
||||
def to_json(
|
||||
cls, array: ndarray, info: SerializationInfo = None
|
||||
) -> Union[list, JsonDict]:
|
||||
"""
|
||||
Convert an array of :attr:`.return_type` to a JSON-compatible format using
|
||||
base python types
|
||||
"""
|
||||
if not isinstance(array, np.ndarray): # pragma: no cover
|
||||
array = np.array(array)
|
||||
|
||||
json_array = array.tolist()
|
||||
|
||||
if info.round_trip:
|
||||
json_array = NumpyJsonDict(
|
||||
type=cls.name, dtype=str(array.dtype), array=json_array
|
||||
)
|
||||
return json_array
|
||||
|
|
|
@ -2,11 +2,14 @@
|
|||
Interface to support treating videos like arrays using OpenCV
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic_core.core_schema import SerializationInfo
|
||||
|
||||
from numpydantic.interface import JsonDict
|
||||
from numpydantic.interface.interface import Interface
|
||||
|
||||
try:
|
||||
|
@ -19,6 +22,20 @@ except ImportError: # pragma: no cover
|
|||
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class VideoJsonDict(JsonDict):
|
||||
"""Json-able roundtrip representation of a video file"""
|
||||
|
||||
type: Literal["video"]
|
||||
file: str
|
||||
|
||||
def to_array_input(self) -> "VideoProxy":
|
||||
"""
|
||||
Construct a :class:`.VideoProxy`
|
||||
"""
|
||||
return VideoProxy(path=Path(self.file))
|
||||
|
||||
|
||||
class VideoProxy:
|
||||
"""
|
||||
Passthrough proxy class to interact with videos as arrays
|
||||
|
@ -184,6 +201,12 @@ class VideoProxy:
|
|||
def __getattr__(self, item: str):
|
||||
return getattr(self.video, item)
|
||||
|
||||
def __eq__(self, other: "VideoProxy") -> bool:
|
||||
"""Check if this is a proxy to the same video file"""
|
||||
if not isinstance(other, VideoProxy):
|
||||
raise TypeError("Can only compare equality of two VideoProxies")
|
||||
return self.path == other.path
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Number of frames in the video"""
|
||||
return self.shape[0]
|
||||
|
@ -194,6 +217,7 @@ class VideoInterface(Interface):
|
|||
OpenCV interface to treat videos as arrays.
|
||||
"""
|
||||
|
||||
name = "video"
|
||||
input_types = (str, Path, VideoCapture, VideoProxy)
|
||||
return_type = VideoProxy
|
||||
|
||||
|
@ -213,6 +237,9 @@ class VideoInterface(Interface):
|
|||
):
|
||||
return True
|
||||
|
||||
if isinstance(array, dict):
|
||||
array = array.get("file", "")
|
||||
|
||||
if isinstance(array, str):
|
||||
try:
|
||||
array = Path(array)
|
||||
|
@ -224,10 +251,22 @@ class VideoInterface(Interface):
|
|||
|
||||
def before_validation(self, array: Any) -> VideoProxy:
|
||||
"""Get a :class:`.VideoProxy` object for this video"""
|
||||
if isinstance(array, VideoCapture):
|
||||
if isinstance(array, dict):
|
||||
proxy = VideoJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, VideoCapture):
|
||||
proxy = VideoProxy(video=array)
|
||||
elif isinstance(array, VideoProxy):
|
||||
proxy = array
|
||||
else:
|
||||
proxy = VideoProxy(path=array)
|
||||
return proxy
|
||||
|
||||
@classmethod
|
||||
def to_json(
|
||||
cls, array: VideoProxy, info: SerializationInfo
|
||||
) -> Union[list, VideoJsonDict]:
|
||||
"""Return a json-representation of a video"""
|
||||
if info.round_trip:
|
||||
return VideoJsonDict(type=cls.name, file=str(array.path))
|
||||
else:
|
||||
return np.array(array).tolist()
|
||||
|
|
|
@ -5,12 +5,12 @@ Interface to zarr arrays
|
|||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from typing import Any, Literal, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType
|
||||
|
||||
try:
|
||||
|
@ -56,11 +56,34 @@ class ZarrArrayPath:
|
|||
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ZarrJsonDict(JsonDict):
|
||||
"""Round-trip Json-able version of a Zarr Array"""
|
||||
|
||||
info: dict[str, str]
|
||||
type: Literal["zarr"]
|
||||
file: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
array: Optional[list] = None
|
||||
|
||||
def to_array_input(self) -> ZarrArray | ZarrArrayPath:
|
||||
"""
|
||||
Construct a ZarrArrayPath if file and path are present,
|
||||
otherwise a ZarrArray
|
||||
"""
|
||||
if self.file:
|
||||
array = ZarrArrayPath(file=self.file, path=self.path)
|
||||
else:
|
||||
array = zarr.array(self.array)
|
||||
return array
|
||||
|
||||
|
||||
class ZarrInterface(Interface):
|
||||
"""
|
||||
Interface to in-memory or on-disk zarr arrays
|
||||
"""
|
||||
|
||||
name = "zarr"
|
||||
input_types = (Path, ZarrArray, ZarrArrayPath)
|
||||
return_type = ZarrArray
|
||||
|
||||
|
@ -73,6 +96,9 @@ class ZarrInterface(Interface):
|
|||
def _get_array(
|
||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
|
||||
) -> ZarrArray:
|
||||
if isinstance(array, dict):
|
||||
array = ZarrJsonDict(**array).to_array_input()
|
||||
|
||||
if isinstance(array, ZarrArray):
|
||||
return array
|
||||
|
||||
|
@ -92,6 +118,12 @@ class ZarrInterface(Interface):
|
|||
if isinstance(array, ZarrArray):
|
||||
return True
|
||||
|
||||
if isinstance(array, dict):
|
||||
if array.get("type", False) == cls.name:
|
||||
return True
|
||||
# continue checking if dict contains a zarr file
|
||||
array = array.get("file", "")
|
||||
|
||||
# See if can be coerced to ZarrArrayPath
|
||||
if isinstance(array, (Path, str)):
|
||||
array = ZarrArrayPath(file=array)
|
||||
|
@ -135,26 +167,46 @@ class ZarrInterface(Interface):
|
|||
cls,
|
||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence],
|
||||
info: Optional[SerializationInfo] = None,
|
||||
) -> dict:
|
||||
) -> list | ZarrJsonDict:
|
||||
"""
|
||||
Dump just the metadata for an array from :meth:`zarr.core.Array.info_items`
|
||||
plus the :meth:`zarr.core.Array.hexdigest`.
|
||||
Dump a Zarr Array to JSON
|
||||
|
||||
The full array can be returned by passing ``'zarr_dump_array': True`` to the
|
||||
serialization ``context`` ::
|
||||
If ``info.round_trip == False``, dump the array as a list of lists.
|
||||
This may be a memory-intensive operation.
|
||||
|
||||
Otherwise, dump the metadata for an array from :meth:`zarr.core.Array.info_items`
|
||||
plus the :meth:`zarr.core.Array.hexdigest` as a :class:`.ZarrJsonDict`
|
||||
|
||||
If either the ``zarr_dump_array`` value in the context dictionary is ``True``
|
||||
or the zarr array is an in-memory array, dump the array as well
|
||||
(since without a persistent array it would be impossible to roundtrip and
|
||||
dumping to JSON would be meaningless)
|
||||
|
||||
Passing ``'zarr_dump_array': True`` to the serialization ``context`` looks like this::
|
||||
|
||||
model.model_dump_json(context={'zarr_dump_array': True})
|
||||
"""
|
||||
array = cls._get_array(array)
|
||||
|
||||
if info.round_trip:
|
||||
dump_array = False
|
||||
if info is not None and info.context is not None:
|
||||
dump_array = info.context.get("zarr_dump_array", False)
|
||||
is_file = False
|
||||
|
||||
array = cls._get_array(array)
|
||||
info = array.info_items()
|
||||
info_dict = {i[0]: i[1] for i in info}
|
||||
info_dict["hexdigest"] = array.hexdigest()
|
||||
as_json = {"type": cls.name}
|
||||
if hasattr(array.store, "dir_path"):
|
||||
is_file = True
|
||||
as_json["file"] = array.store.dir_path()
|
||||
as_json["path"] = array.name
|
||||
as_json["info"] = {i[0]: i[1] for i in array.info_items()}
|
||||
as_json["info"]["hexdigest"] = array.hexdigest()
|
||||
|
||||
if dump_array:
|
||||
info_dict["array"] = array[:].tolist()
|
||||
if dump_array or not is_file:
|
||||
as_json["array"] = array[:].tolist()
|
||||
|
||||
return info_dict
|
||||
as_json = ZarrJsonDict(**as_json)
|
||||
else:
|
||||
as_json = array[:].tolist()
|
||||
|
||||
return as_json
|
||||
|
|
|
@ -24,7 +24,6 @@ from numpydantic.exceptions import InterfaceError
|
|||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import python_to_nptyping
|
||||
from numpydantic.schema import (
|
||||
_handler_type,
|
||||
_jsonize_array,
|
||||
get_validate_interface,
|
||||
make_json_schema,
|
||||
|
@ -41,6 +40,9 @@ from numpydantic.vendor.nptyping.typing_ import (
|
|||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from nptyping.base_meta_classes import SubscriptableMeta
|
||||
from pydantic._internal._schema_generation_shared import (
|
||||
CallbackGetCoreSchemaHandler,
|
||||
)
|
||||
|
||||
from numpydantic import Shape
|
||||
|
||||
|
@ -164,33 +166,34 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: "NDArray",
|
||||
_handler: _handler_type,
|
||||
_handler: "CallbackGetCoreSchemaHandler",
|
||||
) -> core_schema.CoreSchema:
|
||||
shape, dtype = _source_type.__args__
|
||||
shape: ShapeType
|
||||
dtype: DtypeType
|
||||
|
||||
# get pydantic core schema as a list of lists for JSON schema
|
||||
list_schema = make_json_schema(shape, dtype, _handler)
|
||||
# make core schema for json schema, store it and any model definitions
|
||||
# note that there is a big of fragility in this function,
|
||||
# as we need to access a private method of _handler to
|
||||
# flatten out the json schema. See help(make_json_schema)
|
||||
json_schema = make_json_schema(shape, dtype, _handler)
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=list_schema,
|
||||
python_schema=core_schema.with_info_plain_validator_function(
|
||||
get_validate_interface(shape, dtype)
|
||||
),
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
get_validate_interface(shape, dtype),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
_jsonize_array, when_used="json", info_arg=True
|
||||
),
|
||||
metadata=json_schema,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> core_schema.JsonSchema:
|
||||
json_schema = handler(schema)
|
||||
shape, dtype = cls.__args__
|
||||
json_schema = handler(schema["metadata"])
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
|
||||
dtype = cls.__args__[1]
|
||||
if not isinstance(dtype, tuple) and dtype.__module__ not in (
|
||||
"builtins",
|
||||
"typing",
|
||||
|
|
|
@ -13,19 +13,24 @@ from pydantic_core import CoreSchema, core_schema
|
|||
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
||||
|
||||
from numpydantic import dtype as dt
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.interface import Interface, JsonDict
|
||||
from numpydantic.maps import np_to_python
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
from numpydantic.vendor.nptyping.structure import StructureMeta
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from pydantic._internal._schema_generation_shared import (
|
||||
CallbackGetCoreSchemaHandler,
|
||||
)
|
||||
|
||||
from numpydantic import Shape
|
||||
|
||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
||||
_UNSUPPORTED_TYPES = (complex,)
|
||||
|
||||
|
||||
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
def _numeric_dtype(
|
||||
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||
) -> CoreSchema:
|
||||
"""Make a numeric dtype that respects min/max values from extended numpy types"""
|
||||
if dtype in (np.number,):
|
||||
dtype = float
|
||||
|
@ -36,14 +41,19 @@ def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
|||
elif issubclass(dtype, np.integer):
|
||||
info = np.iinfo(dtype)
|
||||
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
|
||||
|
||||
elif dtype is float:
|
||||
schema = core_schema.float_schema()
|
||||
elif dtype is int:
|
||||
schema = core_schema.int_schema()
|
||||
else:
|
||||
schema = _handler.generate_schema(dtype)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
def _lol_dtype(
|
||||
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||
) -> CoreSchema:
|
||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||
if isinstance(dtype, StructureMeta): # pragma: no cover
|
||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||
|
@ -84,6 +94,10 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
|||
# TODO: warn and log here
|
||||
elif python_type in (float, int):
|
||||
array_type = _numeric_dtype(dtype, _handler)
|
||||
elif python_type is bool:
|
||||
array_type = core_schema.bool_schema()
|
||||
elif python_type is Any:
|
||||
array_type = core_schema.any_schema()
|
||||
else:
|
||||
array_type = _handler.generate_schema(python_type)
|
||||
|
||||
|
@ -208,14 +222,24 @@ def _unbounded_shape(
|
|||
|
||||
|
||||
def make_json_schema(
|
||||
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
|
||||
shape: ShapeType, dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||
) -> ListSchema:
|
||||
"""
|
||||
Make a list of list JSON schema from a shape and a dtype.
|
||||
Make a list of list pydantic core schema for an array from a shape and a dtype.
|
||||
Used to generate JSON schema in the containing model, but not for validation,
|
||||
which is handled by interfaces.
|
||||
|
||||
First resolves the dtype into a pydantic ``CoreSchema`` ,
|
||||
and then uses that with :func:`.list_of_lists_schema` .
|
||||
|
||||
.. admonition:: Potentially Fragile
|
||||
|
||||
Uses a private method from the handler to flatten out nested definitions
|
||||
(e.g. when dtype is a pydantic model)
|
||||
so that they are present in the generated schema directly rather than
|
||||
as references. Otherwise, at the time __get_pydantic_json_schema__ is called,
|
||||
the definition references are lost.
|
||||
|
||||
Args:
|
||||
shape ( ShapeType ): Specification of a shape, as a tuple or
|
||||
an nptyping ``Shape``
|
||||
|
@ -234,6 +258,8 @@ def make_json_schema(
|
|||
else:
|
||||
list_schema = list_of_lists_schema(shape, dtype_schema)
|
||||
|
||||
list_schema = _handler._generate_schema.clean_schema(list_schema)
|
||||
|
||||
return list_schema
|
||||
|
||||
|
||||
|
@ -257,4 +283,11 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
|||
def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||
"""Use an interface class to render an array as JSON"""
|
||||
interface_cls = Interface.match_output(value)
|
||||
return interface_cls.to_json(value, info)
|
||||
array = interface_cls.to_json(value, info)
|
||||
if isinstance(array, JsonDict):
|
||||
array = array.to_dict()
|
||||
|
||||
if info.context and info.context.get("mark_interface", False):
|
||||
array = interface_cls.mark_json(array)
|
||||
|
||||
return array
|
||||
|
|
|
@ -101,7 +101,8 @@ def test_assignment(hdf5_array, model_blank):
|
|||
assert (model.array[1:3, 2:4] == 10).all()
|
||||
|
||||
|
||||
def test_to_json(hdf5_array, array_model):
|
||||
@pytest.mark.parametrize("round_trip", (True, False))
|
||||
def test_to_json(hdf5_array, array_model, round_trip):
|
||||
"""
|
||||
Test serialization of HDF5 arrays to JSON
|
||||
Args:
|
||||
|
@ -115,13 +116,13 @@ def test_to_json(hdf5_array, array_model):
|
|||
|
||||
instance = model(array=array) # type: BaseModel
|
||||
|
||||
json_str = instance.model_dump_json()
|
||||
json_dict = json.loads(json_str)["array"]
|
||||
|
||||
assert json_dict["file"] == str(array.file)
|
||||
assert json_dict["path"] == str(array.path)
|
||||
assert json_dict["attrs"] == {}
|
||||
assert json_dict["array"] == instance.array[:].tolist()
|
||||
json_str = instance.model_dump_json(round_trip=round_trip)
|
||||
json_dumped = json.loads(json_str)["array"]
|
||||
if round_trip:
|
||||
assert json_dumped["file"] == str(array.file)
|
||||
assert json_dumped["path"] == str(array.path)
|
||||
else:
|
||||
assert json_dumped == instance.array[:].tolist()
|
||||
|
||||
|
||||
def test_compound_dtype(tmp_path):
|
||||
|
|
|
@ -2,8 +2,11 @@
|
|||
Tests that should be applied to all interfaces
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from typing import Callable
|
||||
import numpy as np
|
||||
import dask.array as da
|
||||
from zarr.core import Array as ZarrArray
|
||||
from numpydantic.interface import Interface
|
||||
|
||||
|
||||
|
@ -35,8 +38,31 @@ def test_interface_to_numpy_array(all_interfaces):
|
|||
_ = np.array(all_interfaces.array)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_interface_dump_json(all_interfaces):
|
||||
"""
|
||||
All interfaces should be able to dump to json
|
||||
"""
|
||||
all_interfaces.model_dump_json()
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("round_trip", [True, False])
|
||||
def test_interface_roundtrip_json(all_interfaces, round_trip):
|
||||
"""
|
||||
All interfaces should be able to roundtrip to and from json
|
||||
"""
|
||||
json = all_interfaces.model_dump_json(round_trip=round_trip)
|
||||
model = all_interfaces.model_validate_json(json)
|
||||
if round_trip:
|
||||
assert type(model.array) is type(all_interfaces.array)
|
||||
if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)):
|
||||
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
||||
elif isinstance(all_interfaces.array, da.Array):
|
||||
assert np.all(da.equal(model.array, all_interfaces.array))
|
||||
else:
|
||||
assert model.array == all_interfaces.array
|
||||
|
||||
assert model.array.dtype == all_interfaces.array.dtype
|
||||
else:
|
||||
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
||||
|
|
|
@ -123,7 +123,10 @@ def test_zarr_array_path_from_iterable(zarr_array):
|
|||
assert apath.path == inner_path
|
||||
|
||||
|
||||
def test_zarr_to_json(store, model_blank):
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("dump_array", [True, False])
|
||||
@pytest.mark.parametrize("roundtrip", [True, False])
|
||||
def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
|
||||
expected_fields = (
|
||||
"Type",
|
||||
"Data type",
|
||||
|
@ -137,17 +140,22 @@ def test_zarr_to_json(store, model_blank):
|
|||
|
||||
array = zarr.array(lol_array, store=store)
|
||||
instance = model_blank(array=array)
|
||||
as_json = json.loads(instance.model_dump_json())["array"]
|
||||
assert "array" not in as_json
|
||||
for field in expected_fields:
|
||||
assert field in as_json
|
||||
assert len(as_json["hexdigest"]) == 40
|
||||
|
||||
# dump the array itself too
|
||||
as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[
|
||||
"array"
|
||||
]
|
||||
for field in expected_fields:
|
||||
assert field in as_json
|
||||
assert len(as_json["hexdigest"]) == 40
|
||||
context = {"zarr_dump_array": dump_array}
|
||||
as_json = json.loads(
|
||||
instance.model_dump_json(round_trip=roundtrip, context=context)
|
||||
)["array"]
|
||||
|
||||
if roundtrip:
|
||||
if dump_array:
|
||||
assert as_json["array"] == lol_array
|
||||
else:
|
||||
if as_json.get("file", False):
|
||||
assert "array" not in as_json
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in as_json["info"]
|
||||
assert len(as_json["info"]["hexdigest"]) == 40
|
||||
|
||||
else:
|
||||
assert as_json == lol_array
|
||||
|
|
|
@ -15,6 +15,7 @@ from numpydantic import dtype
|
|||
from numpydantic.dtype import Number
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
def test_ndarray_type():
|
||||
class Model(BaseModel):
|
||||
array: NDArray[Shape["2 x, * y"], Number]
|
||||
|
@ -40,6 +41,7 @@ def test_ndarray_type():
|
|||
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
def test_schema_unsupported_type():
|
||||
"""
|
||||
Complex numbers should just be made with an `any` schema
|
||||
|
@ -55,6 +57,7 @@ def test_schema_unsupported_type():
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
def test_schema_tuple():
|
||||
"""
|
||||
Types specified as tupled should have their schemas as a union
|
||||
|
@ -72,6 +75,7 @@ def test_schema_tuple():
|
|||
assert all([i["minimum"] == 0 for i in conditions])
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
def test_schema_number():
|
||||
"""
|
||||
np.numeric should just be the float schema
|
||||
|
@ -164,6 +168,7 @@ def test_ndarray_coercion():
|
|||
amod = Model(array=["a", "b", "c"])
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_ndarray_serialize():
|
||||
"""
|
||||
Arrays should be dumped to a list when using json, but kept as ndarray otherwise
|
||||
|
@ -188,6 +193,7 @@ _json_schema_types = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
def test_json_schema_basic(array_model):
|
||||
"""
|
||||
NDArray types should correctly generate a list of lists JSON schema
|
||||
|
@ -210,6 +216,8 @@ def test_json_schema_basic(array_model):
|
|||
assert inner["items"]["type"] == "number"
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
|
||||
def test_json_schema_dtype_single(dtype, array_model):
|
||||
"""
|
||||
|
@ -240,6 +248,7 @@ def test_json_schema_dtype_single(dtype, array_model):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,expected",
|
||||
[
|
||||
|
@ -266,6 +275,8 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
|
|||
assert inner_type["type"] == expected
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
def test_json_schema_dtype_model():
|
||||
"""
|
||||
Pydantic models can be used in arrays as dtypes
|
||||
|
@ -314,6 +325,8 @@ def _recursive_array(schema):
|
|||
assert any_of[1]["minimum"] == 0
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
@pytest.mark.json_schema
|
||||
def test_json_schema_ellipsis():
|
||||
"""
|
||||
NDArray types should create a recursive JSON schema for any-shaped arrays
|
||||
|
|
Loading…
Reference in a new issue