mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 10:44:28 +00:00
add json_model abstract class attr, make deserialize validation method
This commit is contained in:
parent
705de53838
commit
9026eb700f
7 changed files with 53 additions and 33 deletions
|
@ -46,6 +46,11 @@ for interfaces to implement custom behavior that matches the array format.
|
||||||
|
|
||||||
{meth}`.Interface.validate` calls the following methods, in order:
|
{meth}`.Interface.validate` calls the following methods, in order:
|
||||||
|
|
||||||
|
A method to deserialize the array dumped with a {func}`~pydantic.BaseModel.model_dump_json`
|
||||||
|
with `round_trip = True` (see [serialization](./serialization.md))
|
||||||
|
|
||||||
|
- {meth}`.Interface.deserialize`
|
||||||
|
|
||||||
An initial hook for modifying the input data before validation, eg.
|
An initial hook for modifying the input data before validation, eg.
|
||||||
if it needs to be coerced or wrapped in some proxy class. This method
|
if it needs to be coerced or wrapped in some proxy class. This method
|
||||||
should accept all and only the types specified in that interface's
|
should accept all and only the types specified in that interface's
|
||||||
|
|
|
@ -54,6 +54,7 @@ class DaskInterface(Interface):
|
||||||
name = "dask"
|
name = "dask"
|
||||||
input_types = (DaskArray, dict)
|
input_types = (DaskArray, dict)
|
||||||
return_type = DaskArray
|
return_type = DaskArray
|
||||||
|
json_model = DaskJsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check(cls, array: Any) -> bool:
|
def check(cls, array: Any) -> bool:
|
||||||
|
@ -69,18 +70,6 @@ class DaskInterface(Interface):
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def before_validation(self, array: Any) -> DaskArray:
|
|
||||||
"""
|
|
||||||
If given a dict (like that from ``model_dump_json(round_trip=True)`` ),
|
|
||||||
re-cast to dask array
|
|
||||||
"""
|
|
||||||
if isinstance(array, dict):
|
|
||||||
array = DaskJsonDict(**array).to_array_input()
|
|
||||||
elif isinstance(array, DaskJsonDict):
|
|
||||||
array = array.to_array_input()
|
|
||||||
|
|
||||||
return array
|
|
||||||
|
|
||||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||||
"""Dask arrays require a compute() call to retrieve a single value"""
|
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||||
return type(array.ravel()[0].compute())
|
return type(array.ravel()[0].compute())
|
||||||
|
|
|
@ -278,6 +278,7 @@ class H5Interface(Interface):
|
||||||
name = "hdf5"
|
name = "hdf5"
|
||||||
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
|
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
|
||||||
return_type = H5Proxy
|
return_type = H5Proxy
|
||||||
|
json_model = H5JsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
|
@ -326,11 +327,6 @@ class H5Interface(Interface):
|
||||||
|
|
||||||
def before_validation(self, array: Any) -> NDArrayType:
|
def before_validation(self, array: Any) -> NDArrayType:
|
||||||
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
||||||
if isinstance(array, dict):
|
|
||||||
array = H5JsonDict(**array).to_array_input()
|
|
||||||
elif isinstance(array, H5JsonDict):
|
|
||||||
array = array.to_array_input()
|
|
||||||
|
|
||||||
if isinstance(array, H5ArrayPath):
|
if isinstance(array, H5ArrayPath):
|
||||||
array = H5Proxy.from_h5array(h5array=array)
|
array = H5Proxy.from_h5array(h5array=array)
|
||||||
elif isinstance(array, H5Proxy):
|
elif isinstance(array, H5Proxy):
|
||||||
|
|
|
@ -21,6 +21,9 @@ from numpydantic.shape import check_shape
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
|
||||||
T = TypeVar("T", bound=NDArrayType)
|
T = TypeVar("T", bound=NDArrayType)
|
||||||
|
U = TypeVar("U", bound="JsonDict")
|
||||||
|
V = TypeVar("V") # input type
|
||||||
|
W = TypeVar("W") # Any type in handle_input
|
||||||
|
|
||||||
|
|
||||||
class InterfaceMark(TypedDict):
|
class InterfaceMark(TypedDict):
|
||||||
|
@ -39,7 +42,7 @@ class JsonDict(BaseModel):
|
||||||
type: str
|
type: str
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_array_input(self) -> Any:
|
def to_array_input(self) -> V:
|
||||||
"""
|
"""
|
||||||
Convert this roundtrip specifier to the relevant input class
|
Convert this roundtrip specifier to the relevant input class
|
||||||
(one of the ``input_types`` of an interface).
|
(one of the ``input_types`` of an interface).
|
||||||
|
@ -66,6 +69,20 @@ class JsonDict(BaseModel):
|
||||||
raise e
|
raise e
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
|
||||||
|
"""
|
||||||
|
Handle input that is the json serialized roundtrip version
|
||||||
|
(from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
|
||||||
|
converting it to the input format with :meth:`.JsonDict.to_array_input`
|
||||||
|
or passing it through if not applicable
|
||||||
|
"""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = cls(**value).to_array_input()
|
||||||
|
elif isinstance(value, cls):
|
||||||
|
value = value.to_array_input()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class Interface(ABC, Generic[T]):
|
class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
|
@ -86,6 +103,7 @@ class Interface(ABC, Generic[T]):
|
||||||
|
|
||||||
Calls the methods, in order:
|
Calls the methods, in order:
|
||||||
|
|
||||||
|
* array = :meth:`.deserialize` (array)
|
||||||
* array = :meth:`.before_validation` (array)
|
* array = :meth:`.before_validation` (array)
|
||||||
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
||||||
override if eg. the dtype is not contained in ``array.dtype``
|
override if eg. the dtype is not contained in ``array.dtype``
|
||||||
|
@ -120,6 +138,8 @@ class Interface(ABC, Generic[T]):
|
||||||
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
||||||
of :class:`.InterfaceError` )
|
of :class:`.InterfaceError` )
|
||||||
"""
|
"""
|
||||||
|
array = self.deserialize(array)
|
||||||
|
|
||||||
array = self.before_validation(array)
|
array = self.before_validation(array)
|
||||||
|
|
||||||
dtype = self.get_dtype(array)
|
dtype = self.get_dtype(array)
|
||||||
|
@ -135,6 +155,19 @@ class Interface(ABC, Generic[T]):
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
def deserialize(self, array: Any) -> Union[V, Any]:
|
||||||
|
"""
|
||||||
|
If given a JSON serialized version of the array,
|
||||||
|
deserialize it first
|
||||||
|
|
||||||
|
Args:
|
||||||
|
array:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.json_model.handle_input(array)
|
||||||
|
|
||||||
def before_validation(self, array: Any) -> NDArrayType:
|
def before_validation(self, array: Any) -> NDArrayType:
|
||||||
"""
|
"""
|
||||||
Optional step pre-validation that coerces the input into a type that can be
|
Optional step pre-validation that coerces the input into a type that can be
|
||||||
|
@ -270,6 +303,14 @@ class Interface(ABC, Generic[T]):
|
||||||
Short name for this interface
|
Short name for this interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def json_model(self) -> JsonDict:
|
||||||
|
"""
|
||||||
|
The :class:`.JsonDict` model used for roundtripping
|
||||||
|
JSON serialization
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
|
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
|
||||||
|
|
|
@ -44,6 +44,7 @@ class NumpyInterface(Interface):
|
||||||
name = "numpy"
|
name = "numpy"
|
||||||
input_types = (ndarray, list)
|
input_types = (ndarray, list)
|
||||||
return_type = ndarray
|
return_type = ndarray
|
||||||
|
json_model = NumpyJsonDict
|
||||||
priority = -999
|
priority = -999
|
||||||
"""
|
"""
|
||||||
The numpy interface is usually the interface of last resort.
|
The numpy interface is usually the interface of last resort.
|
||||||
|
@ -74,11 +75,6 @@ class NumpyInterface(Interface):
|
||||||
Coerce to an ndarray. We have already checked if coercion is possible
|
Coerce to an ndarray. We have already checked if coercion is possible
|
||||||
in :meth:`.check`
|
in :meth:`.check`
|
||||||
"""
|
"""
|
||||||
if isinstance(array, dict):
|
|
||||||
array = NumpyJsonDict(**array).to_array_input()
|
|
||||||
elif isinstance(array, NumpyJsonDict):
|
|
||||||
array = array.to_array_input()
|
|
||||||
|
|
||||||
if not isinstance(array, ndarray):
|
if not isinstance(array, ndarray):
|
||||||
array = np.array(array)
|
array = np.array(array)
|
||||||
return array
|
return array
|
||||||
|
|
|
@ -221,6 +221,7 @@ class VideoInterface(Interface):
|
||||||
name = "video"
|
name = "video"
|
||||||
input_types = (str, Path, VideoCapture, VideoProxy)
|
input_types = (str, Path, VideoCapture, VideoProxy)
|
||||||
return_type = VideoProxy
|
return_type = VideoProxy
|
||||||
|
json_model = VideoJsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
|
@ -252,11 +253,7 @@ class VideoInterface(Interface):
|
||||||
|
|
||||||
def before_validation(self, array: Any) -> VideoProxy:
|
def before_validation(self, array: Any) -> VideoProxy:
|
||||||
"""Get a :class:`.VideoProxy` object for this video"""
|
"""Get a :class:`.VideoProxy` object for this video"""
|
||||||
if isinstance(array, dict):
|
if isinstance(array, VideoCapture):
|
||||||
proxy = VideoJsonDict(**array).to_array_input()
|
|
||||||
elif isinstance(array, VideoJsonDict):
|
|
||||||
proxy = array.to_array_input()
|
|
||||||
elif isinstance(array, VideoCapture):
|
|
||||||
proxy = VideoProxy(video=array)
|
proxy = VideoProxy(video=array)
|
||||||
elif isinstance(array, VideoProxy):
|
elif isinstance(array, VideoProxy):
|
||||||
proxy = array
|
proxy = array
|
||||||
|
|
|
@ -85,6 +85,7 @@ class ZarrInterface(Interface):
|
||||||
name = "zarr"
|
name = "zarr"
|
||||||
input_types = (Path, ZarrArray, ZarrArrayPath)
|
input_types = (Path, ZarrArray, ZarrArrayPath)
|
||||||
return_type = ZarrArray
|
return_type = ZarrArray
|
||||||
|
json_model = ZarrJsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
|
@ -95,11 +96,6 @@ class ZarrInterface(Interface):
|
||||||
def _get_array(
|
def _get_array(
|
||||||
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
|
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
|
||||||
) -> ZarrArray:
|
) -> ZarrArray:
|
||||||
if isinstance(array, dict):
|
|
||||||
array = ZarrJsonDict(**array).to_array_input()
|
|
||||||
elif isinstance(array, ZarrJsonDict):
|
|
||||||
array = array.to_array_input()
|
|
||||||
|
|
||||||
if isinstance(array, ZarrArray):
|
if isinstance(array, ZarrArray):
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue