refactoring array generation, swapping in the interface case generators

This commit is contained in:
sneakers-the-rat 2024-10-04 00:46:49 -07:00
parent e701bf6e9b
commit 3356738e42
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
9 changed files with 262 additions and 189 deletions

View file

@ -2,15 +2,6 @@
Utilities for testing and 3rd-party interface development. Utilities for testing and 3rd-party interface development.
Only things that *don't* require pytest go in this module.
We want to keep all test-time specific behavior there,
and have this just serve as helpers exposed for downstream interface development.
We want to avoid pytest stuff bleeding in here because then we limit
the ability for downstream developers to configure their own tests.
*(If there is some reason to change this division of labor, just raise an issue and let's chat.)*
```{toctree} ```{toctree}
cases cases
helpers helpers

View file

@ -203,9 +203,19 @@ def merged_product(
iterator = merged_product(shape_cases, dtype_cases)) iterator = merged_product(shape_cases, dtype_cases))
next(iterator) next(iterator)
# ValidationCase(shape=(10, 10, 10), dtype=float, passes=True, id="valid shape-float") # ValidationCase(
# shape=(10, 10, 10),
# dtype=float,
# passes=True,
# id="valid shape-float"
# )
next(iterator) next(iterator)
# ValidationCase(shape=(10, 10, 10), dtype=int, passes=False, id="valid shape-int") # ValidationCase(
# shape=(10, 10, 10),
# dtype=int,
# passes=False,
# id="valid shape-int"
# )
""" """

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
from numpydantic.dtype import Float from numpydantic.dtype import Float
from numpydantic.interface import Interface from numpydantic.interface import Interface
from numpydantic.types import NDArrayType from numpydantic.types import DtypeType, NDArrayType
class InterfaceCase(ABC): class InterfaceCase(ABC):
@ -29,43 +29,64 @@ class InterfaceCase(ABC):
"""The interface that this helper is for""" """The interface that this helper is for"""
@classmethod @classmethod
@abstractmethod def array_from_case(
def generate_array( cls, case: "ValidationCase", path: Optional[Path] = None
cls, case: "ValidationCase", path: Path
) -> Optional[NDArrayType]: ) -> Optional[NDArrayType]:
""" """
Generate an array from the given validation case. Generate an array from the given validation case.
Returns ``None`` if an array can't be generated for a specific case. Returns ``None`` if an array can't be generated for a specific case.
""" """
return cls.make_array(shape=case.shape, dtype=case.dtype, path=path)
@classmethod @classmethod
def validate_array(cls, case: "ValidationCase", path: Path) -> Optional[bool]: @abstractmethod
def make_array(
cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> Optional[NDArrayType]:
"""
Make an array from a shape and dtype, and a path if needed
"""
@classmethod
def validate_case(cls, case: "ValidationCase", path: Path) -> bool:
""" """
Validate a generated array against the annotation in the validation case. Validate a generated array against the annotation in the validation case.
Kept in the InterfaceCase in case an interface has specific Kept in the InterfaceCase in case an interface has specific
needs aside from just validating against a model, but typically left as is. needs aside from just validating against a model, but typically left as is.
Does not raise on Validation errors -
returns bool instead for consistency's sake.
If an array can't be generated for a given case, returns `None` If an array can't be generated for a given case, returns `None`
so that the calling function can know to skip rather than fail the case. so that the calling function can know to skip rather than fail the case.
Raises exceptions if validation fails (or succeeds when it shouldn't)
Args:
case (ValidationCase): The validation case to validate.
path (Path): Path to generate arrays into, if any.
Returns:
``True`` if array is valid and was supposed to be,
or invalid and wasn't supposed to be
""" """
array = cls.generate_array(case, path) import pytest
array = cls.array_from_case(case, path)
if array is None: if array is None:
return None pytest.skip()
try: if case.passes:
case.model(array=array) case.model(array=array)
# True if case is supposed to pass, False if it's not... return True
return case.passes else:
except ValidationError: with pytest.raises(ValidationError):
# False if the case is supposed to pass, True if it is... case.model(array=array)
return not case.passes return True
@classmethod @classmethod
def skip(cls, case: "ValidationCase") -> bool: def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
""" """
Whether a given interface should be skipped for the case Whether a given interface should be skipped for the case
""" """
@ -97,6 +118,9 @@ class ValidationCase(BaseModel):
passes: bool = False passes: bool = False
"""Whether the validation should pass or not""" """Whether the validation should pass or not"""
interface: Optional[InterfaceCase] = None interface: Optional[InterfaceCase] = None
"""The interface test case to generate and validate the array with"""
path: Optional[Path] = None
"""The path to generate arrays into, if any."""
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -110,6 +134,39 @@ class ValidationCase(BaseModel):
return Model return Model
def validate_case(self, path: Optional[Path] = None) -> bool:
"""
Whether the generated array correctly validated against the annotation,
given the interface
Args:
path (:class:`pathlib.Path`): Directory to generate array into, if on disk.
Raises:
ValueError: if an ``interface`` is missing
"""
if self.interface is None:
raise ValueError("Missing an interface")
if path is None:
if self.path:
path = self.path
else:
raise ValueError("Missing a path to generate arrays into")
return self.interface.validate_case(self, path)
def array(self, path: Path) -> NDArrayType:
"""Generate an array for the validation case if we have an interface to do so"""
if self.interface is None:
raise ValueError("Missing an interface")
if path is None:
if self.path:
path = self.path
else:
raise ValueError("Missing a path to generate arrays into")
return self.interface.array_from_case(self, path)
def merge( def merge(
self, other: Union["ValidationCase", Sequence["ValidationCase"]] self, other: Union["ValidationCase", Sequence["ValidationCase"]]
) -> "ValidationCase": ) -> "ValidationCase":
@ -154,7 +211,9 @@ class ValidationCase(BaseModel):
(eg. due to the interface case being incompatible (eg. due to the interface case being incompatible
with the requested dtype or shape) with the requested dtype or shape)
""" """
return bool(self.interface is not None and self.interface.skip()) return bool(
self.interface is not None and self.interface.skip(self.shape, self.dtype)
)
def merge_cases(*args: ValidationCase) -> ValidationCase: def merge_cases(*args: ValidationCase) -> ValidationCase:

View file

@ -1,6 +1,6 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, Tuple
import cv2 import cv2
import dask.array as da import dask.array as da
@ -18,7 +18,8 @@ from numpydantic.interface import (
ZarrArrayPath, ZarrArrayPath,
ZarrInterface, ZarrInterface,
) )
from numpydantic.testing.helpers import InterfaceCase, ValidationCase from numpydantic.testing.helpers import InterfaceCase
from numpydantic.types import DtypeType
class NumpyCase(InterfaceCase): class NumpyCase(InterfaceCase):
@ -27,11 +28,16 @@ class NumpyCase(InterfaceCase):
interface = NumpyInterface interface = NumpyInterface
@classmethod @classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> np.ndarray: def make_array(
if issubclass(case.dtype, BaseModel): cls,
return np.full(shape=case.shape, fill_value=case.dtype(x=1)) shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> np.ndarray:
if issubclass(dtype, BaseModel):
return np.full(shape=shape, fill_value=dtype(x=1))
else: else:
return np.zeros(shape=case.shape, dtype=case.dtype) return np.zeros(shape=shape, dtype=dtype)
class _HDF5MetaCase(InterfaceCase): class _HDF5MetaCase(InterfaceCase):
@ -40,33 +46,34 @@ class _HDF5MetaCase(InterfaceCase):
interface = H5Interface interface = H5Interface
@classmethod @classmethod
def skip(cls, case: "ValidationCase") -> bool: def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
return not issubclass(case.dtype, BaseModel) return issubclass(dtype, BaseModel)
class HDF5Case(_HDF5MetaCase): class HDF5Case(_HDF5MetaCase):
"""HDF5 Array""" """HDF5 Array"""
@classmethod @classmethod
def generate_array( def make_array(
cls, case: "ValidationCase", path: Path cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> Optional[H5ArrayPath]: ) -> Optional[H5ArrayPath]:
if cls.skip(case): if cls.skip(shape, dtype):
return None return None
hdf5_file = path / "h5f.h5" hdf5_file = path / "h5f.h5"
array_path = ( array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__
)
generator = np.random.default_rng() generator = np.random.default_rng()
if case.dtype is str: if dtype is str:
data = generator.random(case.shape).astype(bytes) data = generator.random(shape).astype(bytes)
elif case.dtype is datetime: elif dtype is datetime:
data = np.empty(case.shape, dtype="S32") data = np.empty(shape, dtype="S32")
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8")) data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
else: else:
data = generator.random(case.shape).astype(case.dtype) data = generator.random(shape).astype(dtype)
h5path = H5ArrayPath(hdf5_file, array_path) h5path = H5ArrayPath(hdf5_file, array_path)
@ -79,31 +86,30 @@ class HDF5CompoundCase(_HDF5MetaCase):
"""HDF5 Array with a fake compound dtype""" """HDF5 Array with a fake compound dtype"""
@classmethod @classmethod
def generate_array( def make_array(
cls, case: "ValidationCase", path: Path cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> Optional[H5ArrayPath]: ) -> Optional[H5ArrayPath]:
if cls.skip(case): if cls.skip(shape, dtype):
return None return None
hdf5_file = path / "h5f.h5" hdf5_file = path / "h5f.h5"
array_path = ( array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__ if dtype is str:
)
if case.dtype is str:
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
data = np.array([("hey", 0)] * np.prod(case.shape), dtype=dt).reshape( data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
case.shape elif dtype is datetime:
)
elif case.dtype is datetime:
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")]) dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
data = np.array( data = np.array(
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)] [(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
* np.prod(case.shape), * np.prod(shape),
dtype=dt, dtype=dt,
).reshape(case.shape) ).reshape(shape)
else: else:
dt = np.dtype([("data", case.dtype), ("extra", "i8")]) dt = np.dtype([("data", dtype), ("extra", "i8")])
data = np.zeros(case.shape, dtype=dt) data = np.zeros(shape, dtype=dt)
h5path = H5ArrayPath(hdf5_file, array_path, "data") h5path = H5ArrayPath(hdf5_file, array_path, "data")
with h5py.File(hdf5_file, "w") as h5f: with h5py.File(hdf5_file, "w") as h5f:
@ -117,11 +123,16 @@ class DaskCase(InterfaceCase):
interface = DaskInterface interface = DaskInterface
@classmethod @classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> da.Array: def make_array(
if issubclass(case.dtype, BaseModel): cls,
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1) shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> da.Array:
if issubclass(dtype, BaseModel):
return da.full(shape=shape, fill_value=dtype(x=1), chunks=-1)
else: else:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) return da.zeros(shape=shape, dtype=dtype, chunks=10)
class _ZarrMetaCase(InterfaceCase): class _ZarrMetaCase(InterfaceCase):
@ -130,45 +141,65 @@ class _ZarrMetaCase(InterfaceCase):
interface = ZarrInterface interface = ZarrInterface
@classmethod @classmethod
def skip(cls, case: "ValidationCase") -> bool: def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
return not issubclass(case.dtype, BaseModel) return not issubclass(dtype, BaseModel)
class ZarrCase(_ZarrMetaCase): class ZarrCase(_ZarrMetaCase):
"""In-memory zarr array""" """In-memory zarr array"""
@classmethod @classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[zarr.Array]: def make_array(
return zarr.zeros(shape=case.shape, dtype=case.dtype) cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> Optional[zarr.Array]:
return zarr.zeros(shape=shape, dtype=dtype)
class ZarrDirCase(_ZarrMetaCase): class ZarrDirCase(_ZarrMetaCase):
"""On-disk zarr array""" """On-disk zarr array"""
@classmethod @classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath: def make_array(
cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> Optional[zarr.Array]:
store = zarr.DirectoryStore(str(path / "array.zarr")) store = zarr.DirectoryStore(str(path / "array.zarr"))
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) return zarr.zeros(shape=shape, dtype=dtype, store=store)
class ZarrZipCase(_ZarrMetaCase): class ZarrZipCase(_ZarrMetaCase):
"""Zarr zip store""" """Zarr zip store"""
@classmethod @classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath: def make_array(
cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> Optional[zarr.Array]:
store = zarr.ZipStore(str(path / "array.zarr"), mode="w") store = zarr.ZipStore(str(path / "array.zarr"), mode="w")
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) return zarr.zeros(shape=shape, dtype=dtype, store=store)
class ZarrNestedCase(_ZarrMetaCase): class ZarrNestedCase(_ZarrMetaCase):
"""Nested zarr array""" """Nested zarr array"""
@classmethod @classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath: def make_array(
cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> ZarrArrayPath:
file = str(path / "nested.zarr") file = str(path / "nested.zarr")
root = zarr.open(file, mode="w") root = zarr.open(file, mode="w")
subpath = "a/b/c" subpath = "a/b/c"
_ = root.zeros(subpath, shape=case.shape, dtype=case.dtype) _ = root.zeros(subpath, shape=shape, dtype=dtype)
return ZarrArrayPath(file=file, path=subpath) return ZarrArrayPath(file=file, path=subpath)
@ -178,13 +209,18 @@ class VideoCase(InterfaceCase):
interface = VideoInterface interface = VideoInterface
@classmethod @classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[Path]: def make_array(
if cls.skip(case): cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
path: Optional[Path] = None,
) -> Optional[Path]:
if cls.skip(shape, dtype):
return None return None
is_color = len(case.shape) == 4 is_color = len(shape) == 4
frames = case.shape[0] frames = shape[0]
frame_shape = case.shape[1:] frame_shape = shape[1:]
video_path = path / "test.avi" video_path = path / "test.avi"
writer = cv2.VideoWriter( writer = cv2.VideoWriter(
@ -207,12 +243,12 @@ class VideoCase(InterfaceCase):
return video_path return video_path
@classmethod @classmethod
def skip(cls, case: "ValidationCase") -> bool: def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
"""We really can only handle 3-4 dimensional cases in 8-bit rn lol""" """We really can only handle 3-4 dimensional cases in 8-bit rn lol"""
if len(case.shape) < 3 or len(case.shape) > 4: if len(shape) < 3 or len(shape) > 4:
return True return True
if case.dtype not in (int, np.uint8): if dtype not in (int, np.uint8):
return True return True
# if we have a color video (ie. shape == 4, needs to be RGB) # if we have a color video (ie. shape == 4, needs to be RGB)
if len(case.shape) == 4 and case.shape[3] != 3: if len(shape) == 4 and shape[3] != 3:
return True return True

View file

@ -13,11 +13,19 @@ def pytest_addoption(parser):
) )
@pytest.fixture(scope="module", params=SHAPE_CASES) @pytest.fixture(
def shape_cases(request) -> ValidationCase: scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES]
return request.param )
def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
case.path = tmp_output_dir_func
return case
@pytest.fixture(scope="module", params=DTYPE_CASES) @pytest.fixture(
def dtype_cases(request) -> ValidationCase: scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES]
return request.param )
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy()
case.path = tmp_output_dir_func
return case

View file

@ -1,60 +1,33 @@
from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union
import cv2
import h5py
import numpy as np import numpy as np
import pytest import pytest
import zarr import zarr
from numpydantic.interface.hdf5 import H5ArrayPath from numpydantic.interface.hdf5 import H5ArrayPath
from numpydantic.interface.zarr import ZarrArrayPath from numpydantic.interface.zarr import ZarrArrayPath
from numpydantic.testing import ValidationCase
from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase, VideoCase
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def hdf5_array( def hdf5_array(
request, tmp_output_dir_func request, tmp_output_dir_func
) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]: ) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]:
hdf5_file = tmp_output_dir_func / "h5f.h5"
def _hdf5_array( def _hdf5_array(
shape: Tuple[int, ...] = (10, 10), shape: Tuple[int, ...] = (10, 10),
dtype: Union[np.dtype, type] = float, dtype: Union[np.dtype, type] = float,
compound: bool = False, compound: bool = False,
) -> H5ArrayPath: ) -> H5ArrayPath:
array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ if compound:
generator = np.random.default_rng() array: H5ArrayPath = HDF5CompoundCase.make_array(
shape, dtype, tmp_output_dir_func
if not compound: )
if dtype is str: return array
data = generator.random(shape).astype(bytes)
elif dtype is datetime:
data = np.empty(shape, dtype="S32")
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
else:
data = generator.random(shape).astype(dtype)
h5path = H5ArrayPath(hdf5_file, array_path)
else: else:
if dtype is str: return HDF5Case.make_array(shape, dtype, tmp_output_dir_func)
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
elif dtype is datetime:
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
data = np.array(
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
* np.prod(shape),
dtype=dt,
).reshape(shape)
else:
dt = np.dtype([("data", dtype), ("extra", "i8")])
data = np.zeros(shape, dtype=dt)
h5path = H5ArrayPath(hdf5_file, array_path, "data")
with h5py.File(hdf5_file, "w") as h5f:
_ = h5f.create_dataset(array_path, data=data)
return h5path
return _hdf5_array return _hdf5_array
@ -79,28 +52,13 @@ def zarr_array(tmp_output_dir_func) -> Path:
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def avi_video(tmp_output_dir_func) -> Callable[[Tuple[int, int], int, bool], Path]: def avi_video(tmp_output_dir_func) -> Callable[[Tuple[int, int], int, bool], Path]:
video_path = tmp_output_dir_func / "test.avi"
def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path: def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path:
writer = cv2.VideoWriter( shape = (frames, *shape)
str(video_path),
cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes
30,
(shape[1], shape[0]),
is_color,
)
if is_color: if is_color:
shape = (*shape, 3) shape = (*shape, 3)
return VideoCase.array_from_case(
for i in range(frames): ValidationCase(shape=shape, dtype=np.uint8), tmp_output_dir_func
# make fresh array every time bc opencv eats them )
array = np.zeros(shape, dtype=np.uint8)
if not is_color:
array[i, i] = i
else:
array[i, i, :] = i
writer.write(array)
writer.release()
return video_path
return _make_video return _make_video

View file

@ -1,12 +1,20 @@
import inspect
from typing import Callable, Tuple, Type from typing import Callable, Tuple, Type
import dask.array as da
import numpy as np
import pytest import pytest
import zarr
from pydantic import BaseModel from pydantic import BaseModel
from numpydantic import NDArray, interface from numpydantic import NDArray, interface
from numpydantic.testing.helpers import InterfaceCase
from numpydantic.testing.interfaces import (
DaskCase,
HDF5Case,
NumpyCase,
VideoCase,
ZarrCase,
ZarrDirCase,
ZarrNestedCase,
)
@pytest.fixture( @pytest.fixture(
@ -18,47 +26,55 @@ from numpydantic import NDArray, interface
id="numpy-list", id="numpy-list",
), ),
pytest.param( pytest.param(
(np.zeros((3, 4)), interface.NumpyInterface), (NumpyCase, interface.NumpyInterface),
marks=pytest.mark.numpy, marks=pytest.mark.numpy,
id="numpy", id="numpy",
), ),
pytest.param( pytest.param(
("hdf5_array", interface.H5Interface), (HDF5Case, interface.H5Interface),
marks=pytest.mark.hdf5, marks=pytest.mark.hdf5,
id="h5-array-path", id="h5-array-path",
), ),
pytest.param( pytest.param(
(da.random.random((10, 10)), interface.DaskInterface), (DaskCase, interface.DaskInterface),
marks=pytest.mark.dask, marks=pytest.mark.dask,
id="dask", id="dask",
), ),
pytest.param( pytest.param(
(zarr.ones((10, 10)), interface.ZarrInterface), (ZarrCase, interface.ZarrInterface),
marks=pytest.mark.zarr, marks=pytest.mark.zarr,
id="zarr-memory", id="zarr-memory",
), ),
pytest.param( pytest.param(
("zarr_nested_array", interface.ZarrInterface), (ZarrNestedCase, interface.ZarrInterface),
marks=pytest.mark.zarr, marks=pytest.mark.zarr,
id="zarr-nested", id="zarr-nested",
), ),
pytest.param( pytest.param(
("zarr_array", interface.ZarrInterface), (ZarrDirCase, interface.ZarrInterface),
marks=pytest.mark.zarr, marks=pytest.mark.zarr,
id="zarr-array", id="zarr-dir",
), ),
pytest.param( pytest.param(
("avi_video", interface.VideoInterface), marks=pytest.mark.video, id="video" (VideoCase, interface.VideoInterface), marks=pytest.mark.video, id="video"
), ),
], ],
) )
def interface_type(request) -> Tuple[NDArray, Type[interface.Interface]]: def interface_type(
request, tmp_output_dir_func
) -> Tuple[NDArray, Type[interface.Interface]]:
""" """
Test cases for each interface's ``check`` method - each input should match the Test cases for each interface's ``check`` method - each input should match the
provided interface and that interface only provided interface and that interface only
""" """
if isinstance(request.param[0], str):
return (request.getfixturevalue(request.param[0]), request.param[1]) if inspect.isclass(request.param[0]) and issubclass(
request.param[0], InterfaceCase
):
array = request.param[0].make_array(path=tmp_output_dir_func)
if array is None:
pytest.skip()
return array, request.param[1]
else: else:
return request.param return request.param

