mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-15 03:04:29 +00:00
use pydantic model not dataclass
This commit is contained in:
parent
70bf254ddd
commit
74f03b10bf
7 changed files with 16 additions and 33 deletions
|
@ -2,7 +2,6 @@
|
||||||
Interface for Dask arrays
|
Interface for Dask arrays
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Iterable, Literal, Optional
|
from typing import Any, Iterable, Literal, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -25,7 +24,6 @@ def _as_tuple(a_list: list | Any) -> tuple:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class DaskJsonDict(JsonDict):
|
class DaskJsonDict(JsonDict):
|
||||||
"""
|
"""
|
||||||
Round-trip json serialized form of a dask array
|
Round-trip json serialized form of a dask array
|
||||||
|
@ -78,6 +76,8 @@ class DaskInterface(Interface):
|
||||||
"""
|
"""
|
||||||
if isinstance(array, dict):
|
if isinstance(array, dict):
|
||||||
array = DaskJsonDict(**array).to_array_input()
|
array = DaskJsonDict(**array).to_array_input()
|
||||||
|
elif isinstance(array, DaskJsonDict):
|
||||||
|
array = array.to_array_input()
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,6 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
|
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
|
||||||
|
@ -77,7 +76,6 @@ class H5ArrayPath(NamedTuple):
|
||||||
"""Refer to a specific field within a compound dtype"""
|
"""Refer to a specific field within a compound dtype"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class H5JsonDict(JsonDict):
|
class H5JsonDict(JsonDict):
|
||||||
"""Round-trip Json-able version of an HDF5 dataset"""
|
"""Round-trip Json-able version of an HDF5 dataset"""
|
||||||
|
|
||||||
|
@ -88,7 +86,7 @@ class H5JsonDict(JsonDict):
|
||||||
def to_array_input(self) -> H5ArrayPath:
|
def to_array_input(self) -> H5ArrayPath:
|
||||||
"""Construct an :class:`.H5ArrayPath`"""
|
"""Construct an :class:`.H5ArrayPath`"""
|
||||||
return H5ArrayPath(
|
return H5ArrayPath(
|
||||||
**{k: v for k, v in self.to_dict().items() if k in H5ArrayPath._fields}
|
**{k: v for k, v in self.model_dump().items() if k in H5ArrayPath._fields}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -330,6 +328,8 @@ class H5Interface(Interface):
|
||||||
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
||||||
if isinstance(array, dict):
|
if isinstance(array, dict):
|
||||||
array = H5JsonDict(**array).to_array_input()
|
array = H5JsonDict(**array).to_array_input()
|
||||||
|
elif isinstance(array, H5JsonDict):
|
||||||
|
array = array.to_array_input()
|
||||||
|
|
||||||
if isinstance(array, H5ArrayPath):
|
if isinstance(array, H5ArrayPath):
|
||||||
array = H5Proxy.from_h5array(h5array=array)
|
array = H5Proxy.from_h5array(h5array=array)
|
||||||
|
|
|
@ -4,13 +4,12 @@ Base Interface metaclass
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
|
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo, TypeAdapter, ValidationError
|
from pydantic import BaseModel, SerializationInfo, ValidationError
|
||||||
|
|
||||||
from numpydantic.exceptions import (
|
from numpydantic.exceptions import (
|
||||||
DtypeError,
|
DtypeError,
|
||||||
|
@ -32,13 +31,9 @@ class InterfaceMark(TypedDict):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
class JsonDict(BaseModel):
|
||||||
class JsonDict:
|
|
||||||
"""
|
"""
|
||||||
Representation of array when dumped with round_trip == True.
|
Representation of array when dumped with round_trip == True.
|
||||||
|
|
||||||
Using a dataclass rather than a pydantic model to not tempt
|
|
||||||
us to use more sophisticated types than can be serialized to json.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
|
@ -50,18 +45,6 @@ class JsonDict:
|
||||||
(one of the ``input_types`` of an interface).
|
(one of the ``input_types`` of an interface).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""
|
|
||||||
Convenience method for casting dataclass to dict,
|
|
||||||
removing None-valued items
|
|
||||||
"""
|
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_adapter(cls) -> TypeAdapter:
|
|
||||||
"""Convenience method to get a typeadapter for this class"""
|
|
||||||
return TypeAdapter(cls)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
|
def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -75,9 +58,8 @@ class JsonDict:
|
||||||
Returns:
|
Returns:
|
||||||
bool - true if valid, false if not
|
bool - true if valid, false if not
|
||||||
"""
|
"""
|
||||||
adapter = cls.get_adapter()
|
|
||||||
try:
|
try:
|
||||||
_ = adapter.validate_python(val)
|
_ = cls.model_validate(val)
|
||||||
return True
|
return True
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
if raise_on_error:
|
if raise_on_error:
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
Interface to numpy arrays
|
Interface to numpy arrays
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
|
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
@ -21,7 +20,6 @@ except ImportError: # pragma: no cover
|
||||||
np = None
|
np = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NumpyJsonDict(JsonDict):
|
class NumpyJsonDict(JsonDict):
|
||||||
"""
|
"""
|
||||||
JSON-able roundtrip representation of numpy array
|
JSON-able roundtrip representation of numpy array
|
||||||
|
@ -78,6 +76,8 @@ class NumpyInterface(Interface):
|
||||||
"""
|
"""
|
||||||
if isinstance(array, dict):
|
if isinstance(array, dict):
|
||||||
array = NumpyJsonDict(**array).to_array_input()
|
array = NumpyJsonDict(**array).to_array_input()
|
||||||
|
elif isinstance(array, NumpyJsonDict):
|
||||||
|
array = array.to_array_input()
|
||||||
|
|
||||||
if not isinstance(array, ndarray):
|
if not isinstance(array, ndarray):
|
||||||
array = np.array(array)
|
array = np.array(array)
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
Interface to support treating videos like arrays using OpenCV
|
Interface to support treating videos like arrays using OpenCV
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional, Tuple, Union
|
from typing import Any, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
@ -22,7 +21,6 @@ except ImportError: # pragma: no cover
|
||||||
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class VideoJsonDict(JsonDict):
|
class VideoJsonDict(JsonDict):
|
||||||
"""Json-able roundtrip representation of a video file"""
|
"""Json-able roundtrip representation of a video file"""
|
||||||
|
|
||||||
|
@ -256,6 +254,8 @@ class VideoInterface(Interface):
|
||||||
"""Get a :class:`.VideoProxy` object for this video"""
|
"""Get a :class:`.VideoProxy` object for this video"""
|
||||||
if isinstance(array, dict):
|
if isinstance(array, dict):
|
||||||
proxy = VideoJsonDict(**array).to_array_input()
|
proxy = VideoJsonDict(**array).to_array_input()
|
||||||
|
elif isinstance(array, VideoJsonDict):
|
||||||
|
proxy = array.to_array_input()
|
||||||
elif isinstance(array, VideoCapture):
|
elif isinstance(array, VideoCapture):
|
||||||
proxy = VideoProxy(video=array)
|
proxy = VideoProxy(video=array)
|
||||||
elif isinstance(array, VideoProxy):
|
elif isinstance(array, VideoProxy):
|
||||||
|
|
|
@ -56,7 +56,6 @@ class ZarrArrayPath:
|
||||||
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
|
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
|
||||||
class ZarrJsonDict(JsonDict):
|
class ZarrJsonDict(JsonDict):
|
||||||
"""Round-trip Json-able version of a Zarr Array"""
|
"""Round-trip Json-able version of a Zarr Array"""
|
||||||
|
|
||||||
|
@ -94,10 +93,12 @@ class ZarrInterface(Interface):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_array(
|
def _get_array(
|
||||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
|
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
|
||||||
) -> ZarrArray:
|
) -> ZarrArray:
|
||||||
if isinstance(array, dict):
|
if isinstance(array, dict):
|
||||||
array = ZarrJsonDict(**array).to_array_input()
|
array = ZarrJsonDict(**array).to_array_input()
|
||||||
|
elif isinstance(array, ZarrJsonDict):
|
||||||
|
array = array.to_array_input()
|
||||||
|
|
||||||
if isinstance(array, ZarrArray):
|
if isinstance(array, ZarrArray):
|
||||||
return array
|
return array
|
||||||
|
|
|
@ -19,7 +19,7 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
interface_cls = Interface.match_output(value)
|
interface_cls = Interface.match_output(value)
|
||||||
array = interface_cls.to_json(value, info)
|
array = interface_cls.to_json(value, info)
|
||||||
if isinstance(array, JsonDict):
|
if isinstance(array, JsonDict):
|
||||||
array = array.to_dict()
|
array = array.model_dump(exclude_none=True)
|
||||||
|
|
||||||
if info.context:
|
if info.context:
|
||||||
if info.context.get("mark_interface", False):
|
if info.context.get("mark_interface", False):
|
||||||
|
|
Loading…
Reference in a new issue