hoo boy. working combinatoric testing.

Split out annotation dtype and shape, swap out all interface tests, fix numpy and dask model casting, make merging models more efficient, correctly parameterize and mark tests!
This commit is contained in:
sneakers-the-rat 2024-10-10 23:56:45 -07:00
parent 5d4f03a8a9
commit 1187b37b2d
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
12 changed files with 482 additions and 278 deletions

View file

@ -5,7 +5,7 @@ Interface for Dask arrays
from typing import Any, Iterable, List, Literal, Optional, Union
import numpy as np
from pydantic import SerializationInfo
from pydantic import BaseModel, SerializationInfo
from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType
@ -70,9 +70,33 @@ class DaskInterface(Interface):
else:
return False
def before_validation(self, array: DaskArray) -> NDArrayType:
"""
Try and coerce dicts that should be model objects into the model objects
"""
try:
if issubclass(self.dtype, BaseModel) and isinstance(
array.reshape(-1)[0].compute(), dict
):
def _chunked_to_model(array: np.ndarray) -> np.ndarray:
def _vectorized_to_model(item: Union[dict, BaseModel]) -> BaseModel:
if not isinstance(item, self.dtype):
return self.dtype(**item)
else:
return item
return np.vectorize(_vectorized_to_model)(array)
array = array.map_blocks(_chunked_to_model, dtype=self.dtype)
except TypeError:
# fine, dtype isn't a type
pass
return array
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""Dask arrays require a compute() call to retrieve a single value"""
return type(array.ravel()[0].compute())
return type(array.reshape(-1)[0].compute())
@classmethod
def enabled(cls) -> bool:

View file

@ -4,7 +4,7 @@ Interface to numpy arrays
from typing import Any, Literal, Union
from pydantic import SerializationInfo
from pydantic import BaseModel, SerializationInfo
from numpydantic.interface.interface import Interface, JsonDict
@ -59,6 +59,9 @@ class NumpyInterface(Interface):
Check that this is in fact a numpy ndarray or something that can be
coerced to one
"""
if array is None:
return False
if isinstance(array, ndarray):
return True
elif isinstance(array, dict):
@ -77,6 +80,14 @@ class NumpyInterface(Interface):
"""
if not isinstance(array, ndarray):
array = np.array(array)
try:
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
array = np.vectorize(lambda x: self.dtype(**x))(array)
except TypeError:
# fine, dtype isn't a type
pass
return array
@classmethod

View file

@ -63,6 +63,7 @@ class ZarrJsonDict(JsonDict):
type: Literal["zarr"]
file: Optional[str] = None
path: Optional[str] = None
dtype: Optional[str] = None
value: Optional[list] = None
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
@ -73,7 +74,7 @@ class ZarrJsonDict(JsonDict):
if self.file:
array = ZarrArrayPath(file=self.file, path=self.path)
else:
array = zarr.array(self.value)
array = zarr.array(self.value, dtype=self.dtype)
return array
@ -194,6 +195,7 @@ class ZarrInterface(Interface):
is_file = False
as_json = {"type": cls.name}
as_json["dtype"] = array.dtype.name
if hasattr(array.store, "dir_path"):
is_file = True
as_json["file"] = array.store.dir_path()

View file

@ -1,14 +1,11 @@
import sys
from collections.abc import Sequence
from itertools import product
from typing import Generator, Union
from typing import Union
import numpy as np
from pydantic import BaseModel
from numpydantic import NDArray, Shape
from numpydantic.dtype import Float, Integer, Number
from numpydantic.testing.helpers import ValidationCase, merge_cases
from numpydantic.testing.helpers import ValidationCase, merged_product
from numpydantic.testing.interfaces import (
DaskCase,
HDF5Case,
@ -31,53 +28,6 @@ else:
YES_PIPE = False
def merged_product(
*args: Sequence[ValidationCase],
) -> Generator[ValidationCase, None, None]:
"""
Generator for the product of the iterators of validation cases,
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
or not.
Examples:
.. code-block:: python
shape_cases = [
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
]
dtype_cases = [
ValidationCase(dtype=float, passes=True, id="float"),
ValidationCase(dtype=int, passes=False, id="int"),
]
iterator = merged_product(shape_cases, dtype_cases))
next(iterator)
# ValidationCase(
# shape=(10, 10, 10),
# dtype=float,
# passes=True,
# id="valid shape-float"
# )
next(iterator)
# ValidationCase(
# shape=(10, 10, 10),
# dtype=int,
# passes=False,
# id="valid shape-int"
# )
"""
iterator = product(*args)
for case_tuple in iterator:
case = merge_cases(case_tuple)
if case.skip():
continue
yield case
class BasicModel(BaseModel):
x: int
@ -94,39 +44,40 @@ class SubClass(BasicModel):
# Annotations
# --------------------------------------------------
RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
]
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
RGB_UNION = (("*", "*"), ("*", "*", 3), ("*", "*", 3, 4))
UNION_TYPE: TypeAlias = Union[np.uint32, np.float32]
SHAPE_CASES = (
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
ValidationCase(shape=(10, 10, 10, 10), passes=False, id="extra dimension"),
ValidationCase(shape=(11, 10, 10), passes=False, id="dimension too large"),
ValidationCase(shape=(9, 10, 10), passes=False, id="dimension too small"),
ValidationCase(shape=(10, 10, 9), passes=True, id="wildcard smaller"),
ValidationCase(shape=(10, 10, 11), passes=True, id="wildcard larger"),
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True, id="Union 2D"),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True, id="Union 3D"),
ValidationCase(shape=(10, 10, 2, 2), passes=True, id="valid shape"),
ValidationCase(shape=(10, 10, 2), passes=False, id="missing dimension"),
ValidationCase(shape=(10, 10, 2, 2, 2), passes=False, id="extra dimension"),
ValidationCase(shape=(11, 10, 2, 2), passes=False, id="dimension too large"),
ValidationCase(shape=(9, 10, 2, 2), passes=False, id="dimension too small"),
ValidationCase(shape=(10, 10, 1, 1), passes=True, id="wildcard smaller"),
ValidationCase(shape=(10, 10, 3, 3), passes=True, id="wildcard larger"),
ValidationCase(
annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True, id="Union 4D"
annotation_shape=RGB_UNION, shape=(5, 5), passes=True, id="Union 2D"
),
ValidationCase(
annotation=RGB_UNION, shape=(5, 5, 4), passes=False, id="Union incorrect 3D"
annotation_shape=RGB_UNION, shape=(5, 5, 3), passes=True, id="Union 3D"
),
ValidationCase(
annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False, id="Union incorrect 4D"
annotation_shape=RGB_UNION, shape=(5, 5, 3, 4), passes=True, id="Union 4D"
),
ValidationCase(
annotation=RGB_UNION,
annotation_shape=RGB_UNION,
shape=(5, 5, 4),
passes=False,
id="Union incorrect 3D",
),
ValidationCase(
annotation_shape=RGB_UNION,
shape=(5, 5, 3, 6),
passes=False,
id="Union incorrect 4D",
),
ValidationCase(
annotation_shape=RGB_UNION,
shape=(5, 5, 4, 6),
passes=False,
id="Union incorrect both",
@ -138,91 +89,144 @@ DTYPE_CASES = [
ValidationCase(dtype=float, passes=True, id="float"),
ValidationCase(dtype=int, passes=False, id="int"),
ValidationCase(dtype=np.uint8, passes=False, id="uint8"),
ValidationCase(annotation=NUMBER, dtype=int, passes=True, id="number-int"),
ValidationCase(annotation=NUMBER, dtype=float, passes=True, id="number-float"),
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True, id="number-uint8"),
ValidationCase(annotation_dtype=Number, dtype=int, passes=True, id="number-int"),
ValidationCase(
annotation=NUMBER, dtype=np.float16, passes=True, id="number-float16"
),
ValidationCase(annotation=NUMBER, dtype=str, passes=False, id="number-str"),
ValidationCase(annotation=INTEGER, dtype=int, passes=True, id="integer-int"),
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True, id="integer-uint8"),
ValidationCase(annotation=INTEGER, dtype=float, passes=False, id="integer-float"),
ValidationCase(
annotation=INTEGER, dtype=np.float32, passes=False, id="integer-float32"
),
ValidationCase(annotation=INTEGER, dtype=str, passes=False, id="integer-str"),
ValidationCase(annotation=FLOAT, dtype=float, passes=True, id="float-float"),
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True, id="float-float32"),
ValidationCase(annotation=FLOAT, dtype=int, passes=False, id="float-int"),
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False, id="float-uint8"),
ValidationCase(annotation=FLOAT, dtype=str, passes=False, id="float-str"),
ValidationCase(annotation=STRING, dtype=str, passes=True, id="str-str"),
ValidationCase(annotation=STRING, dtype=int, passes=False, id="str-int"),
ValidationCase(annotation=STRING, dtype=float, passes=False, id="str-float"),
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True, id="model-model"),
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False, id="model-badmodel"),
ValidationCase(annotation=MODEL, dtype=int, passes=False, id="model-int"),
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True, id="model-subclass"),
ValidationCase(
annotation=UNION_TYPE, dtype=np.uint32, passes=True, id="union-type-uint32"
annotation_dtype=Number, dtype=float, passes=True, id="number-float"
),
ValidationCase(
annotation=UNION_TYPE, dtype=np.float32, passes=True, id="union-type-float32"
annotation_dtype=Number, dtype=np.uint8, passes=True, id="number-uint8"
),
ValidationCase(
annotation=UNION_TYPE, dtype=np.uint64, passes=False, id="union-type-uint64"
annotation_dtype=Number, dtype=np.float16, passes=True, id="number-float16"
),
ValidationCase(annotation_dtype=Number, dtype=str, passes=False, id="number-str"),
ValidationCase(annotation_dtype=Integer, dtype=int, passes=True, id="integer-int"),
ValidationCase(
annotation_dtype=Integer, dtype=np.uint8, passes=True, id="integer-uint8"
),
ValidationCase(
annotation=UNION_TYPE, dtype=np.float64, passes=False, id="union-type-float64"
annotation_dtype=Integer, dtype=float, passes=False, id="integer-float"
),
ValidationCase(
annotation_dtype=Integer, dtype=np.float32, passes=False, id="integer-float32"
),
ValidationCase(annotation_dtype=Integer, dtype=str, passes=False, id="integer-str"),
ValidationCase(annotation_dtype=Float, dtype=float, passes=True, id="float-float"),
ValidationCase(
annotation_dtype=Float, dtype=np.float32, passes=True, id="float-float32"
),
ValidationCase(annotation_dtype=Float, dtype=int, passes=False, id="float-int"),
ValidationCase(
annotation_dtype=Float, dtype=np.uint8, passes=False, id="float-uint8"
),
ValidationCase(annotation_dtype=Float, dtype=str, passes=False, id="float-str"),
ValidationCase(annotation_dtype=str, dtype=str, passes=True, id="str-str"),
ValidationCase(annotation_dtype=str, dtype=int, passes=False, id="str-int"),
ValidationCase(annotation_dtype=str, dtype=float, passes=False, id="str-float"),
ValidationCase(
annotation_dtype=BasicModel, dtype=BasicModel, passes=True, id="model-model"
),
ValidationCase(
annotation_dtype=BasicModel, dtype=BadModel, passes=False, id="model-badmodel"
),
ValidationCase(
annotation_dtype=BasicModel, dtype=int, passes=False, id="model-int"
),
ValidationCase(
annotation_dtype=BasicModel, dtype=SubClass, passes=True, id="model-subclass"
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.uint32,
passes=True,
id="union-type-uint32",
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float32,
passes=True,
id="union-type-float32",
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.uint64,
passes=False,
id="union-type-uint64",
),
ValidationCase(
annotation_dtype=UNION_TYPE,
dtype=np.float64,
passes=False,
id="union-type-float64",
),
ValidationCase(
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str"
),
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False, id="union-type-str"),
]
if YES_PIPE:
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
UNION_PIPE: TypeAlias = np.uint32 | np.float32
DTYPE_CASES.extend(
[
ValidationCase(
annotation=UNION_PIPE,
annotation_dtype=UNION_PIPE,
dtype=np.uint32,
passes=True,
id="union-pipe-uint32",
),
ValidationCase(
annotation=UNION_PIPE,
annotation_dtype=UNION_PIPE,
dtype=np.float32,
passes=True,
id="union-pipe-float32",
),
ValidationCase(
annotation=UNION_PIPE,
annotation_dtype=UNION_PIPE,
dtype=np.uint64,
passes=False,
id="union-pipe-uint64",
),
ValidationCase(
annotation=UNION_PIPE,
annotation_dtype=UNION_PIPE,
dtype=np.float64,
passes=False,
id="union-pipe-float64",
),
ValidationCase(
annotation=UNION_PIPE, dtype=str, passes=False, id="union-pipe-str"
annotation_dtype=UNION_PIPE,
dtype=str,
passes=False,
id="union-pipe-str",
),
]
)
_INTERFACE_CASES = [
NumpyCase,
HDF5Case,
HDF5CompoundCase,
DaskCase,
ZarrCase,
ZarrDirCase,
ZarrZipCase,
ZarrNestedCase,
VideoCase,
INTERFACE_CASES = [
ValidationCase(interface=NumpyCase, id="numpy"),
ValidationCase(interface=HDF5Case, id="hdf5"),
ValidationCase(interface=HDF5CompoundCase, id="hdf5_compound"),
ValidationCase(interface=DaskCase, id="dask"),
ValidationCase(interface=ZarrCase, id="zarr"),
ValidationCase(interface=ZarrDirCase, id="zarr_dir"),
ValidationCase(interface=ZarrZipCase, id="zarr_zip"),
ValidationCase(interface=ZarrNestedCase, id="zarr_nested"),
ValidationCase(interface=VideoCase, id="video"),
]
DTYPE_AND_SHAPE_CASES = merged_product(SHAPE_CASES, DTYPE_CASES)
DTYPE_AND_SHAPE_CASES_PASSING = merged_product(
SHAPE_CASES, DTYPE_CASES, conditions={"passes": True}
)
DTYPE_AND_INTERFACE_CASES = merged_product(INTERFACE_CASES, DTYPE_CASES)
DTYPE_AND_INTERFACE_CASES_PASSING = merged_product(
INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True}
)
ALL_CASES = merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES)
ALL_CASES_PASSING = merged_product(
SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True}
)

View file

@ -1,7 +1,10 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from functools import reduce
from itertools import product
from operator import ior
from pathlib import Path
from typing import Any, Optional, Tuple, Type, Union
from typing import Generator, List, Literal, Optional, Tuple, Type, Union
import numpy as np
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
@ -101,6 +104,9 @@ class InterfaceCase(ABC):
return False
_a_shape_type = Tuple[Union[int, Literal["*"], Literal["..."]], ...]
class ValidationCase(BaseModel):
"""
Test case for validating an array.
@ -113,24 +119,56 @@ class ValidationCase(BaseModel):
"""
String identifying the validation case
"""
annotation: Any = NDArray[Shape["10, 10, *"], Float]
annotation_shape: Union[
Tuple[Union[int, str], ...], Tuple[Tuple[Union[int, str], ...], ...]
] = (10, 10, "*", "*")
"""
Array annotation used in the validating model
Any typed because the types of type annotations are weird
Shape to use in computed annotation used to validate against
"""
shape: Tuple[int, ...] = (10, 10, 10)
annotation_dtype: Union[DtypeType, Sequence[DtypeType]] = Float
"""
Dtype to use in computed annotation used to validate against
"""
shape: Tuple[int, ...] = (10, 10, 2, 2)
"""Shape of the array to validate"""
dtype: Union[Type, np.dtype] = float
"""Dtype of the array to validate"""
passes: bool = False
"""Whether the validation should pass or not"""
interface: Optional[InterfaceCase] = None
interface: Optional[Type[InterfaceCase]] = None
"""The interface test case to generate and validate the array with"""
path: Optional[Path] = None
"""The path to generate arrays into, if any."""
model_config = ConfigDict(arbitrary_types_allowed=True)
@computed_field()
def annotation(self) -> NDArray:
"""
Annotation used in the model we validate against
"""
# make a union type if we need to
shape_union = all(isinstance(s, Sequence) for s in self.annotation_shape)
dtype_union = isinstance(self.annotation_dtype, Sequence) and all(
isinstance(s, Sequence) for s in self.annotation_dtype
)
if shape_union or dtype_union:
shape_iter = (
self.annotation_shape if shape_union else [self.annotation_shape]
)
dtype_iter = (
self.annotation_dtype if dtype_union else [self.annotation_dtype]
)
annotations: List[type] = []
for shape, dtype in product(shape_iter, dtype_iter):
shape_str = ", ".join([str(i) for i in shape])
annotations.append(NDArray[Shape[shape_str], dtype])
return Union[tuple(annotations)]
else:
shape_str = ", ".join([str(i) for i in self.annotation_shape])
return NDArray[Shape[shape_str], self.annotation_dtype]
@computed_field()
def model(self) -> Type[BaseModel]:
"""A model with a field ``array`` with the given annotation"""
@ -186,31 +224,8 @@ class ValidationCase(BaseModel):
"""
if isinstance(other, Sequence):
return merge_cases(self, *other)
self_dump = self.model_dump(exclude_unset=True)
other_dump = other.model_dump(exclude_unset=True)
# dumps might not have set `valid`, use only the ones that have
valids = [
v
for v in [self_dump.get("valid", None), other_dump.get("valid", None)]
if v is not None
]
valid = all(valids)
# combine ids if present
ids = "-".join(
[
str(v)
for v in [self_dump.get("id", None), other_dump.get("id", None)]
if v is not None
]
)
merged = {**self_dump, **other_dump}
merged["valid"] = valid
merged["id"] = ids
return ValidationCase(**merged)
else:
return merge_cases(self, other)
def skip(self) -> bool:
"""
@ -230,7 +245,73 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
if len(args) == 1:
return args[0]
case = args[0]
for arg in args[1:]:
case = case.merge(arg)
return case
dumped = [
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args
]
# self_dump = self.model_dump(exclude_unset=True)
# other_dump = other.model_dump(exclude_unset=True)
# dumps might not have set `passes`, use only the ones that have
passes = [v.get("passes") for v in dumped if "passes" in v]
passes = all(passes)
# combine ids if present
ids = "-".join([str(v.get("id")) for v in dumped if "id" in v])
# merge dicts
merged = reduce(ior, dumped, {})
merged["passes"] = passes
merged["id"] = ids
return ValidationCase.model_construct(**merged)
def merged_product(
*args: Sequence[ValidationCase], conditions: dict = None
) -> Generator[ValidationCase, None, None]:
"""
Generator for the product of the iterators of validation cases,
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
or not.
Examples:
.. code-block:: python
shape_cases = [
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
]
dtype_cases = [
ValidationCase(dtype=float, passes=True, id="float"),
ValidationCase(dtype=int, passes=False, id="int"),
]
iterator = merged_product(shape_cases, dtype_cases))
next(iterator)
# ValidationCase(
# shape=(10, 10, 10),
# dtype=float,
# passes=True,
# id="valid shape-float"
# )
next(iterator)
# ValidationCase(
# shape=(10, 10, 10),
# dtype=int,
# passes=False,
# id="valid shape-int"
# )
"""
iterator = product(*args)
for case_tuple in iterator:
case = merge_cases(*case_tuple)
if case.skip():
continue
if conditions:
matching = all([getattr(case, k, None) == v for k, v in conditions.items()])
if not matching:
continue
yield case

