mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 18:54:28 +00:00
refactoring array generation, swapping in the interface case generators
This commit is contained in:
parent
e701bf6e9b
commit
3356738e42
9 changed files with 262 additions and 189 deletions
|
@ -2,15 +2,6 @@
|
||||||
|
|
||||||
Utilities for testing and 3rd-party interface development.
|
Utilities for testing and 3rd-party interface development.
|
||||||
|
|
||||||
Only things that *don't* require pytest go in this module.
|
|
||||||
We want to keep all test-time specific behavior there,
|
|
||||||
and have this just serve as helpers exposed for downstream interface development.
|
|
||||||
|
|
||||||
We want to avoid pytest stuff bleeding in here because then we limit
|
|
||||||
the ability for downstream developers to configure their own tests.
|
|
||||||
|
|
||||||
*(If there is some reason to change this division of labor, just raise an issue and let's chat.)*
|
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
cases
|
cases
|
||||||
helpers
|
helpers
|
||||||
|
|
|
@ -203,9 +203,19 @@ def merged_product(
|
||||||
|
|
||||||
iterator = merged_product(shape_cases, dtype_cases))
|
iterator = merged_product(shape_cases, dtype_cases))
|
||||||
next(iterator)
|
next(iterator)
|
||||||
# ValidationCase(shape=(10, 10, 10), dtype=float, passes=True, id="valid shape-float")
|
# ValidationCase(
|
||||||
|
# shape=(10, 10, 10),
|
||||||
|
# dtype=float,
|
||||||
|
# passes=True,
|
||||||
|
# id="valid shape-float"
|
||||||
|
# )
|
||||||
next(iterator)
|
next(iterator)
|
||||||
# ValidationCase(shape=(10, 10, 10), dtype=int, passes=False, id="valid shape-int")
|
# ValidationCase(
|
||||||
|
# shape=(10, 10, 10),
|
||||||
|
# dtype=int,
|
||||||
|
# passes=False,
|
||||||
|
# id="valid shape-int"
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from numpydantic.dtype import Float
|
from numpydantic.dtype import Float
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.types import NDArrayType
|
from numpydantic.types import DtypeType, NDArrayType
|
||||||
|
|
||||||
|
|
||||||
class InterfaceCase(ABC):
|
class InterfaceCase(ABC):
|
||||||
|
@ -29,43 +29,64 @@ class InterfaceCase(ABC):
|
||||||
"""The interface that this helper is for"""
|
"""The interface that this helper is for"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
def array_from_case(
|
||||||
def generate_array(
|
cls, case: "ValidationCase", path: Optional[Path] = None
|
||||||
cls, case: "ValidationCase", path: Path
|
|
||||||
) -> Optional[NDArrayType]:
|
) -> Optional[NDArrayType]:
|
||||||
"""
|
"""
|
||||||
Generate an array from the given validation case.
|
Generate an array from the given validation case.
|
||||||
|
|
||||||
Returns ``None`` if an array can't be generated for a specific case.
|
Returns ``None`` if an array can't be generated for a specific case.
|
||||||
"""
|
"""
|
||||||
|
return cls.make_array(shape=case.shape, dtype=case.dtype, path=path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_array(cls, case: "ValidationCase", path: Path) -> Optional[bool]:
|
@abstractmethod
|
||||||
|
def make_array(
|
||||||
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> Optional[NDArrayType]:
|
||||||
|
"""
|
||||||
|
Make an array from a shape and dtype, and a path if needed
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_case(cls, case: "ValidationCase", path: Path) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate a generated array against the annotation in the validation case.
|
Validate a generated array against the annotation in the validation case.
|
||||||
|
|
||||||
Kept in the InterfaceCase in case an interface has specific
|
Kept in the InterfaceCase in case an interface has specific
|
||||||
needs aside from just validating against a model, but typically left as is.
|
needs aside from just validating against a model, but typically left as is.
|
||||||
|
|
||||||
Does not raise on Validation errors -
|
|
||||||
returns bool instead for consistency's sake.
|
|
||||||
|
|
||||||
If an array can't be generated for a given case, returns `None`
|
If an array can't be generated for a given case, returns `None`
|
||||||
so that the calling function can know to skip rather than fail the case.
|
so that the calling function can know to skip rather than fail the case.
|
||||||
|
|
||||||
|
Raises exceptions if validation fails (or succeeds when it shouldn't)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
case (ValidationCase): The validation case to validate.
|
||||||
|
path (Path): Path to generate arrays into, if any.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if array is valid and was supposed to be,
|
||||||
|
or invalid and wasn't supposed to be
|
||||||
"""
|
"""
|
||||||
array = cls.generate_array(case, path)
|
import pytest
|
||||||
|
|
||||||
|
array = cls.array_from_case(case, path)
|
||||||
if array is None:
|
if array is None:
|
||||||
return None
|
pytest.skip()
|
||||||
try:
|
if case.passes:
|
||||||
case.model(array=array)
|
case.model(array=array)
|
||||||
# True if case is supposed to pass, False if it's not...
|
return True
|
||||||
return case.passes
|
else:
|
||||||
except ValidationError:
|
with pytest.raises(ValidationError):
|
||||||
# False if the case is supposed to pass, True if it is...
|
case.model(array=array)
|
||||||
return not case.passes
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def skip(cls, case: "ValidationCase") -> bool:
|
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
|
||||||
"""
|
"""
|
||||||
Whether a given interface should be skipped for the case
|
Whether a given interface should be skipped for the case
|
||||||
"""
|
"""
|
||||||
|
@ -97,6 +118,9 @@ class ValidationCase(BaseModel):
|
||||||
passes: bool = False
|
passes: bool = False
|
||||||
"""Whether the validation should pass or not"""
|
"""Whether the validation should pass or not"""
|
||||||
interface: Optional[InterfaceCase] = None
|
interface: Optional[InterfaceCase] = None
|
||||||
|
"""The interface test case to generate and validate the array with"""
|
||||||
|
path: Optional[Path] = None
|
||||||
|
"""The path to generate arrays into, if any."""
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@ -110,6 +134,39 @@ class ValidationCase(BaseModel):
|
||||||
|
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
|
def validate_case(self, path: Optional[Path] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Whether the generated array correctly validated against the annotation,
|
||||||
|
given the interface
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (:class:`pathlib.Path`): Directory to generate array into, if on disk.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if an ``interface`` is missing
|
||||||
|
"""
|
||||||
|
if self.interface is None:
|
||||||
|
raise ValueError("Missing an interface")
|
||||||
|
if path is None:
|
||||||
|
if self.path:
|
||||||
|
path = self.path
|
||||||
|
else:
|
||||||
|
raise ValueError("Missing a path to generate arrays into")
|
||||||
|
|
||||||
|
return self.interface.validate_case(self, path)
|
||||||
|
|
||||||
|
def array(self, path: Path) -> NDArrayType:
|
||||||
|
"""Generate an array for the validation case if we have an interface to do so"""
|
||||||
|
if self.interface is None:
|
||||||
|
raise ValueError("Missing an interface")
|
||||||
|
if path is None:
|
||||||
|
if self.path:
|
||||||
|
path = self.path
|
||||||
|
else:
|
||||||
|
raise ValueError("Missing a path to generate arrays into")
|
||||||
|
|
||||||
|
return self.interface.array_from_case(self, path)
|
||||||
|
|
||||||
def merge(
|
def merge(
|
||||||
self, other: Union["ValidationCase", Sequence["ValidationCase"]]
|
self, other: Union["ValidationCase", Sequence["ValidationCase"]]
|
||||||
) -> "ValidationCase":
|
) -> "ValidationCase":
|
||||||
|
@ -154,7 +211,9 @@ class ValidationCase(BaseModel):
|
||||||
(eg. due to the interface case being incompatible
|
(eg. due to the interface case being incompatible
|
||||||
with the requested dtype or shape)
|
with the requested dtype or shape)
|
||||||
"""
|
"""
|
||||||
return bool(self.interface is not None and self.interface.skip())
|
return bool(
|
||||||
|
self.interface is not None and self.interface.skip(self.shape, self.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_cases(*args: ValidationCase) -> ValidationCase:
|
def merge_cases(*args: ValidationCase) -> ValidationCase:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
|
@ -18,7 +18,8 @@ from numpydantic.interface import (
|
||||||
ZarrArrayPath,
|
ZarrArrayPath,
|
||||||
ZarrInterface,
|
ZarrInterface,
|
||||||
)
|
)
|
||||||
from numpydantic.testing.helpers import InterfaceCase, ValidationCase
|
from numpydantic.testing.helpers import InterfaceCase
|
||||||
|
from numpydantic.types import DtypeType
|
||||||
|
|
||||||
|
|
||||||
class NumpyCase(InterfaceCase):
|
class NumpyCase(InterfaceCase):
|
||||||
|
@ -27,11 +28,16 @@ class NumpyCase(InterfaceCase):
|
||||||
interface = NumpyInterface
|
interface = NumpyInterface
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(cls, case: "ValidationCase", path: Path) -> np.ndarray:
|
def make_array(
|
||||||
if issubclass(case.dtype, BaseModel):
|
cls,
|
||||||
return np.full(shape=case.shape, fill_value=case.dtype(x=1))
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if issubclass(dtype, BaseModel):
|
||||||
|
return np.full(shape=shape, fill_value=dtype(x=1))
|
||||||
else:
|
else:
|
||||||
return np.zeros(shape=case.shape, dtype=case.dtype)
|
return np.zeros(shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class _HDF5MetaCase(InterfaceCase):
|
class _HDF5MetaCase(InterfaceCase):
|
||||||
|
@ -40,33 +46,34 @@ class _HDF5MetaCase(InterfaceCase):
|
||||||
interface = H5Interface
|
interface = H5Interface
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def skip(cls, case: "ValidationCase") -> bool:
|
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
|
||||||
return not issubclass(case.dtype, BaseModel)
|
return issubclass(dtype, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class HDF5Case(_HDF5MetaCase):
|
class HDF5Case(_HDF5MetaCase):
|
||||||
"""HDF5 Array"""
|
"""HDF5 Array"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(
|
def make_array(
|
||||||
cls, case: "ValidationCase", path: Path
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
) -> Optional[H5ArrayPath]:
|
) -> Optional[H5ArrayPath]:
|
||||||
if cls.skip(case):
|
if cls.skip(shape, dtype):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hdf5_file = path / "h5f.h5"
|
hdf5_file = path / "h5f.h5"
|
||||||
array_path = (
|
array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
|
||||||
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__
|
|
||||||
)
|
|
||||||
generator = np.random.default_rng()
|
generator = np.random.default_rng()
|
||||||
|
|
||||||
if case.dtype is str:
|
if dtype is str:
|
||||||
data = generator.random(case.shape).astype(bytes)
|
data = generator.random(shape).astype(bytes)
|
||||||
elif case.dtype is datetime:
|
elif dtype is datetime:
|
||||||
data = np.empty(case.shape, dtype="S32")
|
data = np.empty(shape, dtype="S32")
|
||||||
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
|
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
|
||||||
else:
|
else:
|
||||||
data = generator.random(case.shape).astype(case.dtype)
|
data = generator.random(shape).astype(dtype)
|
||||||
|
|
||||||
h5path = H5ArrayPath(hdf5_file, array_path)
|
h5path = H5ArrayPath(hdf5_file, array_path)
|
||||||
|
|
||||||
|
@ -79,31 +86,30 @@ class HDF5CompoundCase(_HDF5MetaCase):
|
||||||
"""HDF5 Array with a fake compound dtype"""
|
"""HDF5 Array with a fake compound dtype"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(
|
def make_array(
|
||||||
cls, case: "ValidationCase", path: Path
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
) -> Optional[H5ArrayPath]:
|
) -> Optional[H5ArrayPath]:
|
||||||
if cls.skip(case):
|
if cls.skip(shape, dtype):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
hdf5_file = path / "h5f.h5"
|
hdf5_file = path / "h5f.h5"
|
||||||
array_path = (
|
array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
|
||||||
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__
|
if dtype is str:
|
||||||
)
|
|
||||||
if case.dtype is str:
|
|
||||||
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
||||||
data = np.array([("hey", 0)] * np.prod(case.shape), dtype=dt).reshape(
|
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
|
||||||
case.shape
|
elif dtype is datetime:
|
||||||
)
|
|
||||||
elif case.dtype is datetime:
|
|
||||||
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
|
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
|
||||||
data = np.array(
|
data = np.array(
|
||||||
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
|
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
|
||||||
* np.prod(case.shape),
|
* np.prod(shape),
|
||||||
dtype=dt,
|
dtype=dt,
|
||||||
).reshape(case.shape)
|
).reshape(shape)
|
||||||
else:
|
else:
|
||||||
dt = np.dtype([("data", case.dtype), ("extra", "i8")])
|
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
||||||
data = np.zeros(case.shape, dtype=dt)
|
data = np.zeros(shape, dtype=dt)
|
||||||
h5path = H5ArrayPath(hdf5_file, array_path, "data")
|
h5path = H5ArrayPath(hdf5_file, array_path, "data")
|
||||||
|
|
||||||
with h5py.File(hdf5_file, "w") as h5f:
|
with h5py.File(hdf5_file, "w") as h5f:
|
||||||
|
@ -117,11 +123,16 @@ class DaskCase(InterfaceCase):
|
||||||
interface = DaskInterface
|
interface = DaskInterface
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(cls, case: "ValidationCase", path: Path) -> da.Array:
|
def make_array(
|
||||||
if issubclass(case.dtype, BaseModel):
|
cls,
|
||||||
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> da.Array:
|
||||||
|
if issubclass(dtype, BaseModel):
|
||||||
|
return da.full(shape=shape, fill_value=dtype(x=1), chunks=-1)
|
||||||
else:
|
else:
|
||||||
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
|
return da.zeros(shape=shape, dtype=dtype, chunks=10)
|
||||||
|
|
||||||
|
|
||||||
class _ZarrMetaCase(InterfaceCase):
|
class _ZarrMetaCase(InterfaceCase):
|
||||||
|
@ -130,45 +141,65 @@ class _ZarrMetaCase(InterfaceCase):
|
||||||
interface = ZarrInterface
|
interface = ZarrInterface
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def skip(cls, case: "ValidationCase") -> bool:
|
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
|
||||||
return not issubclass(case.dtype, BaseModel)
|
return not issubclass(dtype, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class ZarrCase(_ZarrMetaCase):
|
class ZarrCase(_ZarrMetaCase):
|
||||||
"""In-memory zarr array"""
|
"""In-memory zarr array"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[zarr.Array]:
|
def make_array(
|
||||||
return zarr.zeros(shape=case.shape, dtype=case.dtype)
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> Optional[zarr.Array]:
|
||||||
|
return zarr.zeros(shape=shape, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class ZarrDirCase(_ZarrMetaCase):
|
class ZarrDirCase(_ZarrMetaCase):
|
||||||
"""On-disk zarr array"""
|
"""On-disk zarr array"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
|
def make_array(
|
||||||
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> Optional[zarr.Array]:
|
||||||
store = zarr.DirectoryStore(str(path / "array.zarr"))
|
store = zarr.DirectoryStore(str(path / "array.zarr"))
|
||||||
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
|
return zarr.zeros(shape=shape, dtype=dtype, store=store)
|
||||||
|
|
||||||
|
|
||||||
class ZarrZipCase(_ZarrMetaCase):
|
class ZarrZipCase(_ZarrMetaCase):
|
||||||
"""Zarr zip store"""
|
"""Zarr zip store"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
|
def make_array(
|
||||||
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> Optional[zarr.Array]:
|
||||||
store = zarr.ZipStore(str(path / "array.zarr"), mode="w")
|
store = zarr.ZipStore(str(path / "array.zarr"), mode="w")
|
||||||
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
|
return zarr.zeros(shape=shape, dtype=dtype, store=store)
|
||||||
|
|
||||||
|
|
||||||
class ZarrNestedCase(_ZarrMetaCase):
|
class ZarrNestedCase(_ZarrMetaCase):
|
||||||
"""Nested zarr array"""
|
"""Nested zarr array"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
|
def make_array(
|
||||||
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> ZarrArrayPath:
|
||||||
file = str(path / "nested.zarr")
|
file = str(path / "nested.zarr")
|
||||||
root = zarr.open(file, mode="w")
|
root = zarr.open(file, mode="w")
|
||||||
subpath = "a/b/c"
|
subpath = "a/b/c"
|
||||||
_ = root.zeros(subpath, shape=case.shape, dtype=case.dtype)
|
_ = root.zeros(subpath, shape=shape, dtype=dtype)
|
||||||
return ZarrArrayPath(file=file, path=subpath)
|
return ZarrArrayPath(file=file, path=subpath)
|
||||||
|
|
||||||
|
|
||||||
|
@ -178,13 +209,18 @@ class VideoCase(InterfaceCase):
|
||||||
interface = VideoInterface
|
interface = VideoInterface
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[Path]:
|
def make_array(
|
||||||
if cls.skip(case):
|
cls,
|
||||||
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
|
dtype: DtypeType = float,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
) -> Optional[Path]:
|
||||||
|
if cls.skip(shape, dtype):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
is_color = len(case.shape) == 4
|
is_color = len(shape) == 4
|
||||||
frames = case.shape[0]
|
frames = shape[0]
|
||||||
frame_shape = case.shape[1:]
|
frame_shape = shape[1:]
|
||||||
|
|
||||||
video_path = path / "test.avi"
|
video_path = path / "test.avi"
|
||||||
writer = cv2.VideoWriter(
|
writer = cv2.VideoWriter(
|
||||||
|
@ -207,12 +243,12 @@ class VideoCase(InterfaceCase):
|
||||||
return video_path
|
return video_path
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def skip(cls, case: "ValidationCase") -> bool:
|
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
|
||||||
"""We really can only handle 3-4 dimensional cases in 8-bit rn lol"""
|
"""We really can only handle 3-4 dimensional cases in 8-bit rn lol"""
|
||||||
if len(case.shape) < 3 or len(case.shape) > 4:
|
if len(shape) < 3 or len(shape) > 4:
|
||||||
return True
|
return True
|
||||||
if case.dtype not in (int, np.uint8):
|
if dtype not in (int, np.uint8):
|
||||||
return True
|
return True
|
||||||
# if we have a color video (ie. shape == 4, needs to be RGB)
|
# if we have a color video (ie. shape == 4, needs to be RGB)
|
||||||
if len(case.shape) == 4 and case.shape[3] != 3:
|
if len(shape) == 4 and shape[3] != 3:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -13,11 +13,19 @@ def pytest_addoption(parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=SHAPE_CASES)
|
@pytest.fixture(
|
||||||
def shape_cases(request) -> ValidationCase:
|
scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES]
|
||||||
return request.param
|
)
|
||||||
|
def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
|
||||||
|
case: ValidationCase = request.param.model_copy()
|
||||||
|
case.path = tmp_output_dir_func
|
||||||
|
return case
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=DTYPE_CASES)
|
@pytest.fixture(
|
||||||
def dtype_cases(request) -> ValidationCase:
|
scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES]
|
||||||
return request.param
|
)
|
||||||
|
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
|
||||||
|
case: ValidationCase = request.param.model_copy()
|
||||||
|
case.path = tmp_output_dir_func
|
||||||
|
return case
|
||||||
|
|
66
tests/fixtures/generation.py
vendored
66
tests/fixtures/generation.py
vendored
|
@ -1,60 +1,33 @@
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Tuple, Union
|
from typing import Callable, Tuple, Union
|
||||||
|
|
||||||
import cv2
|
|
||||||
import h5py
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import zarr
|
import zarr
|
||||||
|
|
||||||
from numpydantic.interface.hdf5 import H5ArrayPath
|
from numpydantic.interface.hdf5 import H5ArrayPath
|
||||||
from numpydantic.interface.zarr import ZarrArrayPath
|
from numpydantic.interface.zarr import ZarrArrayPath
|
||||||
|
from numpydantic.testing import ValidationCase
|
||||||
|
from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase, VideoCase
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def hdf5_array(
|
def hdf5_array(
|
||||||
request, tmp_output_dir_func
|
request, tmp_output_dir_func
|
||||||
) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]:
|
) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]:
|
||||||
hdf5_file = tmp_output_dir_func / "h5f.h5"
|
|
||||||
|
|
||||||
def _hdf5_array(
|
def _hdf5_array(
|
||||||
shape: Tuple[int, ...] = (10, 10),
|
shape: Tuple[int, ...] = (10, 10),
|
||||||
dtype: Union[np.dtype, type] = float,
|
dtype: Union[np.dtype, type] = float,
|
||||||
compound: bool = False,
|
compound: bool = False,
|
||||||
) -> H5ArrayPath:
|
) -> H5ArrayPath:
|
||||||
array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
|
if compound:
|
||||||
generator = np.random.default_rng()
|
array: H5ArrayPath = HDF5CompoundCase.make_array(
|
||||||
|
shape, dtype, tmp_output_dir_func
|
||||||
if not compound:
|
)
|
||||||
if dtype is str:
|
return array
|
||||||
data = generator.random(shape).astype(bytes)
|
|
||||||
elif dtype is datetime:
|
|
||||||
data = np.empty(shape, dtype="S32")
|
|
||||||
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
|
|
||||||
else:
|
|
||||||
data = generator.random(shape).astype(dtype)
|
|
||||||
|
|
||||||
h5path = H5ArrayPath(hdf5_file, array_path)
|
|
||||||
else:
|
else:
|
||||||
if dtype is str:
|
return HDF5Case.make_array(shape, dtype, tmp_output_dir_func)
|
||||||
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
|
||||||
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
|
|
||||||
elif dtype is datetime:
|
|
||||||
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
|
|
||||||
data = np.array(
|
|
||||||
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
|
|
||||||
* np.prod(shape),
|
|
||||||
dtype=dt,
|
|
||||||
).reshape(shape)
|
|
||||||
else:
|
|
||||||
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
|
||||||
data = np.zeros(shape, dtype=dt)
|
|
||||||
h5path = H5ArrayPath(hdf5_file, array_path, "data")
|
|
||||||
|
|
||||||
with h5py.File(hdf5_file, "w") as h5f:
|
|
||||||
_ = h5f.create_dataset(array_path, data=data)
|
|
||||||
return h5path
|
|
||||||
|
|
||||||
return _hdf5_array
|
return _hdf5_array
|
||||||
|
|
||||||
|
@ -79,28 +52,13 @@ def zarr_array(tmp_output_dir_func) -> Path:
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def avi_video(tmp_output_dir_func) -> Callable[[Tuple[int, int], int, bool], Path]:
|
def avi_video(tmp_output_dir_func) -> Callable[[Tuple[int, int], int, bool], Path]:
|
||||||
video_path = tmp_output_dir_func / "test.avi"
|
|
||||||
|
|
||||||
def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path:
|
def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path:
|
||||||
writer = cv2.VideoWriter(
|
shape = (frames, *shape)
|
||||||
str(video_path),
|
|
||||||
cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes
|
|
||||||
30,
|
|
||||||
(shape[1], shape[0]),
|
|
||||||
is_color,
|
|
||||||
)
|
|
||||||
if is_color:
|
if is_color:
|
||||||
shape = (*shape, 3)
|
shape = (*shape, 3)
|
||||||
|
return VideoCase.array_from_case(
|
||||||
for i in range(frames):
|
ValidationCase(shape=shape, dtype=np.uint8), tmp_output_dir_func
|
||||||
# make fresh array every time bc opencv eats them
|
)
|
||||||
array = np.zeros(shape, dtype=np.uint8)
|
|
||||||
if not is_color:
|
|
||||||
array[i, i] = i
|
|
||||||
else:
|
|
||||||
array[i, i, :] = i
|
|
||||||
writer.write(array)
|
|
||||||
writer.release()
|
|
||||||
return video_path
|
|
||||||
|
|
||||||
return _make_video
|
return _make_video
|
||||||
|
|
|
@ -1,12 +1,20 @@
|
||||||
|
import inspect
|
||||||
from typing import Callable, Tuple, Type
|
from typing import Callable, Tuple, Type
|
||||||
|
|
||||||
import dask.array as da
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import zarr
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from numpydantic import NDArray, interface
|
from numpydantic import NDArray, interface
|
||||||
|
from numpydantic.testing.helpers import InterfaceCase
|
||||||
|
from numpydantic.testing.interfaces import (
|
||||||
|
DaskCase,
|
||||||
|
HDF5Case,
|
||||||
|
NumpyCase,
|
||||||
|
VideoCase,
|
||||||
|
ZarrCase,
|
||||||
|
ZarrDirCase,
|
||||||
|
ZarrNestedCase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
|
@ -18,47 +26,55 @@ from numpydantic import NDArray, interface
|
||||||
id="numpy-list",
|
id="numpy-list",
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
(np.zeros((3, 4)), interface.NumpyInterface),
|
(NumpyCase, interface.NumpyInterface),
|
||||||
marks=pytest.mark.numpy,
|
marks=pytest.mark.numpy,
|
||||||
id="numpy",
|
id="numpy",
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
("hdf5_array", interface.H5Interface),
|
(HDF5Case, interface.H5Interface),
|
||||||
marks=pytest.mark.hdf5,
|
marks=pytest.mark.hdf5,
|
||||||
id="h5-array-path",
|
id="h5-array-path",
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
(da.random.random((10, 10)), interface.DaskInterface),
|
(DaskCase, interface.DaskInterface),
|
||||||
marks=pytest.mark.dask,
|
marks=pytest.mark.dask,
|
||||||
id="dask",
|
id="dask",
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
(zarr.ones((10, 10)), interface.ZarrInterface),
|
(ZarrCase, interface.ZarrInterface),
|
||||||
marks=pytest.mark.zarr,
|
marks=pytest.mark.zarr,
|
||||||
id="zarr-memory",
|
id="zarr-memory",
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
("zarr_nested_array", interface.ZarrInterface),
|
(ZarrNestedCase, interface.ZarrInterface),
|
||||||
marks=pytest.mark.zarr,
|
marks=pytest.mark.zarr,
|
||||||
id="zarr-nested",
|
id="zarr-nested",
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
("zarr_array", interface.ZarrInterface),
|
(ZarrDirCase, interface.ZarrInterface),
|
||||||
marks=pytest.mark.zarr,
|
marks=pytest.mark.zarr,
|
||||||
id="zarr-array",
|
id="zarr-dir",
|
||||||
),
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
("avi_video", interface.VideoInterface), marks=pytest.mark.video, id="video"
|
(VideoCase, interface.VideoInterface), marks=pytest.mark.video, id="video"
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def interface_type(request) -> Tuple[NDArray, Type[interface.Interface]]:
|
def interface_type(
|
||||||
|
request, tmp_output_dir_func
|
||||||
|
) -> Tuple[NDArray, Type[interface.Interface]]:
|
||||||
"""
|
"""
|
||||||
Test cases for each interface's ``check`` method - each input should match the
|
Test cases for each interface's ``check`` method - each input should match the
|
||||||
provided interface and that interface only
|
provided interface and that interface only
|
||||||
"""
|
"""
|
||||||
if isinstance(request.param[0], str):
|
|
||||||
return (request.getfixturevalue(request.param[0]), request.param[1])
|
if inspect.isclass(request.param[0]) and issubclass(
|
||||||
|
request.param[0], InterfaceCase
|
||||||
|
):
|
||||||
|
array = request.param[0].make_array(path=tmp_output_dir_func)
|
||||||
|
if array is None:
|
||||||
|
pytest.skip()
|
||||||
|
return array, request.param[1]
|
||||||
else:
|
else:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
|
@ -2,31 +2,13 @@ import json
|
||||||
|
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
|
|
||||||
from numpydantic.exceptions import DtypeError, ShapeError
|
|
||||||
from numpydantic.interface import DaskInterface
|
from numpydantic.interface import DaskInterface
|
||||||
from numpydantic.testing.helpers import ValidationCase
|
from numpydantic.testing.interfaces import DaskCase
|
||||||
|
|
||||||
pytestmark = pytest.mark.dask
|
pytestmark = pytest.mark.dask
|
||||||
|
|
||||||
|
|
||||||
def dask_array(case: ValidationCase) -> da.Array:
|
|
||||||
if issubclass(case.dtype, BaseModel):
|
|
||||||
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
|
|
||||||
else:
|
|
||||||
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_dask_case(case: ValidationCase):
|
|
||||||
array = dask_array(case)
|
|
||||||
if case.passes:
|
|
||||||
case.model(array=array)
|
|
||||||
else:
|
|
||||||
with pytest.raises((ValidationError, DtypeError, ShapeError)):
|
|
||||||
case.model(array=array)
|
|
||||||
|
|
||||||
|
|
||||||
def test_dask_enabled():
|
def test_dask_enabled():
|
||||||
"""
|
"""
|
||||||
We need dask to be available to run these tests :)
|
We need dask to be available to run these tests :)
|
||||||
|
@ -43,12 +25,14 @@ def test_dask_check(interface_type):
|
||||||
|
|
||||||
@pytest.mark.shape
|
@pytest.mark.shape
|
||||||
def test_dask_shape(shape_cases):
|
def test_dask_shape(shape_cases):
|
||||||
_test_dask_case(shape_cases)
|
shape_cases.interface = DaskCase
|
||||||
|
shape_cases.validate_case()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dtype
|
@pytest.mark.dtype
|
||||||
def test_dask_dtype(dtype_cases):
|
def test_dask_dtype(dtype_cases):
|
||||||
_test_dask_case(dtype_cases)
|
dtype_cases.interface = DaskCase
|
||||||
|
dtype_cases.validate_case()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.serialization
|
@pytest.mark.serialization
|
||||||
|
|
|
@ -12,10 +12,21 @@ from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
from numpydantic.interface import H5Interface
|
from numpydantic.interface import H5Interface
|
||||||
from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy
|
from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy
|
||||||
from numpydantic.testing.helpers import ValidationCase
|
from numpydantic.testing.helpers import ValidationCase
|
||||||
|
from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase
|
||||||
|
|
||||||
pytestmark = pytest.mark.hdf5
|
pytestmark = pytest.mark.hdf5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[
|
||||||
|
pytest.param(HDF5Case, id="hdf5"),
|
||||||
|
pytest.param(HDF5CompoundCase, id="hdf5-compound"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def hdf5_cases(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
def hdf5_array_case(
|
def hdf5_array_case(
|
||||||
case: ValidationCase, array_func, compound: bool = False
|
case: ValidationCase, array_func, compound: bool = False
|
||||||
) -> H5ArrayPath:
|
) -> H5ArrayPath:
|
||||||
|
@ -47,8 +58,6 @@ def test_hdf5_enabled():
|
||||||
|
|
||||||
def test_hdf5_check(interface_type):
|
def test_hdf5_check(interface_type):
|
||||||
if interface_type[1] is H5Interface:
|
if interface_type[1] is H5Interface:
|
||||||
if interface_type[0].__name__ == "_hdf5_array":
|
|
||||||
interface_type = (interface_type[0](), interface_type[1])
|
|
||||||
assert H5Interface.check(interface_type[0])
|
assert H5Interface.check(interface_type[0])
|
||||||
if isinstance(interface_type[0], H5ArrayPath):
|
if isinstance(interface_type[0], H5ArrayPath):
|
||||||
# also test that we can instantiate from a tuple like the H5ArrayPath
|
# also test that we can instantiate from a tuple like the H5ArrayPath
|
||||||
|
@ -74,15 +83,17 @@ def test_hdf5_check_not_hdf5(tmp_path):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.shape
|
@pytest.mark.shape
|
||||||
@pytest.mark.parametrize("compound", [True, False])
|
def test_hdf5_shape(shape_cases, hdf5_cases):
|
||||||
def test_hdf5_shape(shape_cases, hdf5_array, compound):
|
shape_cases.interface = hdf5_cases
|
||||||
_test_hdf5_case(shape_cases, hdf5_array, compound)
|
if shape_cases.skip():
|
||||||
|
pytest.skip()
|
||||||
|
shape_cases.validate_case()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dtype
|
@pytest.mark.dtype
|
||||||
@pytest.mark.parametrize("compound", [True, False])
|
def test_hdf5_dtype(dtype_cases, hdf5_cases):
|
||||||
def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
|
dtype_cases.interface = hdf5_cases
|
||||||
_test_hdf5_case(dtype_cases, hdf5_array, compound)
|
dtype_cases.validate_case()
|
||||||
|
|
||||||
|
|
||||||
def test_hdf5_dataset_not_exists(hdf5_array, model_blank):
|
def test_hdf5_dataset_not_exists(hdf5_array, model_blank):
|
||||||
|
|
Loading…
Reference in a new issue