View file

@ -2,31 +2,13 @@ import json
import dask.array as da import dask.array as da
import pytest import pytest
from pydantic import BaseModel, ValidationError
from numpydantic.exceptions import DtypeError, ShapeError
from numpydantic.interface import DaskInterface from numpydantic.interface import DaskInterface
from numpydantic.testing.helpers import ValidationCase from numpydantic.testing.interfaces import DaskCase
pytestmark = pytest.mark.dask pytestmark = pytest.mark.dask
def dask_array(case: ValidationCase) -> da.Array:
if issubclass(case.dtype, BaseModel):
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
else:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
def _test_dask_case(case: ValidationCase):
array = dask_array(case)
if case.passes:
case.model(array=array)
else:
with pytest.raises((ValidationError, DtypeError, ShapeError)):
case.model(array=array)
def test_dask_enabled(): def test_dask_enabled():
""" """
We need dask to be available to run these tests :) We need dask to be available to run these tests :)
@ -43,12 +25,14 @@ def test_dask_check(interface_type):
@pytest.mark.shape @pytest.mark.shape
def test_dask_shape(shape_cases): def test_dask_shape(shape_cases):
_test_dask_case(shape_cases) shape_cases.interface = DaskCase
shape_cases.validate_case()
@pytest.mark.dtype @pytest.mark.dtype
def test_dask_dtype(dtype_cases): def test_dask_dtype(dtype_cases):
_test_dask_case(dtype_cases) dtype_cases.interface = DaskCase
dtype_cases.validate_case()
@pytest.mark.serialization @pytest.mark.serialization