View file

@ -154,7 +154,7 @@ class _ZarrMetaCase(InterfaceCase):
@classmethod
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
return issubclass(dtype, BaseModel)
return issubclass(dtype, BaseModel) or dtype is str
class ZarrCase(_ZarrMetaCase):
@ -239,8 +239,8 @@ class VideoCase(InterfaceCase):
@classmethod
def make_array(
cls,
shape: Tuple[int, ...] = (10, 10),
dtype: DtypeType = float,
shape: Tuple[int, ...] = (10, 10, 10, 3),
dtype: DtypeType = np.uint8,
path: Optional[Path] = None,
array: Optional[NDArrayType] = None,
) -> Optional[Path]:
@ -269,20 +269,26 @@ class VideoCase(InterfaceCase):
frame = array[i]
else:
# make fresh array every time bc opencv eats them
frame = np.zeros(frame_shape, dtype=np.uint8)
if not is_color:
frame[i, i] = i
else:
frame[i, i, :] = i
frame = np.full(frame_shape, fill_value=i, dtype=np.uint8)
writer.write(frame)
writer.release()
return video_path
@classmethod
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
"""We really can only handle 3-4 dimensional cases in 8-bit rn lol"""
if len(shape) < 3 or len(shape) > 4:
"""
We really can only handle 4 dimensional cases in 8-bit rn lol
.. todo::
Fix shape/writing for grayscale videos
"""
if len(shape) != 4:
return True
# if len(shape) < 3 or len(shape) > 4:
# return True
if dtype not in (int, np.uint8):
return True
# if we have a color video (ie. shape == 4, needs to be RGB)

