From 3356738e42547f3f79504ec3af98d4cc8d29b05b Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 4 Oct 2024 00:46:49 -0700 Subject: [PATCH] refactoring array generation, swapping in the interface case generators --- docs/api/testing/index.md | 9 -- src/numpydantic/testing/cases.py | 14 ++- src/numpydantic/testing/helpers.py | 95 ++++++++++++---- src/numpydantic/testing/interfaces.py | 150 ++++++++++++++++---------- tests/conftest.py | 20 ++-- tests/fixtures/generation.py | 66 +++--------- tests/test_interface/conftest.py | 44 +++++--- tests/test_interface/test_dask.py | 26 +---- tests/test_interface/test_hdf5.py | 27 +++-- 9 files changed, 262 insertions(+), 189 deletions(-) diff --git a/docs/api/testing/index.md b/docs/api/testing/index.md index 91b670f..687835b 100644 --- a/docs/api/testing/index.md +++ b/docs/api/testing/index.md @@ -2,15 +2,6 @@ Utilities for testing and 3rd-party interface development. -Only things that *don't* require pytest go in this module. -We want to keep all test-time specific behavior there, -and have this just serve as helpers exposed for downstream interface development. - -We want to avoid pytest stuff bleeding in here because then we limit -the ability for downstream developers to configure their own tests. - -*(If there is some reason to change this division of labor, just raise an issue and let's chat.)* - ```{toctree} cases helpers diff --git a/src/numpydantic/testing/cases.py b/src/numpydantic/testing/cases.py index 7eedfb7..b7a850e 100644 --- a/src/numpydantic/testing/cases.py +++ b/src/numpydantic/testing/cases.py @@ -203,9 +203,19 @@ def merged_product( iterator = merged_product(shape_cases, dtype_cases)) next(iterator) - # ValidationCase(shape=(10, 10, 10), dtype=float, passes=True, id="valid shape-float") + # ValidationCase( + # shape=(10, 10, 10), + # dtype=float, + # passes=True, + # id="valid shape-float" + # ) next(iterator) - # ValidationCase(shape=(10, 10, 10), dtype=int, passes=False, id="valid shape-int") + # ValidationCase( + # shape=(10, 10, 10), + # dtype=int, + # passes=False, + # id="valid shape-int" + # ) """ diff --git a/src/numpydantic/testing/helpers.py b/src/numpydantic/testing/helpers.py index 3ab7b22..5abbcd0 100644 --- a/src/numpydantic/testing/helpers.py +++ b/src/numpydantic/testing/helpers.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, ValidationError, computed_field from numpydantic import NDArray, Shape from numpydantic.dtype import Float from numpydantic.interface import Interface -from numpydantic.types import NDArrayType +from numpydantic.types import DtypeType, NDArrayType class InterfaceCase(ABC): @@ -29,43 +29,64 @@ class InterfaceCase(ABC): """The interface that this helper is for""" @classmethod - @abstractmethod - def generate_array( - cls, case: "ValidationCase", path: Path + def array_from_case( + cls, case: "ValidationCase", path: Optional[Path] = None ) -> Optional[NDArrayType]: """ Generate an array from the given validation case. Returns ``None`` if an array can't be generated for a specific case. """ + return cls.make_array(shape=case.shape, dtype=case.dtype, path=path) @classmethod - def validate_array(cls, case: "ValidationCase", path: Path) -> Optional[bool]: + @abstractmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> Optional[NDArrayType]: + """ + Make an array from a shape and dtype, and a path if needed + """ + + @classmethod + def validate_case(cls, case: "ValidationCase", path: Path) -> bool: """ Validate a generated array against the annotation in the validation case. Kept in the InterfaceCase in case an interface has specific needs aside from just validating against a model, but typically left as is. - Does not raise on Validation errors - - returns bool instead for consistency's sake. - If an array can't be generated for a given case, returns `None` so that the calling function can know to skip rather than fail the case. + + Raises exceptions if validation fails (or succeeds when it shouldn't) + + Args: + case (ValidationCase): The validation case to validate. + path (Path): Path to generate arrays into, if any. + + Returns: + ``True`` if array is valid and was supposed to be, + or invalid and wasn't supposed to be """ - array = cls.generate_array(case, path) + import pytest + + array = cls.array_from_case(case, path) if array is None: - return None - try: + pytest.skip() + if case.passes: case.model(array=array) - # True if case is supposed to pass, False if it's not... - return case.passes - except ValidationError: - # False if the case is supposed to pass, True if it is... - return not case.passes + return True + else: + with pytest.raises(ValidationError): + case.model(array=array) + return True @classmethod - def skip(cls, case: "ValidationCase") -> bool: + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: """ Whether a given interface should be skipped for the case """ @@ -97,6 +118,9 @@ class ValidationCase(BaseModel): passes: bool = False """Whether the validation should pass or not""" interface: Optional[InterfaceCase] = None + """The interface test case to generate and validate the array with""" + path: Optional[Path] = None + """The path to generate arrays into, if any.""" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -110,6 +134,39 @@ class ValidationCase(BaseModel): return Model + def validate_case(self, path: Optional[Path] = None) -> bool: + """ + Whether the generated array correctly validated against the annotation, + given the interface + + Args: + path (:class:`pathlib.Path`): Directory to generate array into, if on disk. + + Raises: + ValueError: if an ``interface`` is missing + """ + if self.interface is None: + raise ValueError("Missing an interface") + if path is None: + if self.path: + path = self.path + else: + raise ValueError("Missing a path to generate arrays into") + + return self.interface.validate_case(self, path) + + def array(self, path: Path) -> NDArrayType: + """Generate an array for the validation case if we have an interface to do so""" + if self.interface is None: + raise ValueError("Missing an interface") + if path is None: + if self.path: + path = self.path + else: + raise ValueError("Missing a path to generate arrays into") + + return self.interface.array_from_case(self, path) + def merge( self, other: Union["ValidationCase", Sequence["ValidationCase"]] ) -> "ValidationCase": @@ -154,7 +211,9 @@ class ValidationCase(BaseModel): (eg. due to the interface case being incompatible with the requested dtype or shape) """ - return bool(self.interface is not None and self.interface.skip()) + return bool( + self.interface is not None and self.interface.skip(self.shape, self.dtype) + ) def merge_cases(*args: ValidationCase) -> ValidationCase: diff --git a/src/numpydantic/testing/interfaces.py b/src/numpydantic/testing/interfaces.py index d191b61..c4846ae 100644 --- a/src/numpydantic/testing/interfaces.py +++ b/src/numpydantic/testing/interfaces.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone from pathlib import Path -from typing import Optional +from typing import Optional, Tuple import cv2 import dask.array as da @@ -18,7 +18,8 @@ from numpydantic.interface import ( ZarrArrayPath, ZarrInterface, ) -from numpydantic.testing.helpers import InterfaceCase, ValidationCase +from numpydantic.testing.helpers import InterfaceCase +from numpydantic.types import DtypeType class NumpyCase(InterfaceCase): @@ -27,11 +28,16 @@ class NumpyCase(InterfaceCase): interface = NumpyInterface @classmethod - def generate_array(cls, case: "ValidationCase", path: Path) -> np.ndarray: - if issubclass(case.dtype, BaseModel): - return np.full(shape=case.shape, fill_value=case.dtype(x=1)) + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> np.ndarray: + if issubclass(dtype, BaseModel): + return np.full(shape=shape, fill_value=dtype(x=1)) else: - return np.zeros(shape=case.shape, dtype=case.dtype) + return np.zeros(shape=shape, dtype=dtype) class _HDF5MetaCase(InterfaceCase): @@ -40,33 +46,34 @@ class _HDF5MetaCase(InterfaceCase): interface = H5Interface @classmethod - def skip(cls, case: "ValidationCase") -> bool: - return not issubclass(case.dtype, BaseModel) + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: + return issubclass(dtype, BaseModel) class HDF5Case(_HDF5MetaCase): """HDF5 Array""" @classmethod - def generate_array( - cls, case: "ValidationCase", path: Path + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, ) -> Optional[H5ArrayPath]: - if cls.skip(case): + if cls.skip(shape, dtype): return None hdf5_file = path / "h5f.h5" - array_path = ( - "/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__ - ) + array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ generator = np.random.default_rng() - if case.dtype is str: - data = generator.random(case.shape).astype(bytes) - elif case.dtype is datetime: - data = np.empty(case.shape, dtype="S32") + if dtype is str: + data = generator.random(shape).astype(bytes) + elif dtype is datetime: + data = np.empty(shape, dtype="S32") data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8")) else: - data = generator.random(case.shape).astype(case.dtype) + data = generator.random(shape).astype(dtype) h5path = H5ArrayPath(hdf5_file, array_path) @@ -79,31 +86,30 @@ class HDF5CompoundCase(_HDF5MetaCase): """HDF5 Array with a fake compound dtype""" @classmethod - def generate_array( - cls, case: "ValidationCase", path: Path + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, ) -> Optional[H5ArrayPath]: - if cls.skip(case): + if cls.skip(shape, dtype): return None hdf5_file = path / "h5f.h5" - array_path = ( - "/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__ - ) - if case.dtype is str: + array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ + if dtype is str: dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) - data = np.array([("hey", 0)] * np.prod(case.shape), dtype=dt).reshape( - case.shape - ) - elif case.dtype is datetime: + data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape) + elif dtype is datetime: dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")]) data = np.array( [(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)] - * np.prod(case.shape), + * np.prod(shape), dtype=dt, - ).reshape(case.shape) + ).reshape(shape) else: - dt = np.dtype([("data", case.dtype), ("extra", "i8")]) - data = np.zeros(case.shape, dtype=dt) + dt = np.dtype([("data", dtype), ("extra", "i8")]) + data = np.zeros(shape, dtype=dt) h5path = H5ArrayPath(hdf5_file, array_path, "data") with h5py.File(hdf5_file, "w") as h5f: @@ -117,11 +123,16 @@ class DaskCase(InterfaceCase): interface = DaskInterface @classmethod - def generate_array(cls, case: "ValidationCase", path: Path) -> da.Array: - if issubclass(case.dtype, BaseModel): - return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1) + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> da.Array: + if issubclass(dtype, BaseModel): + return da.full(shape=shape, fill_value=dtype(x=1), chunks=-1) else: - return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) + return da.zeros(shape=shape, dtype=dtype, chunks=10) class _ZarrMetaCase(InterfaceCase): @@ -130,45 +141,65 @@ class _ZarrMetaCase(InterfaceCase): interface = ZarrInterface @classmethod - def skip(cls, case: "ValidationCase") -> bool: - return not issubclass(case.dtype, BaseModel) + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: + return not issubclass(dtype, BaseModel) class ZarrCase(_ZarrMetaCase): """In-memory zarr array""" @classmethod - def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[zarr.Array]: - return zarr.zeros(shape=case.shape, dtype=case.dtype) + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> Optional[zarr.Array]: + return zarr.zeros(shape=shape, dtype=dtype) class ZarrDirCase(_ZarrMetaCase): """On-disk zarr array""" @classmethod - def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath: + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> Optional[zarr.Array]: store = zarr.DirectoryStore(str(path / "array.zarr")) - return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) + return zarr.zeros(shape=shape, dtype=dtype, store=store) class ZarrZipCase(_ZarrMetaCase): """Zarr zip store""" @classmethod - def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath: + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> Optional[zarr.Array]: store = zarr.ZipStore(str(path / "array.zarr"), mode="w") - return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) + return zarr.zeros(shape=shape, dtype=dtype, store=store) class ZarrNestedCase(_ZarrMetaCase): """Nested zarr array""" @classmethod - def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath: + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> ZarrArrayPath: file = str(path / "nested.zarr") root = zarr.open(file, mode="w") subpath = "a/b/c" - _ = root.zeros(subpath, shape=case.shape, dtype=case.dtype) + _ = root.zeros(subpath, shape=shape, dtype=dtype) return ZarrArrayPath(file=file, path=subpath) @@ -178,13 +209,18 @@ class VideoCase(InterfaceCase): interface = VideoInterface @classmethod - def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[Path]: - if cls.skip(case): + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + ) -> Optional[Path]: + if cls.skip(shape, dtype): return None - is_color = len(case.shape) == 4 - frames = case.shape[0] - frame_shape = case.shape[1:] + is_color = len(shape) == 4 + frames = shape[0] + frame_shape = shape[1:] video_path = path / "test.avi" writer = cv2.VideoWriter( @@ -207,12 +243,12 @@ class VideoCase(InterfaceCase): return video_path @classmethod - def skip(cls, case: "ValidationCase") -> bool: + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: """We really can only handle 3-4 dimensional cases in 8-bit rn lol""" - if len(case.shape) < 3 or len(case.shape) > 4: + if len(shape) < 3 or len(shape) > 4: return True - if case.dtype not in (int, np.uint8): + if dtype not in (int, np.uint8): return True # if we have a color video (ie. shape == 4, needs to be RGB) - if len(case.shape) == 4 and case.shape[3] != 3: + if len(shape) == 4 and shape[3] != 3: return True diff --git a/tests/conftest.py b/tests/conftest.py index c4fdb80..669de7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,11 +13,19 @@ def pytest_addoption(parser): ) -@pytest.fixture(scope="module", params=SHAPE_CASES) -def shape_cases(request) -> ValidationCase: - return request.param +@pytest.fixture( + scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES] +) +def shape_cases(request, tmp_output_dir_func) -> ValidationCase: + case: ValidationCase = request.param.model_copy() + case.path = tmp_output_dir_func + return case -@pytest.fixture(scope="module", params=DTYPE_CASES) -def dtype_cases(request) -> ValidationCase: - return request.param +@pytest.fixture( + scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES] +) +def dtype_cases(request, tmp_output_dir_func) -> ValidationCase: + case: ValidationCase = request.param.model_copy() + case.path = tmp_output_dir_func + return case diff --git a/tests/fixtures/generation.py b/tests/fixtures/generation.py index b1de6ea..798dfac 100644 --- a/tests/fixtures/generation.py +++ b/tests/fixtures/generation.py @@ -1,60 +1,33 @@ -from datetime import datetime, timezone from pathlib import Path from typing import Callable, Tuple, Union -import cv2 -import h5py import numpy as np import pytest import zarr from numpydantic.interface.hdf5 import H5ArrayPath from numpydantic.interface.zarr import ZarrArrayPath +from numpydantic.testing import ValidationCase +from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase, VideoCase @pytest.fixture(scope="function") def hdf5_array( request, tmp_output_dir_func ) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]: - hdf5_file = tmp_output_dir_func / "h5f.h5" def _hdf5_array( shape: Tuple[int, ...] = (10, 10), dtype: Union[np.dtype, type] = float, compound: bool = False, ) -> H5ArrayPath: - array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ - generator = np.random.default_rng() - - if not compound: - if dtype is str: - data = generator.random(shape).astype(bytes) - elif dtype is datetime: - data = np.empty(shape, dtype="S32") - data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8")) - else: - data = generator.random(shape).astype(dtype) - - h5path = H5ArrayPath(hdf5_file, array_path) + if compound: + array: H5ArrayPath = HDF5CompoundCase.make_array( + shape, dtype, tmp_output_dir_func + ) + return array else: - if dtype is str: - dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) - data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape) - elif dtype is datetime: - dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")]) - data = np.array( - [(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)] - * np.prod(shape), - dtype=dt, - ).reshape(shape) - else: - dt = np.dtype([("data", dtype), ("extra", "i8")]) - data = np.zeros(shape, dtype=dt) - h5path = H5ArrayPath(hdf5_file, array_path, "data") - - with h5py.File(hdf5_file, "w") as h5f: - _ = h5f.create_dataset(array_path, data=data) - return h5path + return HDF5Case.make_array(shape, dtype, tmp_output_dir_func) return _hdf5_array @@ -79,28 +52,13 @@ def zarr_array(tmp_output_dir_func) -> Path: @pytest.fixture(scope="function") def avi_video(tmp_output_dir_func) -> Callable[[Tuple[int, int], int, bool], Path]: - video_path = tmp_output_dir_func / "test.avi" def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path: - writer = cv2.VideoWriter( - str(video_path), - cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes - 30, - (shape[1], shape[0]), - is_color, - ) + shape = (frames, *shape) if is_color: shape = (*shape, 3) - - for i in range(frames): - # make fresh array every time bc opencv eats them - array = np.zeros(shape, dtype=np.uint8) - if not is_color: - array[i, i] = i - else: - array[i, i, :] = i - writer.write(array) - writer.release() - return video_path + return VideoCase.array_from_case( + ValidationCase(shape=shape, dtype=np.uint8), tmp_output_dir_func + ) return _make_video diff --git a/tests/test_interface/conftest.py b/tests/test_interface/conftest.py index 3917ae9..ff54e1b 100644 --- a/tests/test_interface/conftest.py +++ b/tests/test_interface/conftest.py @@ -1,12 +1,20 @@ +import inspect from typing import Callable, Tuple, Type -import dask.array as da -import numpy as np import pytest -import zarr from pydantic import BaseModel from numpydantic import NDArray, interface +from numpydantic.testing.helpers import InterfaceCase +from numpydantic.testing.interfaces import ( + DaskCase, + HDF5Case, + NumpyCase, + VideoCase, + ZarrCase, + ZarrDirCase, + ZarrNestedCase, +) @pytest.fixture( @@ -18,47 +26,55 @@ from numpydantic import NDArray, interface id="numpy-list", ), pytest.param( - (np.zeros((3, 4)), interface.NumpyInterface), + (NumpyCase, interface.NumpyInterface), marks=pytest.mark.numpy, id="numpy", ), pytest.param( - ("hdf5_array", interface.H5Interface), + (HDF5Case, interface.H5Interface), marks=pytest.mark.hdf5, id="h5-array-path", ), pytest.param( - (da.random.random((10, 10)), interface.DaskInterface), + (DaskCase, interface.DaskInterface), marks=pytest.mark.dask, id="dask", ), pytest.param( - (zarr.ones((10, 10)), interface.ZarrInterface), + (ZarrCase, interface.ZarrInterface), marks=pytest.mark.zarr, id="zarr-memory", ), pytest.param( - ("zarr_nested_array", interface.ZarrInterface), + (ZarrNestedCase, interface.ZarrInterface), marks=pytest.mark.zarr, id="zarr-nested", ), pytest.param( - ("zarr_array", interface.ZarrInterface), + (ZarrDirCase, interface.ZarrInterface), marks=pytest.mark.zarr, - id="zarr-array", + id="zarr-dir", ), pytest.param( - ("avi_video", interface.VideoInterface), marks=pytest.mark.video, id="video" + (VideoCase, interface.VideoInterface), marks=pytest.mark.video, id="video" ), ], ) -def interface_type(request) -> Tuple[NDArray, Type[interface.Interface]]: +def interface_type( + request, tmp_output_dir_func +) -> Tuple[NDArray, Type[interface.Interface]]: """ Test cases for each interface's ``check`` method - each input should match the provided interface and that interface only """ - if isinstance(request.param[0], str): - return (request.getfixturevalue(request.param[0]), request.param[1]) + + if inspect.isclass(request.param[0]) and issubclass( + request.param[0], InterfaceCase + ): + array = request.param[0].make_array(path=tmp_output_dir_func) + if array is None: + pytest.skip() + return array, request.param[1] else: return request.param diff --git a/tests/test_interface/test_dask.py b/tests/test_interface/test_dask.py index a7c17d4..6b6fc49 100644 --- a/tests/test_interface/test_dask.py +++ b/tests/test_interface/test_dask.py @@ -2,31 +2,13 @@ import json import dask.array as da import pytest -from pydantic import BaseModel, ValidationError -from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.interface import DaskInterface -from numpydantic.testing.helpers import ValidationCase +from numpydantic.testing.interfaces import DaskCase pytestmark = pytest.mark.dask -def dask_array(case: ValidationCase) -> da.Array: - if issubclass(case.dtype, BaseModel): - return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1) - else: - return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) - - -def _test_dask_case(case: ValidationCase): - array = dask_array(case) - if case.passes: - case.model(array=array) - else: - with pytest.raises((ValidationError, DtypeError, ShapeError)): - case.model(array=array) - - def test_dask_enabled(): """ We need dask to be available to run these tests :) @@ -43,12 +25,14 @@ def test_dask_check(interface_type): @pytest.mark.shape def test_dask_shape(shape_cases): - _test_dask_case(shape_cases) + shape_cases.interface = DaskCase + shape_cases.validate_case() @pytest.mark.dtype def test_dask_dtype(dtype_cases): - _test_dask_case(dtype_cases) + dtype_cases.interface = DaskCase + dtype_cases.validate_case() @pytest.mark.serialization diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 9dec0b5..f7cea74 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -12,10 +12,21 @@ from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.interface import H5Interface from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy from numpydantic.testing.helpers import ValidationCase +from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase pytestmark = pytest.mark.hdf5 +@pytest.fixture( + params=[ + pytest.param(HDF5Case, id="hdf5"), + pytest.param(HDF5CompoundCase, id="hdf5-compound"), + ] +) +def hdf5_cases(request): + return request.param + + def hdf5_array_case( case: ValidationCase, array_func, compound: bool = False ) -> H5ArrayPath: @@ -47,8 +58,6 @@ def test_hdf5_enabled(): def test_hdf5_check(interface_type): if interface_type[1] is H5Interface: - if interface_type[0].__name__ == "_hdf5_array": - interface_type = (interface_type[0](), interface_type[1]) assert H5Interface.check(interface_type[0]) if isinstance(interface_type[0], H5ArrayPath): # also test that we can instantiate from a tuple like the H5ArrayPath @@ -74,15 +83,17 @@ def test_hdf5_check_not_hdf5(tmp_path): @pytest.mark.shape -@pytest.mark.parametrize("compound", [True, False]) -def test_hdf5_shape(shape_cases, hdf5_array, compound): - _test_hdf5_case(shape_cases, hdf5_array, compound) +def test_hdf5_shape(shape_cases, hdf5_cases): + shape_cases.interface = hdf5_cases + if shape_cases.skip(): + pytest.skip() + shape_cases.validate_case() @pytest.mark.dtype -@pytest.mark.parametrize("compound", [True, False]) -def test_hdf5_dtype(dtype_cases, hdf5_array, compound): - _test_hdf5_case(dtype_cases, hdf5_array, compound) +def test_hdf5_dtype(dtype_cases, hdf5_cases): + dtype_cases.interface = hdf5_cases + dtype_cases.validate_case() def test_hdf5_dataset_not_exists(hdf5_array, model_blank):