View file

@ -12,10 +12,21 @@ from numpydantic.exceptions import DtypeError, ShapeError
from numpydantic.interface import H5Interface from numpydantic.interface import H5Interface
from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy
from numpydantic.testing.helpers import ValidationCase from numpydantic.testing.helpers import ValidationCase
from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase
pytestmark = pytest.mark.hdf5 pytestmark = pytest.mark.hdf5
@pytest.fixture(
params=[
pytest.param(HDF5Case, id="hdf5"),
pytest.param(HDF5CompoundCase, id="hdf5-compound"),
]
)
def hdf5_cases(request):
return request.param
def hdf5_array_case( def hdf5_array_case(
case: ValidationCase, array_func, compound: bool = False case: ValidationCase, array_func, compound: bool = False
) -> H5ArrayPath: ) -> H5ArrayPath:
@ -47,8 +58,6 @@ def test_hdf5_enabled():
def test_hdf5_check(interface_type): def test_hdf5_check(interface_type):
if interface_type[1] is H5Interface: if interface_type[1] is H5Interface:
if interface_type[0].__name__ == "_hdf5_array":
interface_type = (interface_type[0](), interface_type[1])
assert H5Interface.check(interface_type[0]) assert H5Interface.check(interface_type[0])
if isinstance(interface_type[0], H5ArrayPath): if isinstance(interface_type[0], H5ArrayPath):
# also test that we can instantiate from a tuple like the H5ArrayPath # also test that we can instantiate from a tuple like the H5ArrayPath
@ -74,15 +83,17 @@ def test_hdf5_check_not_hdf5(tmp_path):
@pytest.mark.shape @pytest.mark.shape
@pytest.mark.parametrize("compound", [True, False]) def test_hdf5_shape(shape_cases, hdf5_cases):
def test_hdf5_shape(shape_cases, hdf5_array, compound): shape_cases.interface = hdf5_cases
_test_hdf5_case(shape_cases, hdf5_array, compound) if shape_cases.skip():
pytest.skip()
shape_cases.validate_case()
@pytest.mark.dtype @pytest.mark.dtype
@pytest.mark.parametrize("compound", [True, False]) def test_hdf5_dtype(dtype_cases, hdf5_cases):
def test_hdf5_dtype(dtype_cases, hdf5_array, compound): dtype_cases.interface = hdf5_cases
_test_hdf5_case(dtype_cases, hdf5_array, compound) dtype_cases.validate_case()
def test_hdf5_dataset_not_exists(hdf5_array, model_blank): def test_hdf5_dataset_not_exists(hdf5_array, model_blank):