mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
roundtripping marked arrays, roundtrip or not
This commit is contained in:
parent
8cc2574399
commit
708e6e81d8
12 changed files with 328 additions and 53 deletions
|
@ -25,3 +25,7 @@ class NoMatchError(MatchError):
|
||||||
|
|
||||||
class TooManyMatchesError(MatchError):
|
class TooManyMatchesError(MatchError):
|
||||||
"""Too many matches found by :class:`.Interface.match`"""
|
"""Too many matches found by :class:`.Interface.match`"""
|
||||||
|
|
||||||
|
|
||||||
|
class MarkMismatchError(MatchError):
|
||||||
|
"""A serialized :class:`.InterfaceMark` doesn't match the receiving interface"""
|
||||||
|
|
|
@ -4,16 +4,23 @@ Interfaces between nptyping types and array backends
|
||||||
|
|
||||||
from numpydantic.interface.dask import DaskInterface
|
from numpydantic.interface.dask import DaskInterface
|
||||||
from numpydantic.interface.hdf5 import H5Interface
|
from numpydantic.interface.hdf5 import H5Interface
|
||||||
from numpydantic.interface.interface import Interface, JsonDict
|
from numpydantic.interface.interface import (
|
||||||
|
Interface,
|
||||||
|
InterfaceMark,
|
||||||
|
JsonDict,
|
||||||
|
MarkedJson,
|
||||||
|
)
|
||||||
from numpydantic.interface.numpy import NumpyInterface
|
from numpydantic.interface.numpy import NumpyInterface
|
||||||
from numpydantic.interface.video import VideoInterface
|
from numpydantic.interface.video import VideoInterface
|
||||||
from numpydantic.interface.zarr import ZarrInterface
|
from numpydantic.interface.zarr import ZarrInterface
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"JsonDict",
|
|
||||||
"Interface",
|
|
||||||
"DaskInterface",
|
"DaskInterface",
|
||||||
"H5Interface",
|
"H5Interface",
|
||||||
|
"Interface",
|
||||||
|
"InterfaceMark",
|
||||||
|
"JsonDict",
|
||||||
|
"MarkedJson",
|
||||||
"NumpyInterface",
|
"NumpyInterface",
|
||||||
"VideoInterface",
|
"VideoInterface",
|
||||||
"ZarrInterface",
|
"ZarrInterface",
|
||||||
|
|
|
@ -120,7 +120,7 @@ class H5Proxy:
|
||||||
annotation_dtype: Optional[DtypeType] = None,
|
annotation_dtype: Optional[DtypeType] = None,
|
||||||
):
|
):
|
||||||
self._h5f = None
|
self._h5f = None
|
||||||
self.file = Path(file)
|
self.file = Path(file).resolve()
|
||||||
self.path = path
|
self.path = path
|
||||||
self.field = field
|
self.field = field
|
||||||
self._annotation_dtype = annotation_dtype
|
self._annotation_dtype = annotation_dtype
|
||||||
|
@ -156,6 +156,9 @@ class H5Proxy:
|
||||||
return obj[:]
|
return obj[:]
|
||||||
|
|
||||||
def __getattr__(self, item: str):
|
def __getattr__(self, item: str):
|
||||||
|
if item == "__name__":
|
||||||
|
# special case for H5Proxies that don't refer to a real file during testing
|
||||||
|
return "H5Proxy"
|
||||||
with h5py.File(self.file, "r") as h5f:
|
with h5py.File(self.file, "r") as h5f:
|
||||||
obj = h5f.get(self.path)
|
obj = h5f.get(self.path)
|
||||||
val = getattr(obj, item)
|
val = getattr(obj, item)
|
||||||
|
|
|
@ -3,16 +3,19 @@ Base Interface metaclass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import lru_cache
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
|
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, SerializationInfo, ValidationError
|
from pydantic import BaseModel, SerializationInfo, ValidationError
|
||||||
|
|
||||||
from numpydantic.exceptions import (
|
from numpydantic.exceptions import (
|
||||||
DtypeError,
|
DtypeError,
|
||||||
|
MarkMismatchError,
|
||||||
NoMatchError,
|
NoMatchError,
|
||||||
ShapeError,
|
ShapeError,
|
||||||
TooManyMatchesError,
|
TooManyMatchesError,
|
||||||
|
@ -26,13 +29,49 @@ V = TypeVar("V") # input type
|
||||||
W = TypeVar("W") # Any type in handle_input
|
W = TypeVar("W") # Any type in handle_input
|
||||||
|
|
||||||
|
|
||||||
class InterfaceMark(TypedDict):
|
class InterfaceMark(BaseModel):
|
||||||
"""JSON-able mark to be able to round-trip json dumps"""
|
"""JSON-able mark to be able to round-trip json dumps"""
|
||||||
|
|
||||||
module: str
|
module: str
|
||||||
cls: str
|
cls: str
|
||||||
|
name: str
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Check that a given interface matches the mark.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (Type): Interface type to check
|
||||||
|
raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
|
||||||
|
is incorrect
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
:class:`.MarkMismatchError` if requested by ``raise_on_error``
|
||||||
|
for an invalid match
|
||||||
|
"""
|
||||||
|
mark = cls.mark_interface()
|
||||||
|
valid = self == mark
|
||||||
|
if not valid and raise_on_error:
|
||||||
|
raise MarkMismatchError(
|
||||||
|
"Mismatch between serialized mark and current interface, "
|
||||||
|
f"Serialized: {self}; current: {cls}"
|
||||||
|
)
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def match_by_name(self) -> Optional[Type["Interface"]]:
|
||||||
|
"""
|
||||||
|
Try to find a matching interface by its name, returning it if found,
|
||||||
|
or None if not found.
|
||||||
|
"""
|
||||||
|
for i in Interface.interfaces(sort=False):
|
||||||
|
if i.name == self.name:
|
||||||
|
return i
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class JsonDict(BaseModel):
|
class JsonDict(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -84,6 +123,29 @@ class JsonDict(BaseModel):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class MarkedJson(BaseModel):
|
||||||
|
"""
|
||||||
|
Model of JSON dumped with an additional interface mark
|
||||||
|
with ``model_dump_json({'mark_interface': True})``
|
||||||
|
"""
|
||||||
|
|
||||||
|
interface: InterfaceMark
|
||||||
|
value: Union[list, dict]
|
||||||
|
"""
|
||||||
|
Inner value of the array, we don't validate for JsonDict here,
|
||||||
|
that should be downstream from us for performance reasons
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]:
|
||||||
|
"""
|
||||||
|
Try to cast to MarkedJson if applicable, otherwise return input
|
||||||
|
"""
|
||||||
|
if isinstance(value, dict) and "interface" in value and "value" in value:
|
||||||
|
value = MarkedJson(**value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class Interface(ABC, Generic[T]):
|
class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
Abstract parent class for interfaces to different array formats
|
Abstract parent class for interfaces to different array formats
|
||||||
|
@ -158,14 +220,24 @@ class Interface(ABC, Generic[T]):
|
||||||
def deserialize(self, array: Any) -> Union[V, Any]:
|
def deserialize(self, array: Any) -> Union[V, Any]:
|
||||||
"""
|
"""
|
||||||
If given a JSON serialized version of the array,
|
If given a JSON serialized version of the array,
|
||||||
deserialize it first
|
deserialize it first.
|
||||||
|
|
||||||
Args:
|
If a roundtrip-serialized :class:`.JsonDict`,
|
||||||
array:
|
pass to :meth:`.JsonDict.handle_input`.
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
|
If a roundtrip-serialized :class:`.MarkedJson`,
|
||||||
|
unpack mark, check for validity, warn if not,
|
||||||
|
and try to continue with validation
|
||||||
"""
|
"""
|
||||||
|
if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
|
||||||
|
try:
|
||||||
|
marked_array.interface.is_valid(self.__class__, raise_on_error=True)
|
||||||
|
except MarkMismatchError as e:
|
||||||
|
warnings.warn(
|
||||||
|
str(e) + "\nAttempting to continue validation...", stacklevel=2
|
||||||
|
)
|
||||||
|
array = marked_array.value
|
||||||
|
|
||||||
return self.json_model.handle_input(array)
|
return self.json_model.handle_input(array)
|
||||||
|
|
||||||
def before_validation(self, array: Any) -> NDArrayType:
|
def before_validation(self, array: Any) -> NDArrayType:
|
||||||
|
@ -274,13 +346,6 @@ class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
return array
|
return array
|
||||||
|
|
||||||
def mark_input(self, array: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Preserve metadata about the interface and passed input when dumping with
|
|
||||||
``round_trip``
|
|
||||||
"""
|
|
||||||
return array
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def check(cls, array: Any) -> bool:
|
def check(cls, array: Any) -> bool:
|
||||||
|
@ -320,7 +385,7 @@ class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def mark_json(cls, array: Union[list, dict]) -> dict:
|
def mark_json(cls, array: Union[list, dict]) -> MarkedJson:
|
||||||
"""
|
"""
|
||||||
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
|
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
|
||||||
add additional annotations that would allow the serialized array to be
|
add additional annotations that would allow the serialized array to be
|
||||||
|
@ -337,7 +402,7 @@ class Interface(ABC, Generic[T]):
|
||||||
'version': '1.2.2'},
|
'version': '1.2.2'},
|
||||||
'value': [1.0, 2.0]}
|
'value': [1.0, 2.0]}
|
||||||
"""
|
"""
|
||||||
return {"interface": cls.mark_interface(), "value": array}
|
return MarkedJson.model_construct(interface=cls.mark_interface(), value=array)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def interfaces(
|
def interfaces(
|
||||||
|
@ -390,6 +455,28 @@ class Interface(ABC, Generic[T]):
|
||||||
|
|
||||||
return tuple(in_types)
|
return tuple(in_types)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def match_mark(cls, array: Any) -> Optional[Type["Interface"]]:
|
||||||
|
"""
|
||||||
|
Match a marked JSON dump of this array to the interface that it indicates.
|
||||||
|
|
||||||
|
First find an interface that matches by name, and then run its
|
||||||
|
``check`` method, because arrays can be dumped with a mark
|
||||||
|
but without ``round_trip == True`` (and thus can't necessarily
|
||||||
|
use the same interface that they were dumped with)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interface if match found, None otherwise
|
||||||
|
"""
|
||||||
|
mark = MarkedJson.try_cast(array)
|
||||||
|
if not isinstance(mark, MarkedJson):
|
||||||
|
return None
|
||||||
|
|
||||||
|
interface = mark.interface.match_by_name()
|
||||||
|
if interface is not None and interface.check(mark.value):
|
||||||
|
return interface
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
||||||
"""
|
"""
|
||||||
|
@ -407,11 +494,18 @@ class Interface(ABC, Generic[T]):
|
||||||
check each interface (as ordered by its ``priority`` , decreasing),
|
check each interface (as ordered by its ``priority`` , decreasing),
|
||||||
and return on the first match.
|
and return on the first match.
|
||||||
"""
|
"""
|
||||||
|
# Shortcircuit match if this is a marked json dump
|
||||||
|
array = MarkedJson.try_cast(array)
|
||||||
|
if (match := cls.match_mark(array)) is not None:
|
||||||
|
return match
|
||||||
|
elif isinstance(array, MarkedJson):
|
||||||
|
array = array.value
|
||||||
|
|
||||||
# first try and find a non-numpy interface, since the numpy interface
|
# first try and find a non-numpy interface, since the numpy interface
|
||||||
# will try and load the array into memory in its check method
|
# will try and load the array into memory in its check method
|
||||||
interfaces = cls.interfaces()
|
interfaces = cls.interfaces()
|
||||||
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
|
non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
|
||||||
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
|
np_interface = [i for i in interfaces if i.name == "numpy"][0]
|
||||||
|
|
||||||
if fast:
|
if fast:
|
||||||
matches = []
|
matches = []
|
||||||
|
@ -453,6 +547,7 @@ class Interface(ABC, Generic[T]):
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
def mark_interface(cls) -> InterfaceMark:
|
def mark_interface(cls) -> InterfaceMark:
|
||||||
"""
|
"""
|
||||||
Create an interface mark indicating this interface for validation after
|
Create an interface mark indicating this interface for validation after
|
||||||
|
@ -470,5 +565,7 @@ class Interface(ABC, Generic[T]):
|
||||||
)
|
)
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
v = None
|
v = None
|
||||||
interface_name = cls.__name__
|
|
||||||
return InterfaceMark(module=interface_module, cls=interface_name, version=v)
|
return InterfaceMark(
|
||||||
|
module=interface_module, cls=cls.__name__, name=cls.name, version=v
|
||||||
|
)
|
||||||
|
|
|
@ -48,7 +48,7 @@ class VideoProxy:
|
||||||
)
|
)
|
||||||
|
|
||||||
if path is not None:
|
if path is not None:
|
||||||
path = Path(path)
|
path = Path(path).resolve()
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
self._video = video # type: Optional[VideoCapture]
|
self._video = video # type: Optional[VideoCapture]
|
||||||
|
@ -200,6 +200,8 @@ class VideoProxy:
|
||||||
raise NotImplementedError("Setting pixel values on videos is not supported!")
|
raise NotImplementedError("Setting pixel values on videos is not supported!")
|
||||||
|
|
||||||
def __getattr__(self, item: str):
|
def __getattr__(self, item: str):
|
||||||
|
if item == "__name__":
|
||||||
|
return "VideoProxy"
|
||||||
return getattr(self.video, item)
|
return getattr(self.video, item)
|
||||||
|
|
||||||
def __eq__(self, other: "VideoProxy") -> bool:
|
def __eq__(self, other: "VideoProxy") -> bool:
|
||||||
|
|
|
@ -23,7 +23,8 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
|
|
||||||
if info.context:
|
if info.context:
|
||||||
if info.context.get("mark_interface", False):
|
if info.context.get("mark_interface", False):
|
||||||
array = interface_cls.mark_json(array)
|
array = interface_cls.mark_json(array).model_dump()
|
||||||
|
|
||||||
if info.context.get("absolute_paths", False):
|
if info.context.get("absolute_paths", False):
|
||||||
array = _absolutize_paths(array)
|
array = _absolutize_paths(array)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -83,7 +83,6 @@ STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||||
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.shape
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="module",
|
scope="module",
|
||||||
params=[
|
params=[
|
||||||
|
@ -121,7 +120,6 @@ def shape_cases(request) -> ValidationCase:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dtype
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="module",
|
scope="module",
|
||||||
params=[
|
params=[
|
||||||
|
|
|
@ -104,7 +104,6 @@ def model_blank() -> Type[BaseModel]:
|
||||||
return BlankModel
|
return BlankModel
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.hdf5
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def hdf5_file(tmp_output_dir_func) -> h5py.File:
|
def hdf5_file(tmp_output_dir_func) -> h5py.File:
|
||||||
h5f_file = tmp_output_dir_func / "h5f.h5"
|
h5f_file = tmp_output_dir_func / "h5f.h5"
|
||||||
|
@ -113,7 +112,6 @@ def hdf5_file(tmp_output_dir_func) -> h5py.File:
|
||||||
h5f.close()
|
h5f.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.hdf5
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def hdf5_array(
|
def hdf5_array(
|
||||||
hdf5_file, request
|
hdf5_file, request
|
||||||
|
@ -156,7 +154,6 @@ def hdf5_array(
|
||||||
return _hdf5_array
|
return _hdf5_array
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.zarr
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
|
def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
|
||||||
"""Zarr array within a nested array"""
|
"""Zarr array within a nested array"""
|
||||||
|
@ -167,7 +164,6 @@ def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
|
||||||
return ZarrArrayPath(file=file, path=path)
|
return ZarrArrayPath(file=file, path=path)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.zarr
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def zarr_array(tmp_output_dir_func) -> Path:
|
def zarr_array(tmp_output_dir_func) -> Path:
|
||||||
file = tmp_output_dir_func / "array.zarr"
|
file = tmp_output_dir_func / "array.zarr"
|
||||||
|
@ -176,7 +172,6 @@ def zarr_array(tmp_output_dir_func) -> Path:
|
||||||
return file
|
return file
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.video
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def avi_video(tmp_path) -> Callable[[Tuple[int, int], int, bool], Path]:
|
def avi_video(tmp_path) -> Callable[[Tuple[int, int], int, bool], Path]:
|
||||||
video_path = tmp_path / "test.avi"
|
video_path = tmp_path / "test.avi"
|
||||||
|
|
|
@ -231,3 +231,29 @@ def test_empty_dataset(dtype, tmp_path):
|
||||||
array: NDArray[Any, dtype]
|
array: NDArray[Any, dtype]
|
||||||
|
|
||||||
_ = MyModel(array=(array_path, "/data"))
|
_ = MyModel(array=(array_path, "/data"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"comparison,valid",
|
||||||
|
[
|
||||||
|
(H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True),
|
||||||
|
(H5Proxy(file="test_file.h5", path="/subpath"), False),
|
||||||
|
(H5Proxy(file="different_file.h5", path="/subpath"), False),
|
||||||
|
(("different_file.h5", "/subpath", "sup"), ValueError),
|
||||||
|
("not even a proxy-like thing", ValueError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_proxy_eq(comparison, valid):
|
||||||
|
"""
|
||||||
|
test the __eq__ method of H5ArrayProxy matches proxies to the same
|
||||||
|
dataset (and path), or raises a ValueError
|
||||||
|
"""
|
||||||
|
proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup")
|
||||||
|
if valid is True:
|
||||||
|
assert proxy_a == comparison
|
||||||
|
elif valid is False:
|
||||||
|
assert proxy_a != comparison
|
||||||
|
else:
|
||||||
|
with pytest.raises(valid):
|
||||||
|
assert proxy_a == comparison
|
||||||
|
|
|
@ -4,11 +4,26 @@ for tests that should apply to all interfaces, use ``test_interfaces.py``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface, JsonDict
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from numpydantic.interface.interface import V
|
||||||
|
|
||||||
|
|
||||||
|
class MyJsonDict(JsonDict):
|
||||||
|
type: Literal["my_json_dict"]
|
||||||
|
field: str
|
||||||
|
number: int
|
||||||
|
|
||||||
|
def to_array_input(self) -> V:
|
||||||
|
dumped = self.model_dump()
|
||||||
|
dumped["extra_input_param"] = True
|
||||||
|
return dumped
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -162,3 +177,36 @@ def test_interface_recursive(interfaces):
|
||||||
assert issubclass(interfaces.interface3, interfaces.interface1)
|
assert issubclass(interfaces.interface3, interfaces.interface1)
|
||||||
assert issubclass(interfaces.interface1, Interface)
|
assert issubclass(interfaces.interface1, Interface)
|
||||||
assert interfaces.interface4 in ifaces
|
assert interfaces.interface4 in ifaces
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_jsondict_is_valid():
|
||||||
|
"""
|
||||||
|
A JsonDict should return a bool true/false if it is valid or not,
|
||||||
|
and raise an error when requested
|
||||||
|
"""
|
||||||
|
invalid = {"doesnt": "have", "the": "props"}
|
||||||
|
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||||
|
assert MyJsonDict.is_valid(valid)
|
||||||
|
assert not MyJsonDict.is_valid(invalid)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
assert not MyJsonDict.is_valid(invalid, raise_on_error=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_jsondict_handle_input():
|
||||||
|
"""
|
||||||
|
JsonDict should be able to parse a valid dict and return it to the input format
|
||||||
|
"""
|
||||||
|
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||||
|
instantiated = MyJsonDict(**valid)
|
||||||
|
expected = {
|
||||||
|
"type": "my_json_dict",
|
||||||
|
"field": "a_field",
|
||||||
|
"number": 1,
|
||||||
|
"extra_input_param": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
for item in (valid, instantiated):
|
||||||
|
result = MyJsonDict.handle_input(item)
|
||||||
|
assert result == expected
|
||||||
|
|
|
@ -4,10 +4,38 @@ Tests that should be applied to all interfaces
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
from importlib.metadata import version
|
||||||
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
from zarr.core import Array as ZarrArray
|
from zarr.core import Array as ZarrArray
|
||||||
from numpydantic.interface import Interface
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||||
|
|
||||||
|
|
||||||
|
def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool):
|
||||||
|
"""Test model equality for roundtrip tests"""
|
||||||
|
if round_trip:
|
||||||
|
assert type(target.array) is type(source.array)
|
||||||
|
if isinstance(source.array, (np.ndarray, ZarrArray)):
|
||||||
|
assert np.array_equal(target.array, np.array(source.array))
|
||||||
|
elif isinstance(source.array, da.Array):
|
||||||
|
assert np.all(da.equal(target.array, source.array))
|
||||||
|
else:
|
||||||
|
assert target.array == source.array
|
||||||
|
|
||||||
|
assert target.array.dtype == source.array.dtype
|
||||||
|
else:
|
||||||
|
assert np.array_equal(target.array, np.array(source.array))
|
||||||
|
|
||||||
|
|
||||||
|
def test_dunder_len(all_interfaces):
|
||||||
|
"""
|
||||||
|
Each interface or proxy type should support __len__
|
||||||
|
"""
|
||||||
|
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
||||||
|
|
||||||
|
|
||||||
def test_interface_revalidate(all_interfaces):
|
def test_interface_revalidate(all_interfaces):
|
||||||
|
@ -52,24 +80,51 @@ def test_interface_roundtrip_json(all_interfaces, round_trip):
|
||||||
"""
|
"""
|
||||||
All interfaces should be able to roundtrip to and from json
|
All interfaces should be able to roundtrip to and from json
|
||||||
"""
|
"""
|
||||||
json = all_interfaces.model_dump_json(round_trip=round_trip)
|
dumped_json = all_interfaces.model_dump_json(round_trip=round_trip)
|
||||||
model = all_interfaces.model_validate_json(json)
|
model = all_interfaces.model_validate_json(dumped_json)
|
||||||
if round_trip:
|
_test_roundtrip(all_interfaces, model, round_trip)
|
||||||
assert type(model.array) is type(all_interfaces.array)
|
|
||||||
if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)):
|
|
||||||
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
|
||||||
elif isinstance(all_interfaces.array, da.Array):
|
|
||||||
assert np.all(da.equal(model.array, all_interfaces.array))
|
|
||||||
else:
|
|
||||||
assert model.array == all_interfaces.array
|
|
||||||
|
|
||||||
assert model.array.dtype == all_interfaces.array.dtype
|
|
||||||
else:
|
|
||||||
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
|
||||||
|
|
||||||
|
|
||||||
def test_dunder_len(all_interfaces):
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("an_interface", Interface.interfaces())
|
||||||
|
def test_interface_mark_interface(an_interface):
|
||||||
"""
|
"""
|
||||||
Each interface or proxy type should support __len__
|
All interfaces should be able to mark the current version and interface info
|
||||||
"""
|
"""
|
||||||
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
mark = an_interface.mark_interface()
|
||||||
|
assert isinstance(mark, InterfaceMark)
|
||||||
|
assert mark.name == an_interface.name
|
||||||
|
assert mark.cls == an_interface.__name__
|
||||||
|
assert mark.module == an_interface.__module__
|
||||||
|
assert mark.version == version(mark.module.split(".")[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("valid", [True, False])
|
||||||
|
@pytest.mark.parametrize("round_trip", [True, False])
|
||||||
|
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
|
||||||
|
def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
|
||||||
|
"""
|
||||||
|
All interfaces should be able to roundtrip with the marked interface,
|
||||||
|
and a mismatch should raise a warning and attempt to proceed
|
||||||
|
"""
|
||||||
|
dumped_json = all_interfaces.model_dump_json(
|
||||||
|
round_trip=round_trip, context={"mark_interface": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(dumped_json)
|
||||||
|
|
||||||
|
# ensure that we are a MarkedJson
|
||||||
|
_ = MarkedJson.model_validate_json(json.dumps(data["array"]))
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
# ruin the version
|
||||||
|
data["array"]["interface"]["version"] = "v99999999"
|
||||||
|
dumped_json = json.dumps(data)
|
||||||
|
|
||||||
|
with pytest.warns(match="Mismatch.*"):
|
||||||
|
model = all_interfaces.model_validate_json(dumped_json)
|
||||||
|
else:
|
||||||
|
model = all_interfaces.model_validate_json(dumped_json)
|
||||||
|
|
||||||
|
_test_roundtrip(all_interfaces, model, round_trip)
|
||||||
|
|
|
@ -164,3 +164,42 @@ def test_video_close(avi_video):
|
||||||
assert instance.array._video is None
|
assert instance.array._video is None
|
||||||
# reopen
|
# reopen
|
||||||
assert isinstance(instance.array.video, cv2.VideoCapture)
|
assert isinstance(instance.array.video, cv2.VideoCapture)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
|
def test_video_not_exists(tmp_path):
|
||||||
|
"""
|
||||||
|
A video file that doesn't exist should raise an error
|
||||||
|
"""
|
||||||
|
video = VideoProxy(tmp_path / "not_real.avi")
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
_ = video.video
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"comparison,valid",
|
||||||
|
[
|
||||||
|
(VideoProxy("test_video.avi"), True),
|
||||||
|
(VideoProxy("not_real_video.avi"), False),
|
||||||
|
("not even a video proxy", TypeError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_video_proxy_eq(comparison, valid):
|
||||||
|
"""
|
||||||
|
Comparing a video proxy's equality should be valid if the path matches
|
||||||
|
Args:
|
||||||
|
comparison:
|
||||||
|
valid:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
proxy_a = VideoProxy("test_video.avi")
|
||||||
|
if valid is True:
|
||||||
|
assert proxy_a == comparison
|
||||||
|
elif valid is False:
|
||||||
|
assert proxy_a != comparison
|
||||||
|
else:
|
||||||
|
with pytest.raises(valid):
|
||||||
|
assert proxy_a == comparison
|
||||||
|
|
Loading…
Reference in a new issue