interface cases

This commit is contained in:
sneakers-the-rat 2024-10-03 23:18:18 -07:00
parent ad060ce40d
commit e701bf6e9b
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
8 changed files with 542 additions and 123 deletions

View file

@ -2,6 +2,15 @@
Utilities for testing and 3rd-party interface development. Utilities for testing and 3rd-party interface development.
Only things that *don't* require pytest go in this module.
We want to keep all test-time specific behavior there,
and have this just serve as helpers exposed for downstream interface development.
We want to avoid pytest stuff bleeding in here because then we limit
the ability for downstream developers to configure their own tests.
*(If there is some reason to change this division of labor, just raise an issue and let's chat.)*
```{toctree} ```{toctree}
cases cases
helpers helpers

View file

@ -131,6 +131,9 @@ markers = [
"zarr: zarr interface", "zarr: zarr interface",
] ]
[tool.black]
target-version = ["py39", "py310", "py311", "py312"]
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py39"
include = ["src/numpydantic/**/*.py", "tests/**/*.py", "pyproject.toml"] include = ["src/numpydantic/**/*.py", "tests/**/*.py", "pyproject.toml"]

View file

@ -3,7 +3,7 @@ Interfaces between nptyping types and array backends
""" """
from numpydantic.interface.dask import DaskInterface from numpydantic.interface.dask import DaskInterface
from numpydantic.interface.hdf5 import H5Interface from numpydantic.interface.hdf5 import H5ArrayPath, H5Interface
from numpydantic.interface.interface import ( from numpydantic.interface.interface import (
Interface, Interface,
InterfaceMark, InterfaceMark,
@ -12,10 +12,11 @@ from numpydantic.interface.interface import (
) )
from numpydantic.interface.numpy import NumpyInterface from numpydantic.interface.numpy import NumpyInterface
from numpydantic.interface.video import VideoInterface from numpydantic.interface.video import VideoInterface
from numpydantic.interface.zarr import ZarrInterface from numpydantic.interface.zarr import ZarrArrayPath, ZarrInterface
__all__ = [ __all__ = [
"DaskInterface", "DaskInterface",
"H5ArrayPath",
"H5Interface", "H5Interface",
"Interface", "Interface",
"InterfaceMark", "InterfaceMark",
@ -23,5 +24,6 @@ __all__ = [
"MarkedJson", "MarkedJson",
"NumpyInterface", "NumpyInterface",
"VideoInterface", "VideoInterface",
"ZarrArrayPath",
"ZarrInterface", "ZarrInterface",
] ]

View file

@ -0,0 +1,6 @@
from numpydantic.testing.helpers import InterfaceCase, ValidationCase
__all__ = [
"InterfaceCase",
"ValidationCase",
]

View file

@ -1,12 +1,25 @@
import sys import sys
from typing import Union from collections.abc import Sequence
from itertools import product
from typing import Generator, Union
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
from numpydantic.dtype import Float, Integer, Number from numpydantic.dtype import Float, Integer, Number
from numpydantic.testing.helpers import ValidationCase from numpydantic.testing.helpers import ValidationCase, merge_cases
from numpydantic.testing.interfaces import (
DaskCase,
HDF5Case,
HDF5CompoundCase,
NumpyCase,
VideoCase,
ZarrCase,
ZarrDirCase,
ZarrNestedCase,
ZarrZipCase,
)
if sys.version_info.minor >= 10: if sys.version_info.minor >= 10:
from typing import TypeAlias from typing import TypeAlias
@ -30,6 +43,10 @@ class SubClass(BasicModel):
pass pass
# --------------------------------------------------
# Annotations
# --------------------------------------------------
RGB_UNION: TypeAlias = Union[ RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number],
@ -42,89 +59,159 @@ STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]] UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32] UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
SHAPE_CASES = (
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
ValidationCase(shape=(10, 10, 10, 10), passes=False, id="extra dimension"),
ValidationCase(shape=(11, 10, 10), passes=False, id="dimension too large"),
ValidationCase(shape=(9, 10, 10), passes=False, id="dimension too small"),
ValidationCase(shape=(10, 10, 9), passes=True, id="wildcard smaller"),
ValidationCase(shape=(10, 10, 11), passes=True, id="wildcard larger"),
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True, id="Union 2D"),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True, id="Union 3D"),
ValidationCase(
annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True, id="Union 4D"
),
ValidationCase(
annotation=RGB_UNION, shape=(5, 5, 4), passes=False, id="Union incorrect 3D"
),
ValidationCase(
annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False, id="Union incorrect 4D"
),
ValidationCase(
annotation=RGB_UNION,
shape=(5, 5, 4, 6),
passes=False,
id="Union incorrect both",
),
)
DTYPE_CASES = [ DTYPE_CASES = [
ValidationCase(dtype=float, passes=True), ValidationCase(dtype=float, passes=True, id="float"),
ValidationCase(dtype=int, passes=False), ValidationCase(dtype=int, passes=False, id="int"),
ValidationCase(dtype=np.uint8, passes=False), ValidationCase(dtype=np.uint8, passes=False, id="uint8"),
ValidationCase(annotation=NUMBER, dtype=int, passes=True), ValidationCase(annotation=NUMBER, dtype=int, passes=True, id="number-int"),
ValidationCase(annotation=NUMBER, dtype=float, passes=True), ValidationCase(annotation=NUMBER, dtype=float, passes=True, id="number-float"),
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True), ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True, id="number-uint8"),
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True), ValidationCase(
ValidationCase(annotation=NUMBER, dtype=str, passes=False), annotation=NUMBER, dtype=np.float16, passes=True, id="number-float16"
ValidationCase(annotation=INTEGER, dtype=int, passes=True), ),
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True), ValidationCase(annotation=NUMBER, dtype=str, passes=False, id="number-str"),
ValidationCase(annotation=INTEGER, dtype=float, passes=False), ValidationCase(annotation=INTEGER, dtype=int, passes=True, id="integer-int"),
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False), ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True, id="integer-uint8"),
ValidationCase(annotation=INTEGER, dtype=str, passes=False), ValidationCase(annotation=INTEGER, dtype=float, passes=False, id="integer-float"),
ValidationCase(annotation=FLOAT, dtype=float, passes=True), ValidationCase(
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True), annotation=INTEGER, dtype=np.float32, passes=False, id="integer-float32"
ValidationCase(annotation=FLOAT, dtype=int, passes=False), ),
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False), ValidationCase(annotation=INTEGER, dtype=str, passes=False, id="integer-str"),
ValidationCase(annotation=FLOAT, dtype=str, passes=False), ValidationCase(annotation=FLOAT, dtype=float, passes=True, id="float-float"),
ValidationCase(annotation=STRING, dtype=str, passes=True), ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True, id="float-float32"),
ValidationCase(annotation=STRING, dtype=int, passes=False), ValidationCase(annotation=FLOAT, dtype=int, passes=False, id="float-int"),
ValidationCase(annotation=STRING, dtype=float, passes=False), ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False, id="float-uint8"),
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), ValidationCase(annotation=FLOAT, dtype=str, passes=False, id="float-str"),
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=STRING, dtype=str, passes=True, id="str-str"),
ValidationCase(annotation=MODEL, dtype=int, passes=False), ValidationCase(annotation=STRING, dtype=int, passes=False, id="str-int"),
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), ValidationCase(annotation=STRING, dtype=float, passes=False, id="str-float"),
ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True), ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True, id="model-model"),
ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True), ValidationCase(annotation=MODEL, dtype=BadModel, passes=False, id="model-badmodel"),
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False, id="model-int"),
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False), ValidationCase(annotation=MODEL, dtype=SubClass, passes=True, id="model-subclass"),
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False), ValidationCase(
annotation=UNION_TYPE, dtype=np.uint32, passes=True, id="union-type-uint32"
),
ValidationCase(
annotation=UNION_TYPE, dtype=np.float32, passes=True, id="union-type-float32"
),
ValidationCase(
annotation=UNION_TYPE, dtype=np.uint64, passes=False, id="union-type-uint64"
),
ValidationCase(
annotation=UNION_TYPE, dtype=np.float64, passes=False, id="union-type-float64"
),
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False, id="union-type-str"),
] ]
DTYPE_IDS = [
"float",
"int",
"uint8",
"number-int",
"number-float",
"number-uint8",
"number-float16",
"number-str",
"integer-int",
"integer-uint8",
"integer-float",
"integer-float32",
"integer-str",
"float-float",
"float-float32",
"float-int",
"float-uint8",
"float-str",
"str-str",
"str-int",
"str-float",
"model-model",
"model-badmodel",
"model-int",
"model-subclass",
"union-type-uint32",
"union-type-float32",
"union-type-uint64",
"union-type-float64",
"union-type-str",
]
if YES_PIPE: if YES_PIPE:
DTYPE_CASES.extend( DTYPE_CASES.extend(
[ [
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True), ValidationCase(
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True), annotation=UNION_PIPE,
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False), dtype=np.uint32,
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False), passes=True,
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False), id="union-pipe-uint32",
),
ValidationCase(
annotation=UNION_PIPE,
dtype=np.float32,
passes=True,
id="union-pipe-float32",
),
ValidationCase(
annotation=UNION_PIPE,
dtype=np.uint64,
passes=False,
id="union-pipe-uint64",
),
ValidationCase(
annotation=UNION_PIPE,
dtype=np.float64,
passes=False,
id="union-pipe-float64",
),
ValidationCase(
annotation=UNION_PIPE, dtype=str, passes=False, id="union-pipe-str"
),
] ]
) )
DTYPE_IDS.extend(
[ _INTERFACE_CASES = [
"union-pipe-uint32", NumpyCase,
"union-pipe-float32", HDF5Case,
"union-pipe-uint64", HDF5CompoundCase,
"union-pipe-float64", DaskCase,
"union-pipe-str", ZarrCase,
ZarrDirCase,
ZarrZipCase,
ZarrNestedCase,
VideoCase,
] ]
)
def merged_product(
*args: Sequence[ValidationCase],
) -> Generator[ValidationCase, None, None]:
"""
Generator for the product of the iterators of validation cases,
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
or not.
Examples:
.. code-block:: python
shape_cases = [
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
]
dtype_cases = [
ValidationCase(dtype=float, passes=True, id="float"),
ValidationCase(dtype=int, passes=False, id="int"),
]
iterator = merged_product(shape_cases, dtype_cases))
next(iterator)
# ValidationCase(shape=(10, 10, 10), dtype=float, passes=True, id="valid shape-float")
next(iterator)
# ValidationCase(shape=(10, 10, 10), dtype=int, passes=False, id="valid shape-int")
"""
iterator = product(*args)
for case_tuple in iterator:
case = merge_cases(case_tuple)
if case.skip():
continue
yield case

