mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
roundtripping marked arrays, roundtrip or not
This commit is contained in:
parent
8cc2574399
commit
708e6e81d8
12 changed files with 328 additions and 53 deletions
|
@ -25,3 +25,7 @@ class NoMatchError(MatchError):
|
|||
|
||||
class TooManyMatchesError(MatchError):
|
||||
"""Too many matches found by :class:`.Interface.match`"""
|
||||
|
||||
|
||||
class MarkMismatchError(MatchError):
|
||||
"""A serialized :class:`.InterfaceMark` doesn't match the receiving interface"""
|
||||
|
|
|
@ -4,16 +4,23 @@ Interfaces between nptyping types and array backends
|
|||
|
||||
from numpydantic.interface.dask import DaskInterface
|
||||
from numpydantic.interface.hdf5 import H5Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.interface.interface import (
|
||||
Interface,
|
||||
InterfaceMark,
|
||||
JsonDict,
|
||||
MarkedJson,
|
||||
)
|
||||
from numpydantic.interface.numpy import NumpyInterface
|
||||
from numpydantic.interface.video import VideoInterface
|
||||
from numpydantic.interface.zarr import ZarrInterface
|
||||
|
||||
__all__ = [
|
||||
"JsonDict",
|
||||
"Interface",
|
||||
"DaskInterface",
|
||||
"H5Interface",
|
||||
"Interface",
|
||||
"InterfaceMark",
|
||||
"JsonDict",
|
||||
"MarkedJson",
|
||||
"NumpyInterface",
|
||||
"VideoInterface",
|
||||
"ZarrInterface",
|
||||
|
|
|
@ -120,7 +120,7 @@ class H5Proxy:
|
|||
annotation_dtype: Optional[DtypeType] = None,
|
||||
):
|
||||
self._h5f = None
|
||||
self.file = Path(file)
|
||||
self.file = Path(file).resolve()
|
||||
self.path = path
|
||||
self.field = field
|
||||
self._annotation_dtype = annotation_dtype
|
||||
|
@ -156,6 +156,9 @@ class H5Proxy:
|
|||
return obj[:]
|
||||
|
||||
def __getattr__(self, item: str):
|
||||
if item == "__name__":
|
||||
# special case for H5Proxies that don't refer to a real file during testing
|
||||
return "H5Proxy"
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
val = getattr(obj, item)
|
||||
|
|
|
@ -3,16 +3,19 @@ Base Interface metaclass
|
|||
"""
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from operator import attrgetter
|
||||
from typing import Any, Generic, Tuple, Type, TypedDict, TypeVar, Union
|
||||
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, SerializationInfo, ValidationError
|
||||
|
||||
from numpydantic.exceptions import (
|
||||
DtypeError,
|
||||
MarkMismatchError,
|
||||
NoMatchError,
|
||||
ShapeError,
|
||||
TooManyMatchesError,
|
||||
|
@ -26,13 +29,49 @@ V = TypeVar("V") # input type
|
|||
W = TypeVar("W") # Any type in handle_input
|
||||
|
||||
|
||||
class InterfaceMark(TypedDict):
|
||||
class InterfaceMark(BaseModel):
|
||||
"""JSON-able mark to be able to round-trip json dumps"""
|
||||
|
||||
module: str
|
||||
cls: str
|
||||
name: str
|
||||
version: str
|
||||
|
||||
def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool:
|
||||
"""
|
||||
Check that a given interface matches the mark.
|
||||
|
||||
Args:
|
||||
cls (Type): Interface type to check
|
||||
raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
|
||||
is incorrect
|
||||
|
||||
Returns:
|
||||
bool
|
||||
|
||||
Raises:
|
||||
:class:`.MarkMismatchError` if requested by ``raise_on_error``
|
||||
for an invalid match
|
||||
"""
|
||||
mark = cls.mark_interface()
|
||||
valid = self == mark
|
||||
if not valid and raise_on_error:
|
||||
raise MarkMismatchError(
|
||||
"Mismatch between serialized mark and current interface, "
|
||||
f"Serialized: {self}; current: {cls}"
|
||||
)
|
||||
return valid
|
||||
|
||||
def match_by_name(self) -> Optional[Type["Interface"]]:
|
||||
"""
|
||||
Try to find a matching interface by its name, returning it if found,
|
||||
or None if not found.
|
||||
"""
|
||||
for i in Interface.interfaces(sort=False):
|
||||
if i.name == self.name:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
class JsonDict(BaseModel):
|
||||
"""
|
||||
|
@ -84,6 +123,29 @@ class JsonDict(BaseModel):
|
|||
return value
|
||||
|
||||
|
||||
class MarkedJson(BaseModel):
|
||||
"""
|
||||
Model of JSON dumped with an additional interface mark
|
||||
with ``model_dump_json({'mark_interface': True})``
|
||||
"""
|
||||
|
||||
interface: InterfaceMark
|
||||
value: Union[list, dict]
|
||||
"""
|
||||
Inner value of the array, we don't validate for JsonDict here,
|
||||
that should be downstream from us for performance reasons
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]:
|
||||
"""
|
||||
Try to cast to MarkedJson if applicable, otherwise return input
|
||||
"""
|
||||
if isinstance(value, dict) and "interface" in value and "value" in value:
|
||||
value = MarkedJson(**value)
|
||||
return value
|
||||
|
||||
|
||||
class Interface(ABC, Generic[T]):
|
||||
"""
|
||||
Abstract parent class for interfaces to different array formats
|
||||
|
@ -158,14 +220,24 @@ class Interface(ABC, Generic[T]):
|
|||
def deserialize(self, array: Any) -> Union[V, Any]:
|
||||
"""
|
||||
If given a JSON serialized version of the array,
|
||||
deserialize it first
|
||||
deserialize it first.
|
||||
|
||||
Args:
|
||||
array:
|
||||
|
||||
Returns:
|
||||
If a roundtrip-serialized :class:`.JsonDict`,
|
||||
pass to :meth:`.JsonDict.handle_input`.
|
||||
|
||||
If a roundtrip-serialized :class:`.MarkedJson`,
|
||||
unpack mark, check for validity, warn if not,
|
||||
and try to continue with validation
|
||||
"""
|
||||
if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
|
||||
try:
|
||||
marked_array.interface.is_valid(self.__class__, raise_on_error=True)
|
||||
except MarkMismatchError as e:
|
||||
warnings.warn(
|
||||
str(e) + "\nAttempting to continue validation...", stacklevel=2
|
||||
)
|
||||
array = marked_array.value
|
||||
|
||||
return self.json_model.handle_input(array)
|
||||
|
||||
def before_validation(self, array: Any) -> NDArrayType:
|
||||
|
@ -274,13 +346,6 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
return array
|
||||
|
||||
def mark_input(self, array: Any) -> Any:
|
||||
"""
|
||||
Preserve metadata about the interface and passed input when dumping with
|
||||
``round_trip``
|
||||
"""
|
||||
return array
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def check(cls, array: Any) -> bool:
|
||||
|
@ -320,7 +385,7 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
|
||||
@classmethod
|
||||
def mark_json(cls, array: Union[list, dict]) -> dict:
|
||||
def mark_json(cls, array: Union[list, dict]) -> MarkedJson:
|
||||
"""
|
||||
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
|
||||
add additional annotations that would allow the serialized array to be
|
||||
|
@ -337,7 +402,7 @@ class Interface(ABC, Generic[T]):
|
|||
'version': '1.2.2'},
|
||||
'value': [1.0, 2.0]}
|
||||
"""
|
||||
return {"interface": cls.mark_interface(), "value": array}
|
||||
return MarkedJson.model_construct(interface=cls.mark_interface(), value=array)
|
||||
|
||||
@classmethod
|
||||
def interfaces(
|
||||
|
@ -390,6 +455,28 @@ class Interface(ABC, Generic[T]):
|
|||
|
||||
return tuple(in_types)
|
||||
|
||||
@classmethod
|
||||
def match_mark(cls, array: Any) -> Optional[Type["Interface"]]:
|
||||
"""
|
||||
Match a marked JSON dump of this array to the interface that it indicates.
|
||||
|
||||
First find an interface that matches by name, and then run its
|
||||
``check`` method, because arrays can be dumped with a mark
|
||||
but without ``round_trip == True`` (and thus can't necessarily
|
||||
use the same interface that they were dumped with)
|
||||
|
||||
Returns:
|
||||
Interface if match found, None otherwise
|
||||
"""
|
||||
mark = MarkedJson.try_cast(array)
|
||||
if not isinstance(mark, MarkedJson):
|
||||
return None
|
||||
|
||||
interface = mark.interface.match_by_name()
|
||||
if interface is not None and interface.check(mark.value):
|
||||
return interface
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
||||
"""
|
||||
|
@ -407,11 +494,18 @@ class Interface(ABC, Generic[T]):
|
|||
check each interface (as ordered by its ``priority`` , decreasing),
|
||||
and return on the first match.
|
||||
"""
|
||||
# Shortcircuit match if this is a marked json dump
|
||||
array = MarkedJson.try_cast(array)
|
||||
if (match := cls.match_mark(array)) is not None:
|
||||
return match
|
||||
elif isinstance(array, MarkedJson):
|
||||
array = array.value
|
||||
|
||||
# first try and find a non-numpy interface, since the numpy interface
|
||||
# will try and load the array into memory in its check method
|
||||
interfaces = cls.interfaces()
|
||||
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
|
||||
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
|
||||
non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
|
||||
np_interface = [i for i in interfaces if i.name == "numpy"][0]
|
||||
|
||||
if fast:
|
||||
matches = []
|
||||
|
@ -453,6 +547,7 @@ class Interface(ABC, Generic[T]):
|
|||
return matches[0]
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def mark_interface(cls) -> InterfaceMark:
|
||||
"""
|
||||
Create an interface mark indicating this interface for validation after
|
||||
|
@ -470,5 +565,7 @@ class Interface(ABC, Generic[T]):
|
|||
)
|
||||
except PackageNotFoundError:
|
||||
v = None
|
||||
interface_name = cls.__name__
|
||||
return InterfaceMark(module=interface_module, cls=interface_name, version=v)
|
||||
|
||||
return InterfaceMark(
|
||||
module=interface_module, cls=cls.__name__, name=cls.name, version=v
|
||||
)
|
||||
|
|
|
@ -48,7 +48,7 @@ class VideoProxy:
|
|||
)
|
||||
|
||||
if path is not None:
|
||||
path = Path(path)
|
||||
path = Path(path).resolve()
|
||||
self.path = path
|
||||
|
||||
self._video = video # type: Optional[VideoCapture]
|
||||
|
@ -200,6 +200,8 @@ class VideoProxy:
|
|||
raise NotImplementedError("Setting pixel values on videos is not supported!")
|
||||
|
||||
def __getattr__(self, item: str):
|
||||
if item == "__name__":
|
||||
return "VideoProxy"
|
||||
return getattr(self.video, item)
|
||||
|
||||
def __eq__(self, other: "VideoProxy") -> bool:
|
||||
|
|
|
@ -23,7 +23,8 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
|||
|
||||
if info.context:
|
||||
if info.context.get("mark_interface", False):
|
||||
array = interface_cls.mark_json(array)
|
||||
array = interface_cls.mark_json(array).model_dump()
|
||||
|
||||
if info.context.get("absolute_paths", False):
|
||||
array = _absolutize_paths(array)
|
||||
else:
|
||||
|
|
|
@ -83,7 +83,6 @@ STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
|||
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[
|
||||
|
@ -121,7 +120,6 @@ def shape_cases(request) -> ValidationCase:
|
|||
return request.param
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[
|
||||
|
|
|
@ -104,7 +104,6 @@ def model_blank() -> Type[BaseModel]:
|
|||
return BlankModel
|
||||
|
||||
|
||||
@pytest.mark.hdf5
|
||||
@pytest.fixture(scope="function")
|
||||
def hdf5_file(tmp_output_dir_func) -> h5py.File:
|
||||
h5f_file = tmp_output_dir_func / "h5f.h5"
|
||||
|
@ -113,7 +112,6 @@ def hdf5_file(tmp_output_dir_func) -> h5py.File:
|
|||
h5f.close()
|
||||
|
||||
|
||||
@pytest.mark.hdf5
|
||||
@pytest.fixture(scope="function")
|
||||
def hdf5_array(
|
||||
hdf5_file, request
|
||||
|
@ -156,7 +154,6 @@ def hdf5_array(
|
|||
return _hdf5_array
|
||||
|
||||
|
||||
@pytest.mark.zarr
|
||||
@pytest.fixture(scope="function")
|
||||
def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
|
||||
"""Zarr array within a nested array"""
|
||||
|
@ -167,7 +164,6 @@ def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
|
|||
return ZarrArrayPath(file=file, path=path)
|
||||
|
||||
|
||||
@pytest.mark.zarr
|
||||
@pytest.fixture(scope="function")
|
||||
def zarr_array(tmp_output_dir_func) -> Path:
|
||||
file = tmp_output_dir_func / "array.zarr"
|
||||
|
@ -176,7 +172,6 @@ def zarr_array(tmp_output_dir_func) -> Path:
|
|||
return file
|
||||
|
||||
|
||||
@pytest.mark.video
|
||||
@pytest.fixture(scope="function")
|
||||
def avi_video(tmp_path) -> Callable[[Tuple[int, int], int, bool], Path]:
|
||||
video_path = tmp_path / "test.avi"
|
||||
|
|
|
@ -231,3 +231,29 @@ def test_empty_dataset(dtype, tmp_path):
|
|||
array: NDArray[Any, dtype]
|
||||
|
||||
_ = MyModel(array=(array_path, "/data"))
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
@pytest.mark.parametrize(
|
||||
"comparison,valid",
|
||||
[
|
||||
(H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True),
|
||||
(H5Proxy(file="test_file.h5", path="/subpath"), False),
|
||||
(H5Proxy(file="different_file.h5", path="/subpath"), False),
|
||||
(("different_file.h5", "/subpath", "sup"), ValueError),
|
||||
("not even a proxy-like thing", ValueError),
|
||||
],
|
||||
)
|
||||
def test_proxy_eq(comparison, valid):
|
||||
"""
|
||||
test the __eq__ method of H5ArrayProxy matches proxies to the same
|
||||
dataset (and path), or raises a ValueError
|
||||
"""
|
||||
proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup")
|
||||
if valid is True:
|
||||
assert proxy_a == comparison
|
||||
elif valid is False:
|
||||
assert proxy_a != comparison
|
||||
else:
|
||||
with pytest.raises(valid):
|
||||
assert proxy_a == comparison
|
||||
|
|
|
@ -4,11 +4,26 @@ for tests that should apply to all interfaces, use ``test_interfaces.py``
|
|||
"""
|
||||
|
||||
import gc
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.interface import Interface, JsonDict
|
||||
from pydantic import ValidationError
|
||||
|
||||
from numpydantic.interface.interface import V
|
||||
|
||||
|
||||
class MyJsonDict(JsonDict):
|
||||
type: Literal["my_json_dict"]
|
||||
field: str
|
||||
number: int
|
||||
|
||||
def to_array_input(self) -> V:
|
||||
dumped = self.model_dump()
|
||||
dumped["extra_input_param"] = True
|
||||
return dumped
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -162,3 +177,36 @@ def test_interface_recursive(interfaces):
|
|||
assert issubclass(interfaces.interface3, interfaces.interface1)
|
||||
assert issubclass(interfaces.interface1, Interface)
|
||||
assert interfaces.interface4 in ifaces
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_jsondict_is_valid():
|
||||
"""
|
||||
A JsonDict should return a bool true/false if it is valid or not,
|
||||
and raise an error when requested
|
||||
"""
|
||||
invalid = {"doesnt": "have", "the": "props"}
|
||||
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||
assert MyJsonDict.is_valid(valid)
|
||||
assert not MyJsonDict.is_valid(invalid)
|
||||
with pytest.raises(ValidationError):
|
||||
assert not MyJsonDict.is_valid(invalid, raise_on_error=True)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_jsondict_handle_input():
|
||||
"""
|
||||
JsonDict should be able to parse a valid dict and return it to the input format
|
||||
"""
|
||||
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||
instantiated = MyJsonDict(**valid)
|
||||
expected = {
|
||||
"type": "my_json_dict",
|
||||
"field": "a_field",
|
||||
"number": 1,
|
||||
"extra_input_param": True,
|
||||
}
|
||||
|
||||
for item in (valid, instantiated):
|
||||
result = MyJsonDict.handle_input(item)
|
||||
assert result == expected
|
||||
|
|
|
@ -4,10 +4,38 @@ Tests that should be applied to all interfaces
|
|||
|
||||
import pytest
|
||||
from typing import Callable
|
||||
from importlib.metadata import version
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import dask.array as da
|
||||
from zarr.core import Array as ZarrArray
|
||||
from numpydantic.interface import Interface
|
||||
from pydantic import BaseModel
|
||||
|
||||
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||
|
||||
|
||||
def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool):
|
||||
"""Test model equality for roundtrip tests"""
|
||||
if round_trip:
|
||||
assert type(target.array) is type(source.array)
|
||||
if isinstance(source.array, (np.ndarray, ZarrArray)):
|
||||
assert np.array_equal(target.array, np.array(source.array))
|
||||
elif isinstance(source.array, da.Array):
|
||||
assert np.all(da.equal(target.array, source.array))
|
||||
else:
|
||||
assert target.array == source.array
|
||||
|
||||
assert target.array.dtype == source.array.dtype
|
||||
else:
|
||||
assert np.array_equal(target.array, np.array(source.array))
|
||||
|
||||
|
||||
def test_dunder_len(all_interfaces):
|
||||
"""
|
||||
Each interface or proxy type should support __len__
|
||||
"""
|
||||
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
||||
|
||||
|
||||
def test_interface_revalidate(all_interfaces):
|
||||
|
@ -52,24 +80,51 @@ def test_interface_roundtrip_json(all_interfaces, round_trip):
|
|||
"""
|
||||
All interfaces should be able to roundtrip to and from json
|
||||
"""
|
||||
json = all_interfaces.model_dump_json(round_trip=round_trip)
|
||||
model = all_interfaces.model_validate_json(json)
|
||||
if round_trip:
|
||||
assert type(model.array) is type(all_interfaces.array)
|
||||
if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)):
|
||||
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
||||
elif isinstance(all_interfaces.array, da.Array):
|
||||
assert np.all(da.equal(model.array, all_interfaces.array))
|
||||
else:
|
||||
assert model.array == all_interfaces.array
|
||||
dumped_json = all_interfaces.model_dump_json(round_trip=round_trip)
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
_test_roundtrip(all_interfaces, model, round_trip)
|
||||
|
||||
assert model.array.dtype == all_interfaces.array.dtype
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("an_interface", Interface.interfaces())
|
||||
def test_interface_mark_interface(an_interface):
|
||||
"""
|
||||
All interfaces should be able to mark the current version and interface info
|
||||
"""
|
||||
mark = an_interface.mark_interface()
|
||||
assert isinstance(mark, InterfaceMark)
|
||||
assert mark.name == an_interface.name
|
||||
assert mark.cls == an_interface.__name__
|
||||
assert mark.module == an_interface.__module__
|
||||
assert mark.version == version(mark.module.split(".")[0])
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("valid", [True, False])
|
||||
@pytest.mark.parametrize("round_trip", [True, False])
|
||||
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
|
||||
def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
|
||||
"""
|
||||
All interfaces should be able to roundtrip with the marked interface,
|
||||
and a mismatch should raise a warning and attempt to proceed
|
||||
"""
|
||||
dumped_json = all_interfaces.model_dump_json(
|
||||
round_trip=round_trip, context={"mark_interface": True}
|
||||
)
|
||||
|
||||
data = json.loads(dumped_json)
|
||||
|
||||
# ensure that we are a MarkedJson
|
||||
_ = MarkedJson.model_validate_json(json.dumps(data["array"]))
|
||||
|
||||
if not valid:
|
||||
# ruin the version
|
||||
data["array"]["interface"]["version"] = "v99999999"
|
||||
dumped_json = json.dumps(data)
|
||||
|
||||
with pytest.warns(match="Mismatch.*"):
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
else:
|
||||
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
|
||||
|
||||
def test_dunder_len(all_interfaces):
|
||||
"""
|
||||
Each interface or proxy type should support __len__
|
||||
"""
|
||||
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
||||
_test_roundtrip(all_interfaces, model, round_trip)
|
||||
|
|
|
@ -164,3 +164,42 @@ def test_video_close(avi_video):
|
|||
assert instance.array._video is None
|
||||
# reopen
|
||||
assert isinstance(instance.array.video, cv2.VideoCapture)
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
def test_video_not_exists(tmp_path):
|
||||
"""
|
||||
A video file that doesn't exist should raise an error
|
||||
"""
|
||||
video = VideoProxy(tmp_path / "not_real.avi")
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_ = video.video
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
@pytest.mark.parametrize(
|
||||
"comparison,valid",
|
||||
[
|
||||
(VideoProxy("test_video.avi"), True),
|
||||
(VideoProxy("not_real_video.avi"), False),
|
||||
("not even a video proxy", TypeError),
|
||||
],
|
||||
)
|
||||
def test_video_proxy_eq(comparison, valid):
|
||||
"""
|
||||
Comparing a video proxy's equality should be valid if the path matches
|
||||
Args:
|
||||
comparison:
|
||||
valid:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
proxy_a = VideoProxy("test_video.avi")
|
||||
if valid is True:
|
||||
assert proxy_a == comparison
|
||||
elif valid is False:
|
||||
assert proxy_a != comparison
|
||||
else:
|
||||
with pytest.raises(valid):
|
||||
assert proxy_a == comparison
|
||||
|
|
Loading…
Reference in a new issue