View file

@ -1,11 +1,11 @@
import inspect
from typing import Callable, Tuple, Type
import pytest
from pydantic import BaseModel
from numpydantic import NDArray, interface
from numpydantic.testing.helpers import InterfaceCase
from numpydantic.testing.cases import (
ALL_CASES,
ALL_CASES_PASSING,
DTYPE_AND_INTERFACE_CASES_PASSING,
)
from numpydantic.testing.helpers import InterfaceCase, ValidationCase, merge_cases
from numpydantic.testing.interfaces import (
DaskCase,
HDF5Case,
@ -21,76 +21,130 @@ from numpydantic.testing.interfaces import (
scope="function",
params=[
pytest.param(
([[1, 2], [3, 4]], interface.NumpyInterface),
marks=pytest.mark.numpy,
id="numpy-list",
),
pytest.param(
(NumpyCase, interface.NumpyInterface),
NumpyCase,
marks=pytest.mark.numpy,
id="numpy",
),
pytest.param(
(HDF5Case, interface.H5Interface),
HDF5Case,
marks=pytest.mark.hdf5,
id="h5-array-path",
),
pytest.param(
(DaskCase, interface.DaskInterface),
DaskCase,
marks=pytest.mark.dask,
id="dask",
),
pytest.param(
(ZarrCase, interface.ZarrInterface),
ZarrCase,
marks=pytest.mark.zarr,
id="zarr-memory",
),
pytest.param(
(ZarrNestedCase, interface.ZarrInterface),
ZarrNestedCase,
marks=pytest.mark.zarr,
id="zarr-nested",
),
pytest.param(
(ZarrDirCase, interface.ZarrInterface),
ZarrDirCase,
marks=pytest.mark.zarr,
id="zarr-dir",
),
pytest.param(
(VideoCase, interface.VideoInterface), marks=pytest.mark.video, id="video"
),
pytest.param(VideoCase, marks=pytest.mark.video, id="video"),
],
)
def interface_type(
request, tmp_output_dir_func
) -> Tuple[NDArray, Type[interface.Interface]]:
def interface_cases(request) -> InterfaceCase:
"""
Test cases for each interface's ``check`` method - each input should match the
provided interface and that interface only
Fixture for combinatoric tests across all interface cases
"""
return request.param
@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES
)
)
def all_cases(interface_cases, request) -> ValidationCase:
"""
Combinatoric testing for all dtype, shape, and interface cases.
This is a very expensive fixture! Only use it for core functionality
that we want to be sure is *very true* in every circumstance,
INCLUDING invalid combinations of annotations and arrays.
Typically, that means only use this in `test_interfaces.py`
"""
if inspect.isclass(request.param[0]) and issubclass(
request.param[0], InterfaceCase
):
array = request.param[0].make_array(path=tmp_output_dir_func)
if array is None:
pytest.skip()
return array, request.param[1]
else:
return request.param
case = merge_cases(request.param, ValidationCase(interface=interface_cases))
if case.skip():
pytest.skip()
return case
@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES_PASSING
)
)
def all_passing_cases(request) -> ValidationCase:
"""
Combinatoric testing for all dtype, shape, and interface cases,
but only the combinations that we expect to pass.
This is a very expensive fixture! Only use it for core functionality
that we want to be sure is *very true* in every circumstance.
Typically, that means only use this in `test_interfaces.py`
"""
return request.param
@pytest.fixture()
def all_interfaces(interface_type) -> BaseModel:
def all_cases_instance(all_cases, tmp_output_dir_func):
"""
An instantiated version of each interface within a basemodel,
with the array in an `array` field
all_cases but with an instantiated model
Args:
all_cases:
Returns:
"""
array, interface = interface_type
if isinstance(array, Callable):
array = array()
class MyModel(BaseModel):
array: NDArray
instance = MyModel(array=array)
array = all_cases.array(path=tmp_output_dir_func)
instance = all_cases.model(array=array)
return instance
@pytest.fixture()
def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
"""
all_cases but with an instantiated model
Args:
all_cases:
Returns:
"""
array = all_passing_cases.array(path=tmp_output_dir_func)
instance = all_passing_cases.model(array=array)
return instance
@pytest.fixture(
params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in DTYPE_AND_INTERFACE_CASES_PASSING
)
)
def dtype_by_interface(request):
"""
Tests for all dtypes by all interfaces
"""
return request.param
@pytest.fixture()
def dtype_by_interface_instance(dtype_by_interface, tmp_output_dir_func):
array = dtype_by_interface.array(path=tmp_output_dir_func)
instance = dtype_by_interface.model(array=array)
return instance