View file

@ -1,10 +1,76 @@
from typing import Any, Tuple, Type, Union from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Optional, Tuple, Type, Union
import numpy as np import numpy as np
from pydantic import BaseModel, ConfigDict, computed_field from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
from numpydantic.dtype import Float from numpydantic.dtype import Float
from numpydantic.interface import Interface
from numpydantic.types import NDArrayType
class InterfaceCase(ABC):
"""
An interface test helper that allows a given interface to generate and validate
arrays in one of its formats.
Each instance of "interface test case" should be considered one of the
potentially multiple realizations of a given interface.
If an interface has multiple formats (eg. zarr's different `store` s),
then it should have several test helpers.
"""
@property
@abstractmethod
def interface(self) -> Interface:
"""The interface that this helper is for"""
@classmethod
@abstractmethod
def generate_array(
cls, case: "ValidationCase", path: Path
) -> Optional[NDArrayType]:
"""
Generate an array from the given validation case.
Returns ``None`` if an array can't be generated for a specific case.
"""
@classmethod
def validate_array(cls, case: "ValidationCase", path: Path) -> Optional[bool]:
"""
Validate a generated array against the annotation in the validation case.
Kept in the InterfaceCase in case an interface has specific
needs aside from just validating against a model, but typically left as is.
Does not raise on Validation errors -
returns bool instead for consistency's sake.
If an array can't be generated for a given case, returns `None`
so that the calling function can know to skip rather than fail the case.
"""
array = cls.generate_array(case, path)
if array is None:
return None
try:
case.model(array=array)
# True if case is supposed to pass, False if it's not...
return case.passes
except ValidationError:
# False if the case is supposed to pass, True if it is...
return not case.passes
@classmethod
def skip(cls, case: "ValidationCase") -> bool:
"""
Whether a given interface should be skipped for the case
"""
# Assume an interface case is valid for all other cases
return False
class ValidationCase(BaseModel): class ValidationCase(BaseModel):
@ -15,6 +81,10 @@ class ValidationCase(BaseModel):
test in a given interface test in a given interface
""" """
id: Optional[str] = None
"""
String identifying the validation case
"""
annotation: Any = NDArray[Shape["10, 10, *"], Float] annotation: Any = NDArray[Shape["10, 10, *"], Float]
""" """
Array annotation used in the validating model Array annotation used in the validating model
@ -24,8 +94,9 @@ class ValidationCase(BaseModel):
"""Shape of the array to validate""" """Shape of the array to validate"""
dtype: Union[Type, np.dtype] = float dtype: Union[Type, np.dtype] = float
"""Dtype of the array to validate""" """Dtype of the array to validate"""
passes: bool passes: bool = False
"""Whether the validation should pass or not""" """Whether the validation should pass or not"""
interface: Optional[InterfaceCase] = None
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -38,3 +109,62 @@ class ValidationCase(BaseModel):
array: annotation array: annotation
return Model return Model
def merge(
self, other: Union["ValidationCase", Sequence["ValidationCase"]]
) -> "ValidationCase":
"""
Merge two validation cases
Dump both, excluding any unset fields, and merge, preferring `other`.
``valid`` is ``True`` if and only if it is ``True`` in both.
"""
if isinstance(other, Sequence):
return merge_cases(self, *other)
self_dump = self.model_dump(exclude_unset=True)
other_dump = other.model_dump(exclude_unset=True)
# dumps might not have set `valid`, use only the ones that have
valids = [
v
for v in [self_dump.get("valid", None), other_dump.get("valid", None)]
if v is not None
]
valid = all(valids)
# combine ids if present
ids = "-".join(
[
str(v)
for v in [self_dump.get("id", None), other_dump.get("id", None)]
if v is not None
]
)
merged = {**self_dump, **other_dump}
merged["valid"] = valid
merged["id"] = ids
return ValidationCase(**merged)
def skip(self) -> bool:
"""
Whether this case should be skipped
(eg. due to the interface case being incompatible
with the requested dtype or shape)
"""
return bool(self.interface is not None and self.interface.skip())
def merge_cases(*args: ValidationCase) -> ValidationCase:
"""
Merge multiple validation cases
"""
if len(args) == 1:
return args[0]
case = args[0]
for arg in args[1:]:
case = case.merge(arg)
return case

