diff --git a/docs/interfaces.md b/docs/interfaces.md index c4d873d..b26ebba 100644 --- a/docs/interfaces.md +++ b/docs/interfaces.md @@ -46,6 +46,11 @@ for interfaces to implement custom behavior that matches the array format. {meth}`.Interface.validate` calls the following methods, in order: +A method to deserialize the array dumped with a {func}`~pydantic.BaseModel.model_dump_json` +with `round_trip = True` (see [serialization](./serialization.md)) + +- {meth}`.Interface.deserialize` + An initial hook for modifying the input data before validation, eg. if it needs to be coerced or wrapped in some proxy class. This method should accept all and only the types specified in that interface's diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index 960c94b..bc12a13 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -54,6 +54,7 @@ class DaskInterface(Interface): name = "dask" input_types = (DaskArray, dict) return_type = DaskArray + json_model = DaskJsonDict @classmethod def check(cls, array: Any) -> bool: @@ -69,18 +70,6 @@ class DaskInterface(Interface): else: return False - def before_validation(self, array: Any) -> DaskArray: - """ - If given a dict (like that from ``model_dump_json(round_trip=True)`` ), - re-cast to dask array - """ - if isinstance(array, dict): - array = DaskJsonDict(**array).to_array_input() - elif isinstance(array, DaskJsonDict): - array = array.to_array_input() - - return array - def get_object_dtype(self, array: NDArrayType) -> DtypeType: """Dask arrays require a compute() call to retrieve a single value""" return type(array.ravel()[0].compute()) diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 4ce16ce..67b9899 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -278,6 +278,7 @@ class H5Interface(Interface): name = "hdf5" input_types = (H5ArrayPath, H5Arraylike, H5Proxy) return_type = H5Proxy + json_model = H5JsonDict @classmethod def enabled(cls) -> bool: @@ -326,11 +327,6 @@ class H5Interface(Interface): def before_validation(self, array: Any) -> NDArrayType: """Create an :class:`.H5Proxy` to use throughout validation""" - if isinstance(array, dict): - array = H5JsonDict(**array).to_array_input() - elif isinstance(array, H5JsonDict): - array = array.to_array_input() - if isinstance(array, H5ArrayPath): array = H5Proxy.from_h5array(h5array=array) elif isinstance(array, H5Proxy): diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index ebcb950..2cb61f4 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -21,6 +21,9 @@ from numpydantic.shape import check_shape from numpydantic.types import DtypeType, NDArrayType, ShapeType T = TypeVar("T", bound=NDArrayType) +U = TypeVar("U", bound="JsonDict") +V = TypeVar("V") # input type +W = TypeVar("W") # Any type in handle_input class InterfaceMark(TypedDict): @@ -39,7 +42,7 @@ class JsonDict(BaseModel): type: str @abstractmethod - def to_array_input(self) -> Any: + def to_array_input(self) -> V: """ Convert this roundtrip specifier to the relevant input class (one of the ``input_types`` of an interface). @@ -66,6 +69,20 @@ class JsonDict(BaseModel): raise e return False + @classmethod + def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]: + """ + Handle input that is the json serialized roundtrip version + (from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``) + converting it to the input format with :meth:`.JsonDict.to_array_input` + or passing it through if not applicable + """ + if isinstance(value, dict): + value = cls(**value).to_array_input() + elif isinstance(value, cls): + value = value.to_array_input() + return value + class Interface(ABC, Generic[T]): """ @@ -86,6 +103,7 @@ class Interface(ABC, Generic[T]): Calls the methods, in order: + * array = :meth:`.deserialize` (array) * array = :meth:`.before_validation` (array) * dtype = :meth:`.get_dtype` (array) - get the dtype from the array, override if eg. the dtype is not contained in ``array.dtype`` @@ -120,6 +138,8 @@ class Interface(ABC, Generic[T]): :class:`.DtypeError` and :class:`.ShapeError` (both of which are children of :class:`.InterfaceError` ) """ + array = self.deserialize(array) + array = self.before_validation(array) dtype = self.get_dtype(array) @@ -135,6 +155,19 @@ class Interface(ABC, Generic[T]): return array + def deserialize(self, array: Any) -> Union[V, Any]: + """ + If given a JSON serialized version of the array, + deserialize it first + + Args: + array: + + Returns: + + """ + return self.json_model.handle_input(array) + def before_validation(self, array: Any) -> NDArrayType: """ Optional step pre-validation that coerces the input into a type that can be @@ -270,6 +303,14 @@ class Interface(ABC, Generic[T]): Short name for this interface """ + @property + @abstractmethod + def json_model(self) -> JsonDict: + """ + The :class:`.JsonDict` model used for roundtripping + JSON serialization + """ + @classmethod @abstractmethod def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]: diff --git a/src/numpydantic/interface/numpy.py b/src/numpydantic/interface/numpy.py index a1e0b94..ad97474 100644 --- a/src/numpydantic/interface/numpy.py +++ b/src/numpydantic/interface/numpy.py @@ -44,6 +44,7 @@ class NumpyInterface(Interface): name = "numpy" input_types = (ndarray, list) return_type = ndarray + json_model = NumpyJsonDict priority = -999 """ The numpy interface is usually the interface of last resort. @@ -74,11 +75,6 @@ class NumpyInterface(Interface): Coerce to an ndarray. We have already checked if coercion is possible in :meth:`.check` """ - if isinstance(array, dict): - array = NumpyJsonDict(**array).to_array_input() - elif isinstance(array, NumpyJsonDict): - array = array.to_array_input() - if not isinstance(array, ndarray): array = np.array(array) return array diff --git a/src/numpydantic/interface/video.py b/src/numpydantic/interface/video.py index 23f940d..b214a74 100644 --- a/src/numpydantic/interface/video.py +++ b/src/numpydantic/interface/video.py @@ -221,6 +221,7 @@ class VideoInterface(Interface): name = "video" input_types = (str, Path, VideoCapture, VideoProxy) return_type = VideoProxy + json_model = VideoJsonDict @classmethod def enabled(cls) -> bool: @@ -252,11 +253,7 @@ class VideoInterface(Interface): def before_validation(self, array: Any) -> VideoProxy: """Get a :class:`.VideoProxy` object for this video""" - if isinstance(array, dict): - proxy = VideoJsonDict(**array).to_array_input() - elif isinstance(array, VideoJsonDict): - proxy = array.to_array_input() - elif isinstance(array, VideoCapture): + if isinstance(array, VideoCapture): proxy = VideoProxy(video=array) elif isinstance(array, VideoProxy): proxy = array diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index d79491c..41cad03 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -85,6 +85,7 @@ class ZarrInterface(Interface): name = "zarr" input_types = (Path, ZarrArray, ZarrArrayPath) return_type = ZarrArray + json_model = ZarrJsonDict @classmethod def enabled(cls) -> bool: @@ -95,11 +96,6 @@ class ZarrInterface(Interface): def _get_array( array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence] ) -> ZarrArray: - if isinstance(array, dict): - array = ZarrJsonDict(**array).to_array_input() - elif isinstance(array, ZarrJsonDict): - array = array.to_array_input() - if isinstance(array, ZarrArray): return array