View file

@ -16,11 +16,13 @@ def test_dask_enabled():
assert DaskInterface.enabled()
def test_dask_check(interface_type):
if interface_type[1] is DaskInterface:
assert DaskInterface.check(interface_type[0])
def test_dask_check(interface_cases, tmp_output_dir_func):
array = interface_cases.make_array(path=tmp_output_dir_func)
if interface_cases.interface is DaskInterface:
assert DaskInterface.check(array)
else:
assert not DaskInterface.check(interface_type[0])
assert not DaskInterface.check(array)
@pytest.mark.shape

View file

@ -43,14 +43,12 @@ def test_hdf5_dtype(dtype_cases, hdf5_cases):
dtype_cases.validate_case()
def test_hdf5_check(interface_type):
if interface_type[1] is H5Interface:
assert H5Interface.check(interface_type[0])
if isinstance(interface_type[0], H5ArrayPath):
# also test that we can instantiate from a tuple like the H5ArrayPath
assert H5Interface.check((interface_type[0].file, interface_type[0].path))
def test_hdf5_check(interface_cases, tmp_output_dir_func):
array = interface_cases.make_array(path=tmp_output_dir_func)
if interface_cases.interface is H5Interface:
assert H5Interface.check(array)
else:
assert not H5Interface.check(interface_type[0])
assert not H5Interface.check(array)
def test_hdf5_check_not_exists():

