use pydantic model not dataclass

This commit is contained in:
sneakers-the-rat 2024-09-21 21:24:32 -07:00
parent 70bf254ddd
commit 74f03b10bf
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
7 changed files with 16 additions and 33 deletions

View file

@ -2,7 +2,6 @@
Interface for Dask arrays Interface for Dask arrays
""" """
from dataclasses import dataclass
from typing import Any, Iterable, Literal, Optional from typing import Any, Iterable, Literal, Optional
import numpy as np import numpy as np
@ -25,7 +24,6 @@ def _as_tuple(a_list: list | Any) -> tuple:
) )
@dataclass(kw_only=True)
class DaskJsonDict(JsonDict): class DaskJsonDict(JsonDict):
""" """
Round-trip json serialized form of a dask array Round-trip json serialized form of a dask array
@ -78,6 +76,8 @@ class DaskInterface(Interface):
""" """
if isinstance(array, dict): if isinstance(array, dict):
array = DaskJsonDict(**array).to_array_input() array = DaskJsonDict(**array).to_array_input()
elif isinstance(array, DaskJsonDict):
array = array.to_array_input()
return array return array

View file

@ -40,7 +40,6 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
""" """
import sys import sys
from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
@ -77,7 +76,6 @@ class H5ArrayPath(NamedTuple):
"""Refer to a specific field within a compound dtype""" """Refer to a specific field within a compound dtype"""
@dataclass
class H5JsonDict(JsonDict): class H5JsonDict(JsonDict):
"""Round-trip Json-able version of an HDF5 dataset""" """Round-trip Json-able version of an HDF5 dataset"""
@ -88,7 +86,7 @@ class H5JsonDict(JsonDict):
def to_array_input(self) -> H5ArrayPath: def to_array_input(self) -> H5ArrayPath:
"""Construct an :class:`.H5ArrayPath`""" """Construct an :class:`.H5ArrayPath`"""
return H5ArrayPath( return H5ArrayPath(
**{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields} **{k: v for k, v in self.model_dump().items() if k in H5ArrayPath._fields}
) )
@ -330,6 +328,8 @@ class H5Interface(Interface):
"""Create an :class:`.H5Proxy` to use throughout validation""" """Create an :class:`.H5Proxy` to use throughout validation"""
if isinstance(array, dict): if isinstance(array, dict):
array = H5JsonDict(**array).to_array_input() array = H5JsonDict(**array).to_array_input()
elif isinstance(array, H5JsonDict):
array = array.to_array_input()
if isinstance(array, H5ArrayPath): if isinstance(array, H5ArrayPath):
array = H5Proxy.from_h5array(h5array=array) array = H5Proxy.from_h5array(h5array=array)

View file

