add json_model abstract class attr, make deserialize validation method

This commit is contained in:
sneakers-the-rat 2024-09-21 21:58:11 -07:00
parent 705de53838
commit 9026eb700f
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
7 changed files with 53 additions and 33 deletions

View file

@ -46,6 +46,11 @@ for interfaces to implement custom behavior that matches the array format.
{meth}`.Interface.validate` calls the following methods, in order:
A method to deserialize the array dumped with a {func}`~pydantic.BaseModel.model_dump_json`
with `round_trip = True` (see [serialization](./serialization.md))
- {meth}`.Interface.deserialize`
An initial hook for modifying the input data before validation, eg.
if it needs to be coerced or wrapped in some proxy class. This method
should accept all and only the types specified in that interface's

View file

@ -54,6 +54,7 @@ class DaskInterface(Interface):
name = "dask"
input_types = (DaskArray, dict)
return_type = DaskArray
json_model = DaskJsonDict
@classmethod
def check(cls, array: Any) -> bool:
@ -69,18 +70,6 @@ class DaskInterface(Interface):
else:
return False
def before_validation(self, array: Any) -> DaskArray:
"""
If given a dict (like that from ``model_dump_json(round_trip=True)`` ),
re-cast to dask array
"""
if isinstance(array, dict):
array = DaskJsonDict(**array).to_array_input()
elif isinstance(array, DaskJsonDict):
array = array.to_array_input()
return array
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""Dask arrays require a compute() call to retrieve a single value"""
return type(array.ravel()[0].compute())

View file

@ -278,6 +278,7 @@ class H5Interface(Interface):
name = "hdf5"
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
return_type = H5Proxy
json_model = H5JsonDict
@classmethod
def enabled(cls) -> bool:
@ -326,11 +327,6 @@ class H5Interface(Interface):
def before_validation(self, array: Any) -> NDArrayType:
"""Create an :class:`.H5Proxy` to use throughout validation"""
if isinstance(array, dict):
array = H5JsonDict(**array).to_array_input()
elif isinstance(array, H5JsonDict):
array = array.to_array_input()
if isinstance(array, H5ArrayPath):
array = H5Proxy.from_h5array(h5array=array)
elif isinstance(array, H5Proxy):

View file

@ -21,6 +21,9 @@ from numpydantic.shape import check_shape
from numpydantic.types import DtypeType, NDArrayType, ShapeType
T = TypeVar("T", bound=NDArrayType)
U = TypeVar("U", bound="JsonDict")
V = TypeVar("V") # input type
W = TypeVar("W") # Any type in handle_input
class InterfaceMark(TypedDict):
@ -39,7 +42,7 @@ class JsonDict(BaseModel):
type: str
@abstractmethod
def to_array_input(self) -> Any:
def to_array_input(self) -> V:
"""
Convert this roundtrip specifier to the relevant input class
(one of the ``input_types`` of an interface).
@ -66,6 +69,20 @@ class JsonDict(BaseModel):
raise e
return False
@classmethod
def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
"""
Handle input that is the json serialized roundtrip version
(from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
converting it to the input format with :meth:`.JsonDict.to_array_input`
or passing it through if not applicable
"""
if isinstance(value, dict):
value = cls(**value).to_array_input()
elif isinstance(value, cls):
value = value.to_array_input()
return value
class Interface(ABC, Generic[T]):
"""
@ -86,6 +103,7 @@ class Interface(ABC, Generic[T]):
Calls the methods, in order:
* array = :meth:`.deserialize` (array)
* array = :meth:`.before_validation` (array)
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
override if eg. the dtype is not contained in ``array.dtype``
@ -120,6 +138,8 @@ class Interface(ABC, Generic[T]):
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
of :class:`.InterfaceError` )
"""
array = self.deserialize(array)
array = self.before_validation(array)
dtype = self.get_dtype(array)
@ -135,6 +155,19 @@ class Interface(ABC, Generic[T]):
return array
def deserialize(self, array: Any) -> Union[V, Any]:
"""
If given a JSON serialized version of the array,
deserialize it first
Args:
array:
Returns:
"""
return self.json_model.handle_input(array)
def before_validation(self, array: Any) -> NDArrayType:
"""
Optional step pre-validation that coerces the input into a type that can be
@ -270,6 +303,14 @@ class Interface(ABC, Generic[T]):
Short name for this interface
"""
@property
@abstractmethod
def json_model(self) -> JsonDict:
"""
The :class:`.JsonDict` model used for roundtripping
JSON serialization
"""
@classmethod
@abstractmethod
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:

View file

@ -44,6 +44,7 @@ class NumpyInterface(Interface):
name = "numpy"
input_types = (ndarray, list)
return_type = ndarray
json_model = NumpyJsonDict
priority = -999
"""
The numpy interface is usually the interface of last resort.
@ -74,11 +75,6 @@ class NumpyInterface(Interface):
Coerce to an ndarray. We have already checked if coercion is possible
in :meth:`.check`
"""
if isinstance(array, dict):
array = NumpyJsonDict(**array).to_array_input()
elif isinstance(array, NumpyJsonDict):
array = array.to_array_input()
if not isinstance(array, ndarray):
array = np.array(array)
return array

View file

@ -221,6 +221,7 @@ class VideoInterface(Interface):
name = "video"
input_types = (str, Path, VideoCapture, VideoProxy)
return_type = VideoProxy
json_model = VideoJsonDict
@classmethod
def enabled(cls) -> bool:
@ -252,11 +253,7 @@ class VideoInterface(Interface):
def before_validation(self, array: Any) -> VideoProxy:
"""Get a :class:`.VideoProxy` object for this video"""
if isinstance(array, dict):
proxy = VideoJsonDict(**array).to_array_input()
elif isinstance(array, VideoJsonDict):
proxy = array.to_array_input()
elif isinstance(array, VideoCapture):
if isinstance(array, VideoCapture):
proxy = VideoProxy(video=array)
elif isinstance(array, VideoProxy):
proxy = array

View file

@ -85,6 +85,7 @@ class ZarrInterface(Interface):
name = "zarr"
input_types = (Path, ZarrArray, ZarrArrayPath)
return_type = ZarrArray
json_model = ZarrJsonDict
@classmethod
def enabled(cls) -> bool:
@ -95,11 +96,6 @@ class ZarrInterface(Interface):
def _get_array(
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
) -> ZarrArray:
if isinstance(array, dict):
array = ZarrJsonDict(**array).to_array_input()
elif isinstance(array, ZarrJsonDict):
array = array.to_array_input()
if isinstance(array, ZarrArray):
return array