From 708e6e81d8b9cb1b61a121003adc7eb6944f00a4 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 23 Sep 2024 15:54:55 -0700 Subject: [PATCH] roundtripping marked arrays, roundtrip or not --- src/numpydantic/exceptions.py | 4 + src/numpydantic/interface/__init__.py | 13 +- src/numpydantic/interface/hdf5.py | 5 +- src/numpydantic/interface/interface.py | 137 +++++++++++++++++--- src/numpydantic/interface/video.py | 4 +- src/numpydantic/serialization.py | 3 +- tests/conftest.py | 2 - tests/fixtures.py | 5 - tests/test_interface/test_hdf5.py | 26 ++++ tests/test_interface/test_interface_base.py | 50 ++++++- tests/test_interface/test_interfaces.py | 93 ++++++++++--- tests/test_interface/test_video.py | 39 ++++++ 12 files changed, 328 insertions(+), 53 deletions(-) diff --git a/src/numpydantic/exceptions.py b/src/numpydantic/exceptions.py index a61258f..c23b96f 100644 --- a/src/numpydantic/exceptions.py +++ b/src/numpydantic/exceptions.py @@ -25,3 +25,7 @@ class NoMatchError(MatchError): class TooManyMatchesError(MatchError): """Too many matches found by :class:`.Interface.match`""" + + +class MarkMismatchError(MatchError): + """A serialized :class:`.InterfaceMark` doesn't match the receiving interface""" diff --git a/src/numpydantic/interface/__init__.py b/src/numpydantic/interface/__init__.py index c5bd3f2..36c7d97 100644 --- a/src/numpydantic/interface/__init__.py +++ b/src/numpydantic/interface/__init__.py @@ -4,16 +4,23 @@ Interfaces between nptyping types and array backends from numpydantic.interface.dask import DaskInterface from numpydantic.interface.hdf5 import H5Interface -from numpydantic.interface.interface import Interface, JsonDict +from numpydantic.interface.interface import ( + Interface, + InterfaceMark, + JsonDict, + MarkedJson, +) from numpydantic.interface.numpy import NumpyInterface from numpydantic.interface.video import VideoInterface from numpydantic.interface.zarr import ZarrInterface __all__ = [ - "JsonDict", - "Interface", "DaskInterface", "H5Interface", + "Interface", + "InterfaceMark", + "JsonDict", + "MarkedJson", "NumpyInterface", "VideoInterface", "ZarrInterface", diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 67b9899..9215ec2 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -120,7 +120,7 @@ class H5Proxy: annotation_dtype: Optional[DtypeType] = None, ): self._h5f = None - self.file = Path(file) + self.file = Path(file).resolve() self.path = path self.field = field self._annotation_dtype = annotation_dtype @@ -156,6 +156,9 @@ class H5Proxy: return obj[:] def __getattr__(self, item: str): + if item == "__name__": + # special case for H5Proxies that don't refer to a real file during testing + return "H5Proxy" with h5py.File(self.file, "r") as h5f: obj = h5f.get(self.path) val = getattr(obj, item) diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 2cb61f4..6b1e5a4 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -3,16 +3,19 @@ Base Interface metaclass """ import inspect +import warnings from abc import ABC, abstractmethod +from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from operator import attrgetter -from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union +from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union import numpy as np from pydantic import BaseModel, SerializationInfo, ValidationError from numpydantic.exceptions import ( DtypeError, + MarkMismatchError, NoMatchError, ShapeError, TooManyMatchesError, @@ -26,13 +29,49 @@ V = TypeVar("V") # input type W = TypeVar("W") # Any type in handle_input -class InterfaceMark(TypedDict): +class InterfaceMark(BaseModel): """JSON-able mark to be able to round-trip json dumps""" module: str cls: str + name: str version: str + def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool: + """ + Check that a given interface matches the mark. + + Args: + cls (Type): Interface type to check + raise_on_error (bool): Raise an ``MarkMismatchError`` when the match + is incorrect + + Returns: + bool + + Raises: + :class:`.MarkMismatchError` if requested by ``raise_on_error`` + for an invalid match + """ + mark = cls.mark_interface() + valid = self == mark + if not valid and raise_on_error: + raise MarkMismatchError( + "Mismatch between serialized mark and current interface, " + f"Serialized: {self}; current: {cls}" + ) + return valid + + def match_by_name(self) -> Optional[Type["Interface"]]: + """ + Try to find a matching interface by its name, returning it if found, + or None if not found. + """ + for i in Interface.interfaces(sort=False): + if i.name == self.name: + return i + return None + class JsonDict(BaseModel): """ @@ -84,6 +123,29 @@ class JsonDict(BaseModel): return value +class MarkedJson(BaseModel): + """ + Model of JSON dumped with an additional interface mark + with ``model_dump_json({'mark_interface': True})`` + """ + + interface: InterfaceMark + value: Union[list, dict] + """ + Inner value of the array, we don't validate for JsonDict here, + that should be downstream from us for performance reasons + """ + + @classmethod + def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]: + """ + Try to cast to MarkedJson if applicable, otherwise return input + """ + if isinstance(value, dict) and "interface" in value and "value" in value: + value = MarkedJson(**value) + return value + + class Interface(ABC, Generic[T]): """ Abstract parent class for interfaces to different array formats @@ -158,14 +220,24 @@ class Interface(ABC, Generic[T]): def deserialize(self, array: Any) -> Union[V, Any]: """ If given a JSON serialized version of the array, - deserialize it first + deserialize it first. - Args: - array: - - Returns: + If a roundtrip-serialized :class:`.JsonDict`, + pass to :meth:`.JsonDict.handle_input`. + If a roundtrip-serialized :class:`.MarkedJson`, + unpack mark, check for validity, warn if not, + and try to continue with validation """ + if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson): + try: + marked_array.interface.is_valid(self.__class__, raise_on_error=True) + except MarkMismatchError as e: + warnings.warn( + str(e) + "\nAttempting to continue validation...", stacklevel=2 + ) + array = marked_array.value + return self.json_model.handle_input(array) def before_validation(self, array: Any) -> NDArrayType: @@ -274,13 +346,6 @@ class Interface(ABC, Generic[T]): """ return array - def mark_input(self, array: Any) -> Any: - """ - Preserve metadata about the interface and passed input when dumping with - ``round_trip`` - """ - return array - @classmethod @abstractmethod def check(cls, array: Any) -> bool: @@ -320,7 +385,7 @@ class Interface(ABC, Generic[T]): """ @classmethod - def mark_json(cls, array: Union[list, dict]) -> dict: + def mark_json(cls, array: Union[list, dict]) -> MarkedJson: """ When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``, add additional annotations that would allow the serialized array to be @@ -337,7 +402,7 @@ class Interface(ABC, Generic[T]): 'version': '1.2.2'}, 'value': [1.0, 2.0]} """ - return {"interface": cls.mark_interface(), "value": array} + return MarkedJson.model_construct(interface=cls.mark_interface(), value=array) @classmethod def interfaces( @@ -390,6 +455,28 @@ class Interface(ABC, Generic[T]): return tuple(in_types) + @classmethod + def match_mark(cls, array: Any) -> Optional[Type["Interface"]]: + """ + Match a marked JSON dump of this array to the interface that it indicates. + + First find an interface that matches by name, and then run its + ``check`` method, because arrays can be dumped with a mark + but without ``round_trip == True`` (and thus can't necessarily + use the same interface that they were dumped with) + + Returns: + Interface if match found, None otherwise + """ + mark = MarkedJson.try_cast(array) + if not isinstance(mark, MarkedJson): + return None + + interface = mark.interface.match_by_name() + if interface is not None and interface.check(mark.value): + return interface + return None + @classmethod def match(cls, array: Any, fast: bool = False) -> Type["Interface"]: """ @@ -407,11 +494,18 @@ class Interface(ABC, Generic[T]): check each interface (as ordered by its ``priority`` , decreasing), and return on the first match. """ + # Shortcircuit match if this is a marked json dump + array = MarkedJson.try_cast(array) + if (match := cls.match_mark(array)) is not None: + return match + elif isinstance(array, MarkedJson): + array = array.value + # first try and find a non-numpy interface, since the numpy interface # will try and load the array into memory in its check method interfaces = cls.interfaces() - non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"] - np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0] + non_np_interfaces = [i for i in interfaces if i.name != "numpy"] + np_interface = [i for i in interfaces if i.name == "numpy"][0] if fast: matches = [] @@ -453,6 +547,7 @@ class Interface(ABC, Generic[T]): return matches[0] @classmethod + @lru_cache(maxsize=32) def mark_interface(cls) -> InterfaceMark: """ Create an interface mark indicating this interface for validation after @@ -470,5 +565,7 @@ class Interface(ABC, Generic[T]): ) except PackageNotFoundError: v = None - interface_name = cls.__name__ - return InterfaceMark(module=interface_module, cls=interface_name, version=v) + + return InterfaceMark( + module=interface_module, cls=cls.__name__, name=cls.name, version=v + ) diff --git a/src/numpydantic/interface/video.py b/src/numpydantic/interface/video.py index b214a74..53a3ba5 100644 --- a/src/numpydantic/interface/video.py +++ b/src/numpydantic/interface/video.py @@ -48,7 +48,7 @@ class VideoProxy: ) if path is not None: - path = Path(path) + path = Path(path).resolve() self.path = path self._video = video # type: Optional[VideoCapture] @@ -200,6 +200,8 @@ class VideoProxy: raise NotImplementedError("Setting pixel values on videos is not supported!") def __getattr__(self, item: str): + if item == "__name__": + return "VideoProxy" return getattr(self.video, item) def __eq__(self, other: "VideoProxy") -> bool: diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index f645239..f5c7b35 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -23,7 +23,8 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: if info.context: if info.context.get("mark_interface", False): - array = interface_cls.mark_json(array) + array = interface_cls.mark_json(array).model_dump() + if info.context.get("absolute_paths", False): array = _absolutize_paths(array) else: diff --git a/tests/conftest.py b/tests/conftest.py index 3870be2..0467f25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -83,7 +83,6 @@ STRING: TypeAlias = NDArray[Shape["*, *, *"], str] MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] -@pytest.mark.shape @pytest.fixture( scope="module", params=[ @@ -121,7 +120,6 @@ def shape_cases(request) -> ValidationCase: return request.param -@pytest.mark.dtype @pytest.fixture( scope="module", params=[ diff --git a/tests/fixtures.py b/tests/fixtures.py index fb393b5..89359ae 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -104,7 +104,6 @@ def model_blank() -> Type[BaseModel]: return BlankModel -@pytest.mark.hdf5 @pytest.fixture(scope="function") def hdf5_file(tmp_output_dir_func) -> h5py.File: h5f_file = tmp_output_dir_func / "h5f.h5" @@ -113,7 +112,6 @@ def hdf5_file(tmp_output_dir_func) -> h5py.File: h5f.close() -@pytest.mark.hdf5 @pytest.fixture(scope="function") def hdf5_array( hdf5_file, request @@ -156,7 +154,6 @@ def hdf5_array( return _hdf5_array -@pytest.mark.zarr @pytest.fixture(scope="function") def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath: """Zarr array within a nested array""" @@ -167,7 +164,6 @@ def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath: return ZarrArrayPath(file=file, path=path) -@pytest.mark.zarr @pytest.fixture(scope="function") def zarr_array(tmp_output_dir_func) -> Path: file = tmp_output_dir_func / "array.zarr" @@ -176,7 +172,6 @@ def zarr_array(tmp_output_dir_func) -> Path: return file -@pytest.mark.video @pytest.fixture(scope="function") def avi_video(tmp_path) -> Callable[[Tuple[int, int], int, bool], Path]: video_path = tmp_path / "test.avi" diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index f88b411..c412e7a 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -231,3 +231,29 @@ def test_empty_dataset(dtype, tmp_path): array: NDArray[Any, dtype] _ = MyModel(array=(array_path, "/data")) + + +@pytest.mark.proxy +@pytest.mark.parametrize( + "comparison,valid", + [ + (H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True), + (H5Proxy(file="test_file.h5", path="/subpath"), False), + (H5Proxy(file="different_file.h5", path="/subpath"), False), + (("different_file.h5", "/subpath", "sup"), ValueError), + ("not even a proxy-like thing", ValueError), + ], +) +def test_proxy_eq(comparison, valid): + """ + test the __eq__ method of H5ArrayProxy matches proxies to the same + dataset (and path), or raises a ValueError + """ + proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup") + if valid is True: + assert proxy_a == comparison + elif valid is False: + assert proxy_a != comparison + else: + with pytest.raises(valid): + assert proxy_a == comparison diff --git a/tests/test_interface/test_interface_base.py b/tests/test_interface/test_interface_base.py index baacc60..1c18e73 100644 --- a/tests/test_interface/test_interface_base.py +++ b/tests/test_interface/test_interface_base.py @@ -4,11 +4,26 @@ for tests that should apply to all interfaces, use ``test_interfaces.py`` """ import gc +from typing import Literal import pytest import numpy as np -from numpydantic.interface import Interface +from numpydantic.interface import Interface, JsonDict +from pydantic import ValidationError + +from numpydantic.interface.interface import V + + +class MyJsonDict(JsonDict): + type: Literal["my_json_dict"] + field: str + number: int + + def to_array_input(self) -> V: + dumped = self.model_dump() + dumped["extra_input_param"] = True + return dumped @pytest.fixture(scope="module") @@ -162,3 +177,36 @@ def test_interface_recursive(interfaces): assert issubclass(interfaces.interface3, interfaces.interface1) assert issubclass(interfaces.interface1, Interface) assert interfaces.interface4 in ifaces + + +@pytest.mark.serialization +def test_jsondict_is_valid(): + """ + A JsonDict should return a bool true/false if it is valid or not, + and raise an error when requested + """ + invalid = {"doesnt": "have", "the": "props"} + valid = {"type": "my_json_dict", "field": "a_field", "number": 1} + assert MyJsonDict.is_valid(valid) + assert not MyJsonDict.is_valid(invalid) + with pytest.raises(ValidationError): + assert not MyJsonDict.is_valid(invalid, raise_on_error=True) + + +@pytest.mark.serialization +def test_jsondict_handle_input(): + """ + JsonDict should be able to parse a valid dict and return it to the input format + """ + valid = {"type": "my_json_dict", "field": "a_field", "number": 1} + instantiated = MyJsonDict(**valid) + expected = { + "type": "my_json_dict", + "field": "a_field", + "number": 1, + "extra_input_param": True, + } + + for item in (valid, instantiated): + result = MyJsonDict.handle_input(item) + assert result == expected diff --git a/tests/test_interface/test_interfaces.py b/tests/test_interface/test_interfaces.py index 01709d7..faec0d8 100644 --- a/tests/test_interface/test_interfaces.py +++ b/tests/test_interface/test_interfaces.py @@ -4,10 +4,38 @@ Tests that should be applied to all interfaces import pytest from typing import Callable +from importlib.metadata import version +import json + import numpy as np import dask.array as da from zarr.core import Array as ZarrArray -from numpydantic.interface import Interface +from pydantic import BaseModel + +from numpydantic.interface import Interface, InterfaceMark, MarkedJson + + +def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool): + """Test model equality for roundtrip tests""" + if round_trip: + assert type(target.array) is type(source.array) + if isinstance(source.array, (np.ndarray, ZarrArray)): + assert np.array_equal(target.array, np.array(source.array)) + elif isinstance(source.array, da.Array): + assert np.all(da.equal(target.array, source.array)) + else: + assert target.array == source.array + + assert target.array.dtype == source.array.dtype + else: + assert np.array_equal(target.array, np.array(source.array)) + + +def test_dunder_len(all_interfaces): + """ + Each interface or proxy type should support __len__ + """ + assert len(all_interfaces.array) == all_interfaces.array.shape[0] def test_interface_revalidate(all_interfaces): @@ -52,24 +80,51 @@ def test_interface_roundtrip_json(all_interfaces, round_trip): """ All interfaces should be able to roundtrip to and from json """ - json = all_interfaces.model_dump_json(round_trip=round_trip) - model = all_interfaces.model_validate_json(json) - if round_trip: - assert type(model.array) is type(all_interfaces.array) - if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)): - assert np.array_equal(model.array, np.array(all_interfaces.array)) - elif isinstance(all_interfaces.array, da.Array): - assert np.all(da.equal(model.array, all_interfaces.array)) - else: - assert model.array == all_interfaces.array + dumped_json = all_interfaces.model_dump_json(round_trip=round_trip) + model = all_interfaces.model_validate_json(dumped_json) + _test_roundtrip(all_interfaces, model, round_trip) - assert model.array.dtype == all_interfaces.array.dtype + +@pytest.mark.serialization +@pytest.mark.parametrize("an_interface", Interface.interfaces()) +def test_interface_mark_interface(an_interface): + """ + All interfaces should be able to mark the current version and interface info + """ + mark = an_interface.mark_interface() + assert isinstance(mark, InterfaceMark) + assert mark.name == an_interface.name + assert mark.cls == an_interface.__name__ + assert mark.module == an_interface.__module__ + assert mark.version == version(mark.module.split(".")[0]) + + +@pytest.mark.serialization +@pytest.mark.parametrize("valid", [True, False]) +@pytest.mark.parametrize("round_trip", [True, False]) +@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark") +def test_interface_mark_roundtrip(all_interfaces, valid, round_trip): + """ + All interfaces should be able to roundtrip with the marked interface, + and a mismatch should raise a warning and attempt to proceed + """ + dumped_json = all_interfaces.model_dump_json( + round_trip=round_trip, context={"mark_interface": True} + ) + + data = json.loads(dumped_json) + + # ensure that we are a MarkedJson + _ = MarkedJson.model_validate_json(json.dumps(data["array"])) + + if not valid: + # ruin the version + data["array"]["interface"]["version"] = "v99999999" + dumped_json = json.dumps(data) + + with pytest.warns(match="Mismatch.*"): + model = all_interfaces.model_validate_json(dumped_json) else: - assert np.array_equal(model.array, np.array(all_interfaces.array)) + model = all_interfaces.model_validate_json(dumped_json) - -def test_dunder_len(all_interfaces): - """ - Each interface or proxy type should support __len__ - """ - assert len(all_interfaces.array) == all_interfaces.array.shape[0] + _test_roundtrip(all_interfaces, model, round_trip) diff --git a/tests/test_interface/test_video.py b/tests/test_interface/test_video.py index 44c8b9a..5f03a57 100644 --- a/tests/test_interface/test_video.py +++ b/tests/test_interface/test_video.py @@ -164,3 +164,42 @@ def test_video_close(avi_video): assert instance.array._video is None # reopen assert isinstance(instance.array.video, cv2.VideoCapture) + + +@pytest.mark.proxy +def test_video_not_exists(tmp_path): + """ + A video file that doesn't exist should raise an error + """ + video = VideoProxy(tmp_path / "not_real.avi") + with pytest.raises(FileNotFoundError): + _ = video.video + + +@pytest.mark.proxy +@pytest.mark.parametrize( + "comparison,valid", + [ + (VideoProxy("test_video.avi"), True), + (VideoProxy("not_real_video.avi"), False), + ("not even a video proxy", TypeError), + ], +) +def test_video_proxy_eq(comparison, valid): + """ + Comparing a video proxy's equality should be valid if the path matches + Args: + comparison: + valid: + + Returns: + + """ + proxy_a = VideoProxy("test_video.avi") + if valid is True: + assert proxy_a == comparison + elif valid is False: + assert proxy_a != comparison + else: + with pytest.raises(valid): + assert proxy_a == comparison