@ -4,13 +4,12 @@ Base Interface metaclass
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from operator import attrgetter from operator import attrgetter
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
import numpy as np import numpy as np
from pydantic import SerializationInfo, TypeAdapter, ValidationError from pydantic import BaseModel, SerializationInfo, ValidationError
from numpydantic.exceptions import ( from numpydantic.exceptions import (
DtypeError, DtypeError,
@ -32,13 +31,9 @@ class InterfaceMark(TypedDict):
version: str version: str
@dataclass(kw_only=True) class JsonDict(BaseModel):
class JsonDict:
""" """
Representation of array when dumped with round_trip == True. Representation of array when dumped with round_trip == True.
Using a dataclass rather than a pydantic model to not tempt
us to use more sophisticated types than can be serialized to json.
""" """
type: str type: str
@ -50,18 +45,6 @@ class JsonDict:
(one of the ``input_types`` of an interface). (one of the ``input_types`` of an interface).
""" """
def to_dict(self) -> dict:
"""
Convenience method for casting dataclass to dict,
removing None-valued items
"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def get_adapter(cls) -> TypeAdapter:
"""Convenience method to get a typeadapter for this class"""
return TypeAdapter(cls)
@classmethod @classmethod
def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool: def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
""" """
@ -75,9 +58,8 @@ class JsonDict:
Returns: Returns:
bool - true if valid, false if not bool - true if valid, false if not
""" """
adapter = cls.get_adapter()
try: try:
_ = adapter.validate_python(val) _ = cls.model_validate(val)
return True return True
except ValidationError as e: except ValidationError as e:
if raise_on_error: if raise_on_error:

View file

@ -2,7 +2,6 @@
Interface to numpy arrays Interface to numpy arrays
""" """
from dataclasses import dataclass
from typing import Any, Literal, Union from typing import Any, Literal, Union
from pydantic import SerializationInfo from pydantic import SerializationInfo
@ -21,7 +20,6 @@ except ImportError: # pragma: no cover
np = None np = None
@dataclass
class NumpyJsonDict(JsonDict): class NumpyJsonDict(JsonDict):
""" """
JSON-able roundtrip representation of numpy array JSON-able roundtrip representation of numpy array
@ -78,6 +76,8 @@ class NumpyInterface(Interface):
""" """
if isinstance(array, dict): if isinstance(array, dict):
array = NumpyJsonDict(**array).to_array_input() array = NumpyJsonDict(**array).to_array_input()
elif isinstance(array, NumpyJsonDict):
array = array.to_array_input()
if not isinstance(array, ndarray): if not isinstance(array, ndarray):
array = np.array(array) array = np.array(array)

View file

@ -2,7 +2,6 @@
Interface to support treating videos like arrays using OpenCV Interface to support treating videos like arrays using OpenCV
""" """
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional, Tuple, Union from typing import Any, Literal, Optional, Tuple, Union
@ -22,7 +21,6 @@ except ImportError: # pragma: no cover
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
@dataclass(kw_only=True)
class VideoJsonDict(JsonDict): class VideoJsonDict(JsonDict):
"""Json-able roundtrip representation of a video file""" """Json-able roundtrip representation of a video file"""
@ -256,6 +254,8 @@ class VideoInterface(Interface):
"""Get a :class:`.VideoProxy` object for this video""" """Get a :class:`.VideoProxy` object for this video"""
if isinstance(array, dict): if isinstance(array, dict):
proxy = VideoJsonDict(**array).to_array_input() proxy = VideoJsonDict(**array).to_array_input()
elif isinstance(array, VideoJsonDict):
proxy = array.to_array_input()
elif isinstance(array, VideoCapture): elif isinstance(array, VideoCapture):
proxy = VideoProxy(video=array) proxy = VideoProxy(video=array)
elif isinstance(array, VideoProxy): elif isinstance(array, VideoProxy):

View file

@ -56,7 +56,6 @@ class ZarrArrayPath:
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath") raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
@dataclass(kw_only=True)
class ZarrJsonDict(JsonDict): class ZarrJsonDict(JsonDict):
"""Round-trip Json-able version of a Zarr Array""" """Round-trip Json-able version of a Zarr Array"""
@ -94,10 +93,12 @@ class ZarrInterface(Interface):
@staticmethod @staticmethod
def _get_array( def _get_array(
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence] array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
) -> ZarrArray: ) -> ZarrArray:
if isinstance(array, dict): if isinstance(array, dict):
array = ZarrJsonDict(**array).to_array_input() array = ZarrJsonDict(**array).to_array_input()
elif isinstance(array, ZarrJsonDict):
array = array.to_array_input()
if isinstance(array, ZarrArray): if isinstance(array, ZarrArray):
return array return array

View file

@ -19,7 +19,7 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
interface_cls = Interface.match_output(value) interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info) array = interface_cls.to_json(value, info)
if isinstance(array, JsonDict): if isinstance(array, JsonDict):
array = array.to_dict() array = array.model_dump(exclude_none=True)
if info.context: if info.context:
if info.context.get("mark_interface", False): if info.context.get("mark_interface", False):