roundtripping marked arrays, roundtrip or not

This commit is contained in:
sneakers-the-rat 2024-09-23 15:54:55 -07:00
parent 8cc2574399
commit 708e6e81d8
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
12 changed files with 328 additions and 53 deletions

View file

@ -25,3 +25,7 @@ class NoMatchError(MatchError):
class TooManyMatchesError(MatchError): class TooManyMatchesError(MatchError):
"""Too many matches found by :class:`.Interface.match`""" """Too many matches found by :class:`.Interface.match`"""
class MarkMismatchError(MatchError):
"""A serialized :class:`.InterfaceMark` doesn't match the receiving interface"""

View file

@ -4,16 +4,23 @@ Interfaces between nptyping types and array backends
from numpydantic.interface.dask import DaskInterface from numpydantic.interface.dask import DaskInterface
from numpydantic.interface.hdf5 import H5Interface from numpydantic.interface.hdf5 import H5Interface
from numpydantic.interface.interface import Interface, JsonDict from numpydantic.interface.interface import (
Interface,
InterfaceMark,
JsonDict,
MarkedJson,
)
from numpydantic.interface.numpy import NumpyInterface from numpydantic.interface.numpy import NumpyInterface
from numpydantic.interface.video import VideoInterface from numpydantic.interface.video import VideoInterface
from numpydantic.interface.zarr import ZarrInterface from numpydantic.interface.zarr import ZarrInterface
__all__ = [ __all__ = [
"JsonDict",
"Interface",
"DaskInterface", "DaskInterface",
"H5Interface", "H5Interface",
"Interface",
"InterfaceMark",
"JsonDict",
"MarkedJson",
"NumpyInterface", "NumpyInterface",
"VideoInterface", "VideoInterface",
"ZarrInterface", "ZarrInterface",

View file

@ -120,7 +120,7 @@ class H5Proxy:
annotation_dtype: Optional[DtypeType] = None, annotation_dtype: Optional[DtypeType] = None,
): ):
self._h5f = None self._h5f = None
self.file = Path(file) self.file = Path(file).resolve()
self.path = path self.path = path
self.field = field self.field = field
self._annotation_dtype = annotation_dtype self._annotation_dtype = annotation_dtype
@ -156,6 +156,9 @@ class H5Proxy:
return obj[:] return obj[:]
def __getattr__(self, item: str): def __getattr__(self, item: str):
if item == "__name__":
# special case for H5Proxies that don't refer to a real file during testing
return "H5Proxy"
with h5py.File(self.file, "r") as h5f: with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path) obj = h5f.get(self.path)
val = getattr(obj, item) val = getattr(obj, item)

View file