View file

@ -0,0 +1,218 @@
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
import cv2
import dask.array as da
import h5py
import numpy as np
import zarr
from pydantic import BaseModel
from numpydantic.interface import (
DaskInterface,
H5ArrayPath,
H5Interface,
NumpyInterface,
VideoInterface,
ZarrArrayPath,
ZarrInterface,
)
from numpydantic.testing.helpers import InterfaceCase, ValidationCase
class NumpyCase(InterfaceCase):
"""In-memory numpy array"""
interface = NumpyInterface
@classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> np.ndarray:
if issubclass(case.dtype, BaseModel):
return np.full(shape=case.shape, fill_value=case.dtype(x=1))
else:
return np.zeros(shape=case.shape, dtype=case.dtype)
class _HDF5MetaCase(InterfaceCase):
"""Base case for hdf5 cases"""
interface = H5Interface
@classmethod
def skip(cls, case: "ValidationCase") -> bool:
return not issubclass(case.dtype, BaseModel)
class HDF5Case(_HDF5MetaCase):
"""HDF5 Array"""
@classmethod
def generate_array(
cls, case: "ValidationCase", path: Path
) -> Optional[H5ArrayPath]:
if cls.skip(case):
return None
hdf5_file = path / "h5f.h5"
array_path = (
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__
)
generator = np.random.default_rng()
if case.dtype is str:
data = generator.random(case.shape).astype(bytes)
elif case.dtype is datetime:
data = np.empty(case.shape, dtype="S32")
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
else:
data = generator.random(case.shape).astype(case.dtype)
h5path = H5ArrayPath(hdf5_file, array_path)
with h5py.File(hdf5_file, "w") as h5f:
_ = h5f.create_dataset(array_path, data=data)
return h5path
class HDF5CompoundCase(_HDF5MetaCase):
"""HDF5 Array with a fake compound dtype"""
@classmethod
def generate_array(
cls, case: "ValidationCase", path: Path
) -> Optional[H5ArrayPath]:
if cls.skip(case):
return None
hdf5_file = path / "h5f.h5"
array_path = (
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__
)
if case.dtype is str:
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
data = np.array([("hey", 0)] * np.prod(case.shape), dtype=dt).reshape(
case.shape
)
elif case.dtype is datetime:
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
data = np.array(
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
* np.prod(case.shape),
dtype=dt,
).reshape(case.shape)
else:
dt = np.dtype([("data", case.dtype), ("extra", "i8")])
data = np.zeros(case.shape, dtype=dt)
h5path = H5ArrayPath(hdf5_file, array_path, "data")
with h5py.File(hdf5_file, "w") as h5f:
_ = h5f.create_dataset(array_path, data=data)
return h5path
class DaskCase(InterfaceCase):
"""In-memory dask array"""
interface = DaskInterface
@classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> da.Array:
if issubclass(case.dtype, BaseModel):
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
else:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
class _ZarrMetaCase(InterfaceCase):
"""Shared classmethods for zarr cases"""
interface = ZarrInterface
@classmethod
def skip(cls, case: "ValidationCase") -> bool:
return not issubclass(case.dtype, BaseModel)
class ZarrCase(_ZarrMetaCase):
"""In-memory zarr array"""
@classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[zarr.Array]:
return zarr.zeros(shape=case.shape, dtype=case.dtype)
class ZarrDirCase(_ZarrMetaCase):
"""On-disk zarr array"""
@classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
store = zarr.DirectoryStore(str(path / "array.zarr"))
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
class ZarrZipCase(_ZarrMetaCase):
"""Zarr zip store"""
@classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
store = zarr.ZipStore(str(path / "array.zarr"), mode="w")
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
class ZarrNestedCase(_ZarrMetaCase):
"""Nested zarr array"""
@classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
file = str(path / "nested.zarr")
root = zarr.open(file, mode="w")
subpath = "a/b/c"
_ = root.zeros(subpath, shape=case.shape, dtype=case.dtype)
return ZarrArrayPath(file=file, path=subpath)
class VideoCase(InterfaceCase):
"""AVI video"""
interface = VideoInterface
@classmethod
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[Path]:
if cls.skip(case):
return None
is_color = len(case.shape) == 4
frames = case.shape[0]
frame_shape = case.shape[1:]
video_path = path / "test.avi"
writer = cv2.VideoWriter(
str(video_path),
cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes
30,
(frame_shape[1], frame_shape[0]),
is_color,
)
for i in range(frames):
# make fresh array every time bc opencv eats them
array = np.zeros(frame_shape, dtype=np.uint8)
if not is_color:
array[i, i] = i
else:
array[i, i, :] = i
writer.write(array)
writer.release()
return video_path
@classmethod
def skip(cls, case: "ValidationCase") -> bool:
"""We really can only handle 3-4 dimensional cases in 8-bit rn lol"""
if len(case.shape) < 3 or len(case.shape) > 4:
return True
if case.dtype not in (int, np.uint8):
return True
# if we have a color video (ie. shape == 4, needs to be RGB)
if len(case.shape) == 4 and case.shape[3] != 3:
return True

