mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 10:44:28 +00:00
add json_model abstract class attr, make deserialize validation method
This commit is contained in:
parent
705de53838
commit
9026eb700f
7 changed files with 53 additions and 33 deletions
|
@ -46,6 +46,11 @@ for interfaces to implement custom behavior that matches the array format.
|
|||
|
||||
{meth}`.Interface.validate` calls the following methods, in order:
|
||||
|
||||
A method to deserialize the array dumped with a {func}`~pydantic.BaseModel.model_dump_json`
|
||||
with `round_trip = True` (see [serialization](./serialization.md))
|
||||
|
||||
- {meth}`.Interface.deserialize`
|
||||
|
||||
An initial hook for modifying the input data before validation, eg.
|
||||
if it needs to be coerced or wrapped in some proxy class. This method
|
||||
should accept all and only the types specified in that interface's
|
||||
|
|
|
@ -54,6 +54,7 @@ class DaskInterface(Interface):
|
|||
name = "dask"
|
||||
input_types = (DaskArray, dict)
|
||||
return_type = DaskArray
|
||||
json_model = DaskJsonDict
|
||||
|
||||
@classmethod
|
||||
def check(cls, array: Any) -> bool:
|
||||
|
@ -69,18 +70,6 @@ class DaskInterface(Interface):
|
|||
else:
|
||||
return False
|
||||
|
||||
def before_validation(self, array: Any) -> DaskArray:
|
||||
"""
|
||||
If given a dict (like that from ``model_dump_json(round_trip=True)`` ),
|
||||
re-cast to dask array
|
||||
"""
|
||||
if isinstance(array, dict):
|
||||
array = DaskJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, DaskJsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
return array
|
||||
|
||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||
return type(array.ravel()[0].compute())
|
||||
|
|
|
@ -278,6 +278,7 @@ class H5Interface(Interface):
|
|||
name = "hdf5"
|
||||
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
|
||||
return_type = H5Proxy
|
||||
json_model = H5JsonDict
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
|
@ -326,11 +327,6 @@ class H5Interface(Interface):
|
|||
|
||||
def before_validation(self, array: Any) -> NDArrayType:
|
||||
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
||||
if isinstance(array, dict):
|
||||
array = H5JsonDict(**array).to_array_input()
|
||||
elif isinstance(array, H5JsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
if isinstance(array, H5ArrayPath):
|
||||
array = H5Proxy.from_h5array(h5array=array)
|
||||
elif isinstance(array, H5Proxy):
|
||||
|
|
|
@ -21,6 +21,9 @@ from numpydantic.shape import check_shape
|
|||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
T = TypeVar("T", bound=NDArrayType)
|
||||
U = TypeVar("U", bound="JsonDict")
|
||||
V = TypeVar("V") # input type
|
||||
W = TypeVar("W") # Any type in handle_input
|
||||
|
||||
|
||||
class InterfaceMark(TypedDict):
|
||||
|
@ -39,7 +42,7 @@ class JsonDict(BaseModel):
|
|||
type: str
|
||||
|
||||
@abstractmethod
|
||||
def to_array_input(self) -> Any:
|
||||
def to_array_input(self) -> V:
|
||||
"""
|
||||
Convert this roundtrip specifier to the relevant input class
|
||||
(one of the ``input_types`` of an interface).
|
||||
|
@ -66,6 +69,20 @@ class JsonDict(BaseModel):
|
|||
raise e
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
|
||||
"""
|
||||
Handle input that is the json serialized roundtrip version
|
||||
(from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
|
||||
converting it to the input format with :meth:`.JsonDict.to_array_input`
|
||||
or passing it through if not applicable
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
value = cls(**value).to_array_input()
|
||||
elif isinstance(value, cls):
|
||||
value = value.to_array_input()
|
||||
return value
|
||||
|
||||
|
||||
class Interface(ABC, Generic[T]):
|
||||
"""
|
||||
|
@ -86,6 +103,7 @@ class Interface(ABC, Generic[T]):
|
|||
|
||||
Calls the methods, in order:
|
||||
|
||||
* array = :meth:`.deserialize` (array)
|
||||
* array = :meth:`.before_validation` (array)
|
||||
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
||||
override if eg. the dtype is not contained in ``array.dtype``
|
||||
|
@ -120,6 +138,8 @@ class Interface(ABC, Generic[T]):
|
|||
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
||||
of :class:`.InterfaceError` )
|
||||
"""
|
||||
array = self.deserialize(array)
|
||||
|
||||
array = self.before_validation(array)
|
||||
|
||||
dtype = self.get_dtype(array)
|
||||
|
@ -135,6 +155,19 @@ class Interface(ABC, Generic[T]):
|
|||
|
||||
return array
|
||||
|
||||
def deserialize(self, array: Any) -> Union[V, Any]:
|
||||
"""
|
||||
If given a JSON serialized version of the array,
|
||||
deserialize it first
|
||||
|
||||
Args:
|
||||
array:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return self.json_model.handle_input(array)
|
||||
|
||||
def before_validation(self, array: Any) -> NDArrayType:
|
||||
"""
|
||||
Optional step pre-validation that coerces the input into a type that can be
|
||||
|
@ -270,6 +303,14 @@ class Interface(ABC, Generic[T]):
|
|||
Short name for this interface
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def json_model(self) -> JsonDict:
|
||||
"""
|
||||
The :class:`.JsonDict` model used for roundtripping
|
||||
JSON serialization
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
|
||||
|
|
|
@ -44,6 +44,7 @@ class NumpyInterface(Interface):
|
|||
name = "numpy"
|
||||
input_types = (ndarray, list)
|
||||
return_type = ndarray
|
||||
json_model = NumpyJsonDict
|
||||
priority = -999
|
||||
"""
|
||||
The numpy interface is usually the interface of last resort.
|
||||
|
@ -74,11 +75,6 @@ class NumpyInterface(Interface):
|
|||
Coerce to an ndarray. We have already checked if coercion is possible
|
||||
in :meth:`.check`
|
||||
"""
|
||||
if isinstance(array, dict):
|
||||
array = NumpyJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, NumpyJsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
if not isinstance(array, ndarray):
|
||||
array = np.array(array)
|
||||
return array
|
||||
|
|
|
@ -221,6 +221,7 @@ class VideoInterface(Interface):
|
|||
name = "video"
|
||||
input_types = (str, Path, VideoCapture, VideoProxy)
|
||||
return_type = VideoProxy
|
||||
json_model = VideoJsonDict
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
|
@ -252,11 +253,7 @@ class VideoInterface(Interface):
|
|||
|
||||
def before_validation(self, array: Any) -> VideoProxy:
|
||||
"""Get a :class:`.VideoProxy` object for this video"""
|
||||
if isinstance(array, dict):
|
||||
proxy = VideoJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, VideoJsonDict):
|
||||
proxy = array.to_array_input()
|
||||
elif isinstance(array, VideoCapture):
|
||||
if isinstance(array, VideoCapture):
|
||||
proxy = VideoProxy(video=array)
|
||||
elif isinstance(array, VideoProxy):
|
||||
proxy = array
|
||||
|
|
|
@ -85,6 +85,7 @@ class ZarrInterface(Interface):
|
|||
name = "zarr"
|
||||
input_types = (Path, ZarrArray, ZarrArrayPath)
|
||||
return_type = ZarrArray
|
||||
json_model = ZarrJsonDict
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
|
@ -95,11 +96,6 @@ class ZarrInterface(Interface):
|
|||
def _get_array(
|
||||
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
|
||||
) -> ZarrArray:
|
||||
if isinstance(array, dict):
|
||||
array = ZarrJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, ZarrJsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
if isinstance(array, ZarrArray):
|
||||
return array
|
||||
|
||||
|
|
Loading…
Reference in a new issue