@ -3,16 +3,19 @@ Base Interface metaclass
""" """
import inspect import inspect
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from operator import attrgetter from operator import attrgetter
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
from pydantic import BaseModel, SerializationInfo, ValidationError from pydantic import BaseModel, SerializationInfo, ValidationError
from numpydantic.exceptions import ( from numpydantic.exceptions import (
DtypeError, DtypeError,
MarkMismatchError,
NoMatchError, NoMatchError,
ShapeError, ShapeError,
TooManyMatchesError, TooManyMatchesError,
@ -26,13 +29,49 @@ V = TypeVar("V") # input type
W = TypeVar("W") # Any type in handle_input W = TypeVar("W") # Any type in handle_input
class InterfaceMark(TypedDict): class InterfaceMark(BaseModel):
"""JSON-able mark to be able to round-trip json dumps""" """JSON-able mark to be able to round-trip json dumps"""
module: str module: str
cls: str cls: str
name: str
version: str version: str
def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool:
"""
Check that a given interface matches the mark.
Args:
cls (Type): Interface type to check
raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
is incorrect
Returns:
bool
Raises:
:class:`.MarkMismatchError` if requested by ``raise_on_error``
for an invalid match
"""
mark = cls.mark_interface()
valid = self == mark
if not valid and raise_on_error:
raise MarkMismatchError(
"Mismatch between serialized mark and current interface, "
f"Serialized: {self}; current: {cls}"
)
return valid
def match_by_name(self) -> Optional[Type["Interface"]]:
"""
Try to find a matching interface by its name, returning it if found,
or None if not found.
"""
for i in Interface.interfaces(sort=False):
if i.name == self.name:
return i
return None
class JsonDict(BaseModel): class JsonDict(BaseModel):
""" """
@ -84,6 +123,29 @@ class JsonDict(BaseModel):
return value return value
class MarkedJson(BaseModel):
"""
Model of JSON dumped with an additional interface mark
with ``model_dump_json({'mark_interface': True})``
"""
interface: InterfaceMark
value: Union[list, dict]
"""
Inner value of the array, we don't validate for JsonDict here,
that should be downstream from us for performance reasons
"""
@classmethod
def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]:
"""
Try to cast to MarkedJson if applicable, otherwise return input
"""
if isinstance(value, dict) and "interface" in value and "value" in value:
value = MarkedJson(**value)
return value
class Interface(ABC, Generic[T]): class Interface(ABC, Generic[T]):
""" """
Abstract parent class for interfaces to different array formats Abstract parent class for interfaces to different array formats
@ -158,14 +220,24 @@ class Interface(ABC, Generic[T]):
def deserialize(self, array: Any) -> Union[V, Any]: def deserialize(self, array: Any) -> Union[V, Any]:
""" """
If given a JSON serialized version of the array, If given a JSON serialized version of the array,
deserialize it first deserialize it first.
Args: If a roundtrip-serialized :class:`.JsonDict`,
array: pass to :meth:`.JsonDict.handle_input`.
Returns:
If a roundtrip-serialized :class:`.MarkedJson`,
unpack mark, check for validity, warn if not,
and try to continue with validation
""" """
if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
try:
marked_array.interface.is_valid(self.__class__, raise_on_error=True)
except MarkMismatchError as e:
warnings.warn(
str(e) + "\nAttempting to continue validation...", stacklevel=2
)
array = marked_array.value
return self.json_model.handle_input(array) return self.json_model.handle_input(array)
def before_validation(self, array: Any) -> NDArrayType: def before_validation(self, array: Any) -> NDArrayType:
@ -274,13 +346,6 @@ class Interface(ABC, Generic[T]):
""" """
return array return array
def mark_input(self, array: Any) -> Any:
"""
Preserve metadata about the interface and passed input when dumping with
``round_trip``
"""
return array
@classmethod @classmethod
@abstractmethod @abstractmethod
def check(cls, array: Any) -> bool: def check(cls, array: Any) -> bool:
@ -320,7 +385,7 @@ class Interface(ABC, Generic[T]):
""" """
@classmethod @classmethod
def mark_json(cls, array: Union[list, dict]) -> dict: def mark_json(cls, array: Union[list, dict]) -> MarkedJson:
""" """
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``, When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
add additional annotations that would allow the serialized array to be add additional annotations that would allow the serialized array to be
@ -337,7 +402,7 @@ class Interface(ABC, Generic[T]):
'version': '1.2.2'}, 'version': '1.2.2'},
'value': [1.0, 2.0]} 'value': [1.0, 2.0]}
""" """
return {"interface": cls.mark_interface(), "value": array} return MarkedJson.model_construct(interface=cls.mark_interface(), value=array)
@classmethod @classmethod
def interfaces( def interfaces(
@ -390,6 +455,28 @@ class Interface(ABC, Generic[T]):
return tuple(in_types) return tuple(in_types)
@classmethod
def match_mark(cls, array: Any) -> Optional[Type["Interface"]]:
"""
Match a marked JSON dump of this array to the interface that it indicates.
First find an interface that matches by name, and then run its
``check`` method, because arrays can be dumped with a mark
but without ``round_trip == True`` (and thus can't necessarily
use the same interface that they were dumped with)
Returns:
Interface if match found, None otherwise
"""
mark = MarkedJson.try_cast(array)
if not isinstance(mark, MarkedJson):
return None
interface = mark.interface.match_by_name()
if interface is not None and interface.check(mark.value):
return interface
return None
@classmethod @classmethod
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]: def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
""" """
@ -407,11 +494,18 @@ class Interface(ABC, Generic[T]):
check each interface (as ordered by its ``priority`` , decreasing), check each interface (as ordered by its ``priority`` , decreasing),
and return on the first match. and return on the first match.
""" """
# Shortcircuit match if this is a marked json dump
array = MarkedJson.try_cast(array)
if (match := cls.match_mark(array)) is not None:
return match
elif isinstance(array, MarkedJson):
array = array.value
# first try and find a non-numpy interface, since the numpy interface # first try and find a non-numpy interface, since the numpy interface
# will try and load the array into memory in its check method # will try and load the array into memory in its check method
interfaces = cls.interfaces() interfaces = cls.interfaces()
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"] non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0] np_interface = [i for i in interfaces if i.name == "numpy"][0]
if fast: if fast:
matches = [] matches = []
@ -453,6 +547,7 @@ class Interface(ABC, Generic[T]):
return matches[0] return matches[0]
@classmethod @classmethod
@lru_cache(maxsize=32)
def mark_interface(cls) -> InterfaceMark: def mark_interface(cls) -> InterfaceMark:
""" """
Create an interface mark indicating this interface for validation after Create an interface mark indicating this interface for validation after
@ -470,5 +565,7 @@ class Interface(ABC, Generic[T]):
) )
except PackageNotFoundError: except PackageNotFoundError:
v = None v = None
interface_name = cls.__name__
return InterfaceMark(module=interface_module, cls=interface_name, version=v) return InterfaceMark(
module=interface_module, cls=cls.__name__, name=cls.name, version=v
)