View file

@ -4,7 +4,6 @@ Tests that should be applied to all interfaces
import json
from importlib.metadata import version
from typing import Callable
import dask.array as da
import numpy as np
@ -13,76 +12,98 @@ from pydantic import BaseModel
from zarr.core import Array as ZarrArray
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
from numpydantic.testing.helpers import ValidationCase
def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool):
def _test_roundtrip(source: BaseModel, target: BaseModel):
"""Test model equality for roundtrip tests"""
if round_trip:
assert type(target.array) is type(source.array)
if isinstance(source.array, (np.ndarray, ZarrArray)):
assert np.array_equal(target.array, np.array(source.array))
elif isinstance(source.array, da.Array):
assert np.all(da.equal(target.array, source.array))
else:
assert target.array == source.array
assert target.array.dtype == source.array.dtype
else:
assert type(target.array) is type(source.array)
if isinstance(source.array, (np.ndarray, ZarrArray)):
assert np.array_equal(target.array, np.array(source.array))
elif isinstance(source.array, da.Array):
if target.array.dtype == object:
# object equality doesn't really work well with dask
# just check that the types match
target_type = type(target.array.ravel()[0].compute())
source_type = type(source.array.ravel()[0].compute())
assert target_type is source_type
else:
assert np.all(da.equal(target.array, source.array))
else:
assert target.array == source.array
assert target.array.dtype == source.array.dtype
def test_dunder_len(all_interfaces):
def test_dunder_len(interface_cases, tmp_output_dir_func):
"""
Each interface or proxy type should support __len__
"""
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
case = ValidationCase(interface=interface_cases)
if interface_cases.interface.name == "video":
case.shape = (10, 10, 2, 3)
case.dtype = np.uint8
case.annotation_dtype = np.uint8
case.annotation_shape = (10, 10, "*", 3)
array = case.array(path=tmp_output_dir_func)
instance = case.model(array=array)
assert len(instance.array) == case.shape[0]
def test_interface_revalidate(all_interfaces):
def test_interface_revalidate(all_passing_cases_instance):
"""
An interface should revalidate with the output of its initial validation
See: https://github.com/p2p-ld/numpydantic/pull/14
"""
_ = type(all_interfaces)(array=all_interfaces.array)
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)
def test_interface_rematch(interface_type):
@pytest.mark.xfail
def test_interface_rematch(interface_cases, tmp_output_dir_func):
"""
All interfaces should match the results of the object they return after validation
"""
array, interface = interface_type
if isinstance(array, Callable):
array = array()
array = interface_cases.make_array(path=tmp_output_dir_func)
assert Interface.match(interface().validate(array)) is interface
assert (
Interface.match(interface_cases.interface.validate(array))
is interface_cases.interface
)
def test_interface_to_numpy_array(all_interfaces):
def test_interface_to_numpy_array(dtype_by_interface):
"""
All interfaces should be able to have the output of their validation stage
coerced to a numpy array with np.array()
"""
_ = np.array(all_interfaces.array)
_ = np.array(dtype_by_interface.array)
@pytest.mark.serialization
def test_interface_dump_json(all_interfaces):
def test_interface_dump_json(dtype_by_interface_instance):
"""
All interfaces should be able to dump to json
"""
all_interfaces.model_dump_json()
dtype_by_interface_instance.model_dump_json()
@pytest.mark.serialization
@pytest.mark.parametrize("round_trip", [True, False])
def test_interface_roundtrip_json(all_interfaces, round_trip):
def test_interface_roundtrip_json(dtype_by_interface, tmp_output_dir_func):
"""
All interfaces should be able to roundtrip to and from json
"""
dumped_json = all_interfaces.model_dump_json(round_trip=round_trip)
model = all_interfaces.model_validate_json(dumped_json)
_test_roundtrip(all_interfaces, model, round_trip)
if "subclass" in dtype_by_interface.id.lower():
pytest.xfail()
array = dtype_by_interface.array(path=tmp_output_dir_func)
case = dtype_by_interface.model(array=array)
dumped_json = case.model_dump_json(round_trip=True)
model = case.model_validate_json(dumped_json)
_test_roundtrip(case, model)
@pytest.mark.serialization
@ -101,15 +122,20 @@ def test_interface_mark_interface(an_interface):
@pytest.mark.serialization
@pytest.mark.parametrize("valid", [True, False])
@pytest.mark.parametrize("round_trip", [True, False])
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
def test_interface_mark_roundtrip(dtype_by_interface, valid, tmp_output_dir_func):
"""
All interfaces should be able to roundtrip with the marked interface,
and a mismatch should raise a warning and attempt to proceed
"""
dumped_json = all_interfaces.model_dump_json(
round_trip=round_trip, context={"mark_interface": True}
if "subclass" in dtype_by_interface.id.lower():
pytest.xfail()
array = dtype_by_interface.array(path=tmp_output_dir_func)
case = dtype_by_interface.model(array=array)
dumped_json = case.model_dump_json(
round_trip=True, context={"mark_interface": True}
)
data = json.loads(dumped_json)
@ -123,8 +149,8 @@ def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
dumped_json = json.dumps(data)
with pytest.warns(match="Mismatch.*"):
model = all_interfaces.model_validate_json(dumped_json)
model = case.model_validate_json(dumped_json)
else:
model = all_interfaces.model_validate_json(dumped_json)
model = case.model_validate_json(dumped_json)
_test_roundtrip(all_interfaces, model, round_trip)
_test_roundtrip(case, model)

View file

@ -80,15 +80,12 @@ def test_video_getitem(avi_video):
instance = MyModel(array=vid)
fifth_frame = instance.array[5]
# the first frame should have 1's in the 1,1 position
# the fifth frame should be all 5s
assert (fifth_frame[5, 5, :] == [5, 5, 5]).all()
# and nothing in the 6th position
assert (fifth_frame[6, 6, :] == [0, 0, 0]).all()
# slicing should also work as if it were just a numpy array
single_slice = instance.array[3, 0:10, 0:5]
assert single_slice[3, 3, 0] == 3
assert single_slice[4, 4, 0] == 0
assert single_slice.shape == (10, 5, 3)
# also get a range of frames
@ -96,19 +93,19 @@ def test_video_getitem(avi_video):
range_slice = instance.array[3:5]
assert range_slice.shape == (2, 100, 50, 3)
assert range_slice[0, 3, 3, 0] == 3
assert range_slice[0, 4, 4, 0] == 0
assert range_slice[1, 4, 4, 0] == 4
# full range
range_slice = instance.array[3:5, 0:10, 0:5]
assert range_slice.shape == (2, 10, 5, 3)
assert range_slice[0, 3, 3, 0] == 3
assert range_slice[0, 4, 4, 0] == 0
assert range_slice[1, 4, 4, 0] == 4
# starting range
range_slice = instance.array[6:, 0:10, 0:10]
assert range_slice.shape == (4, 10, 10, 3)
assert range_slice[-1, 9, 9, 0] == 9
assert range_slice[-2, 9, 9, 0] == 0
assert range_slice[-2, 9, 9, 0] == 8
# ending range
range_slice = instance.array[:3, 0:5, 0:5]
@ -119,10 +116,8 @@ def test_video_getitem(avi_video):
# second slice should be the second frame (instead of the first)
assert range_slice.shape == (3, 6, 6, 3)
assert range_slice[1, 2, 2, 0] == 2
assert range_slice[1, 3, 3, 0] == 0
# and the third should be the fourth (instead of the second)
assert range_slice[2, 4, 4, 0] == 4
assert range_slice[2, 5, 5, 0] == 0
with pytest.raises(NotImplementedError):
# shouldn't be allowed to set

View file

@ -38,14 +38,15 @@ def test_zarr_enabled():
assert ZarrInterface.enabled()
def test_zarr_check(interface_type):
def test_zarr_check(interface_cases, tmp_output_dir_func):
"""
We should only use the zarr interface for zarr-like things
"""
if interface_type[1] is ZarrInterface:
assert ZarrInterface.check(interface_type[0])
array = interface_cases.make_array(path=tmp_output_dir_func)
if interface_cases.interface is ZarrInterface:
assert ZarrInterface.check(array)
else:
assert not ZarrInterface.check(interface_type[0])
assert not ZarrInterface.check(array)
@pytest.mark.shape