mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 18:54:28 +00:00
use pydantic model not dataclass
This commit is contained in:
parent
70bf254ddd
commit
74f03b10bf
7 changed files with 16 additions and 33 deletions
|
@ -2,7 +2,6 @@
|
|||
Interface for Dask arrays
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
@ -25,7 +24,6 @@ def _as_tuple(a_list: list | Any) -> tuple:
|
|||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class DaskJsonDict(JsonDict):
|
||||
"""
|
||||
Round-trip json serialized form of a dask array
|
||||
|
@ -78,6 +76,8 @@ class DaskInterface(Interface):
|
|||
"""
|
||||
if isinstance(array, dict):
|
||||
array = DaskJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, DaskJsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
return array
|
||||
|
||||
|
|
|
@ -40,7 +40,6 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
|
|||
"""
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
|
||||
|
@ -77,7 +76,6 @@ class H5ArrayPath(NamedTuple):
|
|||
"""Refer to a specific field within a compound dtype"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class H5JsonDict(JsonDict):
|
||||
"""Round-trip Json-able version of an HDF5 dataset"""
|
||||
|
||||
|
@ -88,7 +86,7 @@ class H5JsonDict(JsonDict):
|
|||
def to_array_input(self) -> H5ArrayPath:
|
||||
"""Construct an :class:`.H5ArrayPath`"""
|
||||
return H5ArrayPath(
|
||||
**{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields}
|
||||
**{k: v for k, v in self.model_dump().items() if k in H5ArrayPath._fields}
|
||||
)
|
||||
|
||||
|
||||
|
@ -330,6 +328,8 @@ class H5Interface(Interface):
|
|||
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
||||
if isinstance(array, dict):
|
||||
array = H5JsonDict(**array).to_array_input()
|
||||
elif isinstance(array, H5JsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
if isinstance(array, H5ArrayPath):
|
||||
array = H5Proxy.from_h5array(h5array=array)
|
||||
|
|
|
@ -4,13 +4,12 @@ Base Interface metaclass
|
|||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from operator import attrgetter
|
||||
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo, TypeAdapter, ValidationError
|
||||
from pydantic import BaseModel, SerializationInfo, ValidationError
|
||||
|
||||
from numpydantic.exceptions import (
|
||||
DtypeError,
|
||||
|
@ -32,13 +31,9 @@ class InterfaceMark(TypedDict):
|
|||
version: str
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class JsonDict:
|
||||
class JsonDict(BaseModel):
|
||||
"""
|
||||
Representation of array when dumped with round_trip == True.
|
||||
|
||||
Using a dataclass rather than a pydantic model to not tempt
|
||||
us to use more sophisticated types than can be serialized to json.
|
||||
"""
|
||||
|
||||
type: str
|
||||
|
@ -50,18 +45,6 @@ class JsonDict:
|
|||
(one of the ``input_types`` of an interface).
|
||||
"""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Convenience method for casting dataclass to dict,
|
||||
removing None-valued items
|
||||
"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def get_adapter(cls) -> TypeAdapter:
|
||||
"""Convenience method to get a typeadapter for this class"""
|
||||
return TypeAdapter(cls)
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
|
||||
"""
|
||||
|
@ -75,9 +58,8 @@ class JsonDict:
|
|||
Returns:
|
||||
bool - true if valid, false if not
|
||||
"""
|
||||
adapter = cls.get_adapter()
|
||||
try:
|
||||
_ = adapter.validate_python(val)
|
||||
_ = cls.model_validate(val)
|
||||
return True
|
||||
except ValidationError as e:
|
||||
if raise_on_error:
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
Interface to numpy arrays
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import SerializationInfo
|
||||
|
@ -21,7 +20,6 @@ except ImportError: # pragma: no cover
|
|||
np = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumpyJsonDict(JsonDict):
|
||||
"""
|
||||
JSON-able roundtrip representation of numpy array
|
||||
|
@ -78,6 +76,8 @@ class NumpyInterface(Interface):
|
|||
"""
|
||||
if isinstance(array, dict):
|
||||
array = NumpyJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, NumpyJsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
if not isinstance(array, ndarray):
|
||||
array = np.array(array)
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
Interface to support treating videos like arrays using OpenCV
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional, Tuple, Union
|
||||
|
||||
|
@ -22,7 +21,6 @@ except ImportError: # pragma: no cover
|
|||
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class VideoJsonDict(JsonDict):
|
||||
"""Json-able roundtrip representation of a video file"""
|
||||
|
||||
|
@ -256,6 +254,8 @@ class VideoInterface(Interface):
|
|||
"""Get a :class:`.VideoProxy` object for this video"""
|
||||
if isinstance(array, dict):
|
||||
proxy = VideoJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, VideoJsonDict):
|
||||
proxy = array.to_array_input()
|
||||
elif isinstance(array, VideoCapture):
|
||||
proxy = VideoProxy(video=array)
|
||||
elif isinstance(array, VideoProxy):
|
||||
|
|
|
@ -56,7 +56,6 @@ class ZarrArrayPath:
|
|||
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ZarrJsonDict(JsonDict):
|
||||
"""Round-trip Json-able version of a Zarr Array"""
|
||||
|
||||
|
@ -94,10 +93,12 @@ class ZarrInterface(Interface):
|
|||
|
||||
@staticmethod
|
||||
def _get_array(
|
||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
|
||||
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
|
||||
) -> ZarrArray:
|
||||
if isinstance(array, dict):
|
||||
array = ZarrJsonDict(**array).to_array_input()
|
||||
elif isinstance(array, ZarrJsonDict):
|
||||
array = array.to_array_input()
|
||||
|
||||
if isinstance(array, ZarrArray):
|
||||
return array
|
||||
|
|
|
@ -19,7 +19,7 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
|||
interface_cls = Interface.match_output(value)
|
||||
array = interface_cls.to_json(value, info)
|
||||
if isinstance(array, JsonDict):
|
||||
array = array.to_dict()
|
||||
array = array.model_dump(exclude_none=True)
|
||||
|
||||
if info.context:
|
||||
if info.context.get("mark_interface", False):
|
||||
|
|
Loading…
Reference in a new issue