View file

@ -48,7 +48,7 @@ class VideoProxy:
) )
if path is not None: if path is not None:
path = Path(path) path = Path(path).resolve()
self.path = path self.path = path
self._video = video # type: Optional[VideoCapture] self._video = video # type: Optional[VideoCapture]
@ -200,6 +200,8 @@ class VideoProxy:
raise NotImplementedError("Setting pixel values on videos is not supported!") raise NotImplementedError("Setting pixel values on videos is not supported!")
def __getattr__(self, item: str): def __getattr__(self, item: str):
if item == "__name__":
return "VideoProxy"
return getattr(self.video, item) return getattr(self.video, item)
def __eq__(self, other: "VideoProxy") -> bool: def __eq__(self, other: "VideoProxy") -> bool:

View file

@ -23,7 +23,8 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
if info.context: if info.context:
if info.context.get("mark_interface", False): if info.context.get("mark_interface", False):
array = interface_cls.mark_json(array) array = interface_cls.mark_json(array).model_dump()
if info.context.get("absolute_paths", False): if info.context.get("absolute_paths", False):
array = _absolutize_paths(array) array = _absolutize_paths(array)
else: else:

View file

@ -83,7 +83,6 @@ STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
@pytest.mark.shape
@pytest.fixture( @pytest.fixture(
scope="module", scope="module",
params=[ params=[
@ -121,7 +120,6 @@ def shape_cases(request) -> ValidationCase:
return request.param return request.param
@pytest.mark.dtype
@pytest.fixture( @pytest.fixture(
scope="module", scope="module",
params=[ params=[

View file

@ -104,7 +104,6 @@ def model_blank() -> Type[BaseModel]:
return BlankModel return BlankModel
@pytest.mark.hdf5
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def hdf5_file(tmp_output_dir_func) -> h5py.File: def hdf5_file(tmp_output_dir_func) -> h5py.File:
h5f_file = tmp_output_dir_func / "h5f.h5" h5f_file = tmp_output_dir_func / "h5f.h5"
@ -113,7 +112,6 @@ def hdf5_file(tmp_output_dir_func) -> h5py.File:
h5f.close() h5f.close()
@pytest.mark.hdf5
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def hdf5_array( def hdf5_array(
hdf5_file, request hdf5_file, request
@ -156,7 +154,6 @@ def hdf5_array(
return _hdf5_array return _hdf5_array
@pytest.mark.zarr
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath: def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
"""Zarr array within a nested array""" """Zarr array within a nested array"""
@ -167,7 +164,6 @@ def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
return ZarrArrayPath(file=file, path=path) return ZarrArrayPath(file=file, path=path)
@pytest.mark.zarr
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def zarr_array(tmp_output_dir_func) -> Path: def zarr_array(tmp_output_dir_func) -> Path:
file = tmp_output_dir_func / "array.zarr" file = tmp_output_dir_func / "array.zarr"
@ -176,7 +172,6 @@ def zarr_array(tmp_output_dir_func) -> Path:
return file return file
@pytest.mark.video
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def avi_video(tmp_path) -> Callable[[Tuple[int, int], int, bool], Path]: def avi_video(tmp_path) -> Callable[[Tuple[int, int], int, bool], Path]:
video_path = tmp_path / "test.avi" video_path = tmp_path / "test.avi"

View file

@ -231,3 +231,29 @@ def test_empty_dataset(dtype, tmp_path):
array: NDArray[Any, dtype] array: NDArray[Any, dtype]
_ = MyModel(array=(array_path, "/data")) _ = MyModel(array=(array_path, "/data"))
@pytest.mark.proxy
@pytest.mark.parametrize(
"comparison,valid",
[
(H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True),
(H5Proxy(file="test_file.h5", path="/subpath"), False),
(H5Proxy(file="different_file.h5", path="/subpath"), False),
(("different_file.h5", "/subpath", "sup"), ValueError),
("not even a proxy-like thing", ValueError),
],
)
def test_proxy_eq(comparison, valid):
"""
test the __eq__ method of H5ArrayProxy matches proxies to the same
dataset (and path), or raises a ValueError
"""
proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup")
if valid is True:
assert proxy_a == comparison
elif valid is False:
assert proxy_a != comparison
else:
with pytest.raises(valid):
assert proxy_a == comparison

View file

@ -4,11 +4,26 @@ for tests that should apply to all interfaces, use ``test_interfaces.py``
""" """
import gc import gc
from typing import Literal
import pytest import pytest
import numpy as np import numpy as np
from numpydantic.interface import Interface from numpydantic.interface import Interface, JsonDict
from pydantic import ValidationError
from numpydantic.interface.interface import V
class MyJsonDict(JsonDict):
type: Literal["my_json_dict"]
field: str
number: int
def to_array_input(self) -> V:
dumped = self.model_dump()
dumped["extra_input_param"] = True
return dumped
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -162,3 +177,36 @@ def test_interface_recursive(interfaces):
assert issubclass(interfaces.interface3, interfaces.interface1) assert issubclass(interfaces.interface3, interfaces.interface1)
assert issubclass(interfaces.interface1, Interface) assert issubclass(interfaces.interface1, Interface)
assert interfaces.interface4 in ifaces assert interfaces.interface4 in ifaces
@pytest.mark.serialization
def test_jsondict_is_valid():
"""
A JsonDict should return a bool true/false if it is valid or not,
and raise an error when requested
"""
invalid = {"doesnt": "have", "the": "props"}
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
assert MyJsonDict.is_valid(valid)
assert not MyJsonDict.is_valid(invalid)
with pytest.raises(ValidationError):
assert not MyJsonDict.is_valid(invalid, raise_on_error=True)
@pytest.mark.serialization
def test_jsondict_handle_input():
"""
JsonDict should be able to parse a valid dict and return it to the input format
"""
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
instantiated = MyJsonDict(**valid)
expected = {
"type": "my_json_dict",
"field": "a_field",
"number": 1,
"extra_input_param": True,
}
for item in (valid, instantiated):
result = MyJsonDict.handle_input(item)
assert result == expected

View file

@ -4,10 +4,38 @@ Tests that should be applied to all interfaces
import pytest import pytest
from typing import Callable from typing import Callable
from importlib.metadata import version
import json
import numpy as np import numpy as np
import dask.array as da import dask.array as da
from zarr.core import Array as ZarrArray from zarr.core import Array as ZarrArray
from numpydantic.interface import Interface from pydantic import BaseModel
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool):
"""Test model equality for roundtrip tests"""
if round_trip:
assert type(target.array) is type(source.array)
if isinstance(source.array, (np.ndarray, ZarrArray)):
assert np.array_equal(target.array, np.array(source.array))
elif isinstance(source.array, da.Array):
assert np.all(da.equal(target.array, source.array))
else:
assert target.array == source.array
assert target.array.dtype == source.array.dtype
else:
assert np.array_equal(target.array, np.array(source.array))
def test_dunder_len(all_interfaces):
"""
Each interface or proxy type should support __len__
"""
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
def test_interface_revalidate(all_interfaces): def test_interface_revalidate(all_interfaces):
@ -52,24 +80,51 @@ def test_interface_roundtrip_json(all_interfaces, round_trip):
""" """
All interfaces should be able to roundtrip to and from json All interfaces should be able to roundtrip to and from json
""" """
json = all_interfaces.model_dump_json(round_trip=round_trip) dumped_json = all_interfaces.model_dump_json(round_trip=round_trip)
model = all_interfaces.model_validate_json(json) model = all_interfaces.model_validate_json(dumped_json)
if round_trip: _test_roundtrip(all_interfaces, model, round_trip)
assert type(model.array) is type(all_interfaces.array)
if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)):
assert np.array_equal(model.array, np.array(all_interfaces.array))
elif isinstance(all_interfaces.array, da.Array):
assert np.all(da.equal(model.array, all_interfaces.array))
else:
assert model.array == all_interfaces.array
assert model.array.dtype == all_interfaces.array.dtype
else:
assert np.array_equal(model.array, np.array(all_interfaces.array))
def test_dunder_len(all_interfaces): @pytest.mark.serialization
@pytest.mark.parametrize("an_interface", Interface.interfaces())
def test_interface_mark_interface(an_interface):
""" """
Each interface or proxy type should support __len__ All interfaces should be able to mark the current version and interface info
""" """
assert len(all_interfaces.array) == all_interfaces.array.shape[0] mark = an_interface.mark_interface()
assert isinstance(mark, InterfaceMark)
assert mark.name == an_interface.name
assert mark.cls == an_interface.__name__
assert mark.module == an_interface.__module__
assert mark.version == version(mark.module.split(".")[0])
@pytest.mark.serialization
@pytest.mark.parametrize("valid", [True, False])
@pytest.mark.parametrize("round_trip", [True, False])
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
"""
All interfaces should be able to roundtrip with the marked interface,
and a mismatch should raise a warning and attempt to proceed
"""
dumped_json = all_interfaces.model_dump_json(
round_trip=round_trip, context={"mark_interface": True}
)
data = json.loads(dumped_json)
# ensure that we are a MarkedJson
_ = MarkedJson.model_validate_json(json.dumps(data["array"]))
if not valid:
# ruin the version
data["array"]["interface"]["version"] = "v99999999"
dumped_json = json.dumps(data)
with pytest.warns(match="Mismatch.*"):
model = all_interfaces.model_validate_json(dumped_json)
else:
model = all_interfaces.model_validate_json(dumped_json)
_test_roundtrip(all_interfaces, model, round_trip)

View file

@ -164,3 +164,42 @@ def test_video_close(avi_video):
assert instance.array._video is None assert instance.array._video is None
# reopen # reopen
assert isinstance(instance.array.video, cv2.VideoCapture) assert isinstance(instance.array.video, cv2.VideoCapture)
@pytest.mark.proxy
def test_video_not_exists(tmp_path):
"""
A video file that doesn't exist should raise an error
"""
video = VideoProxy(tmp_path / "not_real.avi")
with pytest.raises(FileNotFoundError):
_ = video.video
@pytest.mark.proxy
@pytest.mark.parametrize(
"comparison,valid",
[
(VideoProxy("test_video.avi"), True),
(VideoProxy("not_real_video.avi"), False),
("not even a video proxy", TypeError),
],
)
def test_video_proxy_eq(comparison, valid):
"""
Comparing a video proxy's equality should be valid if the path matches
Args:
comparison:
valid:
Returns:
"""
proxy_a = VideoProxy("test_video.avi")
if valid is True:
assert proxy_a == comparison
elif valid is False:
assert proxy_a != comparison
else:
with pytest.raises(valid):
assert proxy_a == comparison