View file

@ -1,10 +1,6 @@
import pytest import pytest
from numpydantic.testing.cases import ( from numpydantic.testing.cases import DTYPE_CASES, SHAPE_CASES
DTYPE_CASES,
DTYPE_IDS,
RGB_UNION,
)
from numpydantic.testing.helpers import ValidationCase from numpydantic.testing.helpers import ValidationCase
from tests.fixtures import * from tests.fixtures import *
@ -17,43 +13,11 @@ def pytest_addoption(parser):
) )
@pytest.fixture( @pytest.fixture(scope="module", params=SHAPE_CASES)
scope="module",
params=[
ValidationCase(shape=(10, 10, 10), passes=True),
ValidationCase(shape=(10, 10), passes=False),
ValidationCase(shape=(10, 10, 10, 10), passes=False),
ValidationCase(shape=(11, 10, 10), passes=False),
ValidationCase(shape=(9, 10, 10), passes=False),
ValidationCase(shape=(10, 10, 9), passes=True),
ValidationCase(shape=(10, 10, 11), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False),
],
ids=[
"valid shape",
"missing dimension",
"extra dimension",
"dimension too large",
"dimension too small",
"wildcard smaller",
"wildcard larger",
"Union 2D",
"Union 3D",
"Union 4D",
"Union incorrect 3D",
"Union incorrect 4D",
"Union incorrect both",
],
)
def shape_cases(request) -> ValidationCase: def shape_cases(request) -> ValidationCase:
return request.param return request.param
@pytest.fixture(scope="module", params=DTYPE_CASES, ids=DTYPE_IDS) @pytest.fixture(scope="module", params=DTYPE_CASES)
def dtype_cases(request) -> ValidationCase: def dtype_cases(request) -> ValidationCase:
return request.param return request.param