use pydantic model not dataclass

This commit is contained in:
sneakers-the-rat 2024-09-21 21:24:32 -07:00
parent 70bf254ddd
commit 74f03b10bf
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
7 changed files with 16 additions and 33 deletions

View file

@ -2,7 +2,6 @@
Interface for Dask arrays
"""
from dataclasses import dataclass
from typing import Any, Iterable, Literal, Optional
import numpy as np
@ -25,7 +24,6 @@ def _as_tuple(a_list: list | Any) -> tuple:
)
@dataclass(kw_only=True)
class DaskJsonDict(JsonDict):
"""
Round-trip json serialized form of a dask array
@ -78,6 +76,8 @@ class DaskInterface(Interface):
"""
if isinstance(array, dict):
array = DaskJsonDict(**array).to_array_input()
elif isinstance(array, DaskJsonDict):
array = array.to_array_input()
return array

View file

@ -40,7 +40,6 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
"""
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
@ -77,7 +76,6 @@ class H5ArrayPath(NamedTuple):
"""Refer to a specific field within a compound dtype"""
@dataclass
class H5JsonDict(JsonDict):
"""Round-trip Json-able version of an HDF5 dataset"""
@ -88,7 +86,7 @@ class H5JsonDict(JsonDict):
def to_array_input(self) -> H5ArrayPath:
"""Construct an :class:`.H5ArrayPath`"""
return H5ArrayPath(
**{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields}
**{k: v for k, v in self.model_dump().items() if k in H5ArrayPath._fields}
)
@ -330,6 +328,8 @@ class H5Interface(Interface):
"""Create an :class:`.H5Proxy` to use throughout validation"""
if isinstance(array, dict):
array = H5JsonDict(**array).to_array_input()
elif isinstance(array, H5JsonDict):
array = array.to_array_input()
if isinstance(array, H5ArrayPath):
array = H5Proxy.from_h5array(h5array=array)

View file

@ -4,13 +4,12 @@ Base Interface metaclass
import inspect
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from importlib.metadata import PackageNotFoundError, version
from operator import attrgetter
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
import numpy as np
from pydantic import SerializationInfo, TypeAdapter, ValidationError
from pydantic import BaseModel, SerializationInfo, ValidationError
from numpydantic.exceptions import (
DtypeError,
@ -32,13 +31,9 @@ class InterfaceMark(TypedDict):
version: str
@dataclass(kw_only=True)
class JsonDict:
class JsonDict(BaseModel):
"""
Representation of array when dumped with round_trip == True.
Using a dataclass rather than a pydantic model to not tempt
us to use more sophisticated types than can be serialized to json.
"""
type: str
@ -50,18 +45,6 @@ class JsonDict:
(one of the ``input_types`` of an interface).
"""
def to_dict(self) -> dict:
"""
Convenience method for casting dataclass to dict,
removing None-valued items
"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def get_adapter(cls) -> TypeAdapter:
"""Convenience method to get a typeadapter for this class"""
return TypeAdapter(cls)
@classmethod
def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
"""
@ -75,9 +58,8 @@ class JsonDict:
Returns:
bool - true if valid, false if not
"""
adapter = cls.get_adapter()
try:
_ = adapter.validate_python(val)
_ = cls.model_validate(val)
return True
except ValidationError as e:
if raise_on_error:

View file

@ -2,7 +2,6 @@
Interface to numpy arrays
"""
from dataclasses import dataclass
from typing import Any, Literal, Union
from pydantic import SerializationInfo
@ -21,7 +20,6 @@ except ImportError: # pragma: no cover
np = None
@dataclass
class NumpyJsonDict(JsonDict):
"""
JSON-able roundtrip representation of numpy array
@ -78,6 +76,8 @@ class NumpyInterface(Interface):
"""
if isinstance(array, dict):
array = NumpyJsonDict(**array).to_array_input()
elif isinstance(array, NumpyJsonDict):
array = array.to_array_input()
if not isinstance(array, ndarray):
array = np.array(array)

View file

@ -2,7 +2,6 @@
Interface to support treating videos like arrays using OpenCV
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, Optional, Tuple, Union
@ -22,7 +21,6 @@ except ImportError: # pragma: no cover
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
@dataclass(kw_only=True)
class VideoJsonDict(JsonDict):
"""Json-able roundtrip representation of a video file"""
@ -256,6 +254,8 @@ class VideoInterface(Interface):
"""Get a :class:`.VideoProxy` object for this video"""
if isinstance(array, dict):
proxy = VideoJsonDict(**array).to_array_input()
elif isinstance(array, VideoJsonDict):
proxy = array.to_array_input()
elif isinstance(array, VideoCapture):
proxy = VideoProxy(video=array)
elif isinstance(array, VideoProxy):

View file

@ -56,7 +56,6 @@ class ZarrArrayPath:
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
@dataclass(kw_only=True)
class ZarrJsonDict(JsonDict):
"""Round-trip Json-able version of a Zarr Array"""
@ -94,10 +93,12 @@ class ZarrInterface(Interface):
@staticmethod
def _get_array(
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
) -> ZarrArray:
if isinstance(array, dict):
array = ZarrJsonDict(**array).to_array_input()
elif isinstance(array, ZarrJsonDict):
array = array.to_array_input()
if isinstance(array, ZarrArray):
return array

View file

@ -19,7 +19,7 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info)
if isinstance(array, JsonDict):
array = array.to_dict()
array = array.model_dump(exclude_none=True)
if info.context:
if info.context.get("mark_interface", False):