mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
interface cases
This commit is contained in:
parent
ad060ce40d
commit
e701bf6e9b
8 changed files with 542 additions and 123 deletions
|
@ -2,6 +2,15 @@
|
||||||
|
|
||||||
Utilities for testing and 3rd-party interface development.
|
Utilities for testing and 3rd-party interface development.
|
||||||
|
|
||||||
|
Only things that *don't* require pytest go in this module.
|
||||||
|
We want to keep all test-time specific behavior there,
|
||||||
|
and have this just serve as helpers exposed for downstream interface development.
|
||||||
|
|
||||||
|
We want to avoid pytest stuff bleeding in here because then we limit
|
||||||
|
the ability for downstream developers to configure their own tests.
|
||||||
|
|
||||||
|
*(If there is some reason to change this division of labor, just raise an issue and let's chat.)*
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
cases
|
cases
|
||||||
helpers
|
helpers
|
||||||
|
|
|
@ -131,6 +131,9 @@ markers = [
|
||||||
"zarr: zarr interface",
|
"zarr: zarr interface",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
target-version = ["py39", "py310", "py311", "py312"]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
include = ["src/numpydantic/**/*.py", "tests/**/*.py", "pyproject.toml"]
|
include = ["src/numpydantic/**/*.py", "tests/**/*.py", "pyproject.toml"]
|
||||||
|
|
|
@ -3,7 +3,7 @@ Interfaces between nptyping types and array backends
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from numpydantic.interface.dask import DaskInterface
|
from numpydantic.interface.dask import DaskInterface
|
||||||
from numpydantic.interface.hdf5 import H5Interface
|
from numpydantic.interface.hdf5 import H5ArrayPath, H5Interface
|
||||||
from numpydantic.interface.interface import (
|
from numpydantic.interface.interface import (
|
||||||
Interface,
|
Interface,
|
||||||
InterfaceMark,
|
InterfaceMark,
|
||||||
|
@ -12,10 +12,11 @@ from numpydantic.interface.interface import (
|
||||||
)
|
)
|
||||||
from numpydantic.interface.numpy import NumpyInterface
|
from numpydantic.interface.numpy import NumpyInterface
|
||||||
from numpydantic.interface.video import VideoInterface
|
from numpydantic.interface.video import VideoInterface
|
||||||
from numpydantic.interface.zarr import ZarrInterface
|
from numpydantic.interface.zarr import ZarrArrayPath, ZarrInterface
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DaskInterface",
|
"DaskInterface",
|
||||||
|
"H5ArrayPath",
|
||||||
"H5Interface",
|
"H5Interface",
|
||||||
"Interface",
|
"Interface",
|
||||||
"InterfaceMark",
|
"InterfaceMark",
|
||||||
|
@ -23,5 +24,6 @@ __all__ = [
|
||||||
"MarkedJson",
|
"MarkedJson",
|
||||||
"NumpyInterface",
|
"NumpyInterface",
|
||||||
"VideoInterface",
|
"VideoInterface",
|
||||||
|
"ZarrArrayPath",
|
||||||
"ZarrInterface",
|
"ZarrInterface",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
from numpydantic.testing.helpers import InterfaceCase, ValidationCase
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"InterfaceCase",
|
||||||
|
"ValidationCase",
|
||||||
|
]
|
|
@ -1,12 +1,25 @@
|
||||||
import sys
|
import sys
|
||||||
from typing import Union
|
from collections.abc import Sequence
|
||||||
|
from itertools import product
|
||||||
|
from typing import Generator, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from numpydantic.dtype import Float, Integer, Number
|
from numpydantic.dtype import Float, Integer, Number
|
||||||
from numpydantic.testing.helpers import ValidationCase
|
from numpydantic.testing.helpers import ValidationCase, merge_cases
|
||||||
|
from numpydantic.testing.interfaces import (
|
||||||
|
DaskCase,
|
||||||
|
HDF5Case,
|
||||||
|
HDF5CompoundCase,
|
||||||
|
NumpyCase,
|
||||||
|
VideoCase,
|
||||||
|
ZarrCase,
|
||||||
|
ZarrDirCase,
|
||||||
|
ZarrNestedCase,
|
||||||
|
ZarrZipCase,
|
||||||
|
)
|
||||||
|
|
||||||
if sys.version_info.minor >= 10:
|
if sys.version_info.minor >= 10:
|
||||||
from typing import TypeAlias
|
from typing import TypeAlias
|
||||||
|
@ -30,6 +43,10 @@ class SubClass(BasicModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------
|
||||||
|
# Annotations
|
||||||
|
# --------------------------------------------------
|
||||||
|
|
||||||
RGB_UNION: TypeAlias = Union[
|
RGB_UNION: TypeAlias = Union[
|
||||||
NDArray[Shape["* x, * y"], Number],
|
NDArray[Shape["* x, * y"], Number],
|
||||||
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||||
|
@ -42,89 +59,159 @@ STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||||
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||||
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
|
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
|
||||||
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
|
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
|
||||||
|
|
||||||
|
SHAPE_CASES = (
|
||||||
|
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
|
||||||
|
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
|
||||||
|
ValidationCase(shape=(10, 10, 10, 10), passes=False, id="extra dimension"),
|
||||||
|
ValidationCase(shape=(11, 10, 10), passes=False, id="dimension too large"),
|
||||||
|
ValidationCase(shape=(9, 10, 10), passes=False, id="dimension too small"),
|
||||||
|
ValidationCase(shape=(10, 10, 9), passes=True, id="wildcard smaller"),
|
||||||
|
ValidationCase(shape=(10, 10, 11), passes=True, id="wildcard larger"),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True, id="Union 2D"),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True, id="Union 3D"),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True, id="Union 4D"
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=RGB_UNION, shape=(5, 5, 4), passes=False, id="Union incorrect 3D"
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False, id="Union incorrect 4D"
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=RGB_UNION,
|
||||||
|
shape=(5, 5, 4, 6),
|
||||||
|
passes=False,
|
||||||
|
id="Union incorrect both",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DTYPE_CASES = [
|
DTYPE_CASES = [
|
||||||
ValidationCase(dtype=float, passes=True),
|
ValidationCase(dtype=float, passes=True, id="float"),
|
||||||
ValidationCase(dtype=int, passes=False),
|
ValidationCase(dtype=int, passes=False, id="int"),
|
||||||
ValidationCase(dtype=np.uint8, passes=False),
|
ValidationCase(dtype=np.uint8, passes=False, id="uint8"),
|
||||||
ValidationCase(annotation=NUMBER, dtype=int, passes=True),
|
ValidationCase(annotation=NUMBER, dtype=int, passes=True, id="number-int"),
|
||||||
ValidationCase(annotation=NUMBER, dtype=float, passes=True),
|
ValidationCase(annotation=NUMBER, dtype=float, passes=True, id="number-float"),
|
||||||
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True),
|
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True, id="number-uint8"),
|
||||||
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True),
|
ValidationCase(
|
||||||
ValidationCase(annotation=NUMBER, dtype=str, passes=False),
|
annotation=NUMBER, dtype=np.float16, passes=True, id="number-float16"
|
||||||
ValidationCase(annotation=INTEGER, dtype=int, passes=True),
|
),
|
||||||
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
ValidationCase(annotation=NUMBER, dtype=str, passes=False, id="number-str"),
|
||||||
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=int, passes=True, id="integer-int"),
|
||||||
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True, id="integer-uint8"),
|
||||||
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=float, passes=False, id="integer-float"),
|
||||||
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
ValidationCase(
|
||||||
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
annotation=INTEGER, dtype=np.float32, passes=False, id="integer-float32"
|
||||||
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
),
|
||||||
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=str, passes=False, id="integer-str"),
|
||||||
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=float, passes=True, id="float-float"),
|
||||||
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True, id="float-float32"),
|
||||||
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=int, passes=False, id="float-int"),
|
||||||
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False, id="float-uint8"),
|
||||||
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
ValidationCase(annotation=FLOAT, dtype=str, passes=False, id="float-str"),
|
||||||
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
ValidationCase(annotation=STRING, dtype=str, passes=True, id="str-str"),
|
||||||
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
ValidationCase(annotation=STRING, dtype=int, passes=False, id="str-int"),
|
||||||
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
|
ValidationCase(annotation=STRING, dtype=float, passes=False, id="str-float"),
|
||||||
ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True),
|
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True, id="model-model"),
|
||||||
ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True),
|
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False, id="model-badmodel"),
|
||||||
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False),
|
ValidationCase(annotation=MODEL, dtype=int, passes=False, id="model-int"),
|
||||||
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False),
|
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True, id="model-subclass"),
|
||||||
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False),
|
ValidationCase(
|
||||||
|
annotation=UNION_TYPE, dtype=np.uint32, passes=True, id="union-type-uint32"
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=UNION_TYPE, dtype=np.float32, passes=True, id="union-type-float32"
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=UNION_TYPE, dtype=np.uint64, passes=False, id="union-type-uint64"
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=UNION_TYPE, dtype=np.float64, passes=False, id="union-type-float64"
|
||||||
|
),
|
||||||
|
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False, id="union-type-str"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
DTYPE_IDS = [
|
|
||||||
"float",
|
|
||||||
"int",
|
|
||||||
"uint8",
|
|
||||||
"number-int",
|
|
||||||
"number-float",
|
|
||||||
"number-uint8",
|
|
||||||
"number-float16",
|
|
||||||
"number-str",
|
|
||||||
"integer-int",
|
|
||||||
"integer-uint8",
|
|
||||||
"integer-float",
|
|
||||||
"integer-float32",
|
|
||||||
"integer-str",
|
|
||||||
"float-float",
|
|
||||||
"float-float32",
|
|
||||||
"float-int",
|
|
||||||
"float-uint8",
|
|
||||||
"float-str",
|
|
||||||
"str-str",
|
|
||||||
"str-int",
|
|
||||||
"str-float",
|
|
||||||
"model-model",
|
|
||||||
"model-badmodel",
|
|
||||||
"model-int",
|
|
||||||
"model-subclass",
|
|
||||||
"union-type-uint32",
|
|
||||||
"union-type-float32",
|
|
||||||
"union-type-uint64",
|
|
||||||
"union-type-float64",
|
|
||||||
"union-type-str",
|
|
||||||
]
|
|
||||||
|
|
||||||
if YES_PIPE:
|
if YES_PIPE:
|
||||||
DTYPE_CASES.extend(
|
DTYPE_CASES.extend(
|
||||||
[
|
[
|
||||||
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True),
|
ValidationCase(
|
||||||
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True),
|
annotation=UNION_PIPE,
|
||||||
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False),
|
dtype=np.uint32,
|
||||||
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False),
|
passes=True,
|
||||||
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False),
|
id="union-pipe-uint32",
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=UNION_PIPE,
|
||||||
|
dtype=np.float32,
|
||||||
|
passes=True,
|
||||||
|
id="union-pipe-float32",
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=UNION_PIPE,
|
||||||
|
dtype=np.uint64,
|
||||||
|
passes=False,
|
||||||
|
id="union-pipe-uint64",
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=UNION_PIPE,
|
||||||
|
dtype=np.float64,
|
||||||
|
passes=False,
|
||||||
|
id="union-pipe-float64",
|
||||||
|
),
|
||||||
|
ValidationCase(
|
||||||
|
annotation=UNION_PIPE, dtype=str, passes=False, id="union-pipe-str"
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
DTYPE_IDS.extend(
|
|
||||||
[
|
_INTERFACE_CASES = [
|
||||||
"union-pipe-uint32",
|
NumpyCase,
|
||||||
"union-pipe-float32",
|
HDF5Case,
|
||||||
"union-pipe-uint64",
|
HDF5CompoundCase,
|
||||||
"union-pipe-float64",
|
DaskCase,
|
||||||
"union-pipe-str",
|
ZarrCase,
|
||||||
|
ZarrDirCase,
|
||||||
|
ZarrZipCase,
|
||||||
|
ZarrNestedCase,
|
||||||
|
VideoCase,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def merged_product(
|
||||||
|
*args: Sequence[ValidationCase],
|
||||||
|
) -> Generator[ValidationCase, None, None]:
|
||||||
|
"""
|
||||||
|
Generator for the product of the iterators of validation cases,
|
||||||
|
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
|
||||||
|
or not.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
shape_cases = [
|
||||||
|
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
|
||||||
|
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
|
||||||
]
|
]
|
||||||
)
|
dtype_cases = [
|
||||||
|
ValidationCase(dtype=float, passes=True, id="float"),
|
||||||
|
ValidationCase(dtype=int, passes=False, id="int"),
|
||||||
|
]
|
||||||
|
|
||||||
|
iterator = merged_product(shape_cases, dtype_cases))
|
||||||
|
next(iterator)
|
||||||
|
# ValidationCase(shape=(10, 10, 10), dtype=float, passes=True, id="valid shape-float")
|
||||||
|
next(iterator)
|
||||||
|
# ValidationCase(shape=(10, 10, 10), dtype=int, passes=False, id="valid shape-int")
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
iterator = product(*args)
|
||||||
|
for case_tuple in iterator:
|
||||||
|
case = merge_cases(case_tuple)
|
||||||
|
if case.skip():
|
||||||
|
continue
|
||||||
|
yield case
|
||||||
|
|
|
@ -1,10 +1,76 @@
|
||||||
from typing import Any, Tuple, Type, Union
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict, computed_field
|
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
|
||||||
|
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from numpydantic.dtype import Float
|
from numpydantic.dtype import Float
|
||||||
|
from numpydantic.interface import Interface
|
||||||
|
from numpydantic.types import NDArrayType
|
||||||
|
|
||||||
|
|
||||||
|
class InterfaceCase(ABC):
|
||||||
|
"""
|
||||||
|
An interface test helper that allows a given interface to generate and validate
|
||||||
|
arrays in one of its formats.
|
||||||
|
|
||||||
|
Each instance of "interface test case" should be considered one of the
|
||||||
|
potentially multiple realizations of a given interface.
|
||||||
|
If an interface has multiple formats (eg. zarr's different `store` s),
|
||||||
|
then it should have several test helpers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def interface(self) -> Interface:
|
||||||
|
"""The interface that this helper is for"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def generate_array(
|
||||||
|
cls, case: "ValidationCase", path: Path
|
||||||
|
) -> Optional[NDArrayType]:
|
||||||
|
"""
|
||||||
|
Generate an array from the given validation case.
|
||||||
|
|
||||||
|
Returns ``None`` if an array can't be generated for a specific case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_array(cls, case: "ValidationCase", path: Path) -> Optional[bool]:
|
||||||
|
"""
|
||||||
|
Validate a generated array against the annotation in the validation case.
|
||||||
|
|
||||||
|
Kept in the InterfaceCase in case an interface has specific
|
||||||
|
needs aside from just validating against a model, but typically left as is.
|
||||||
|
|
||||||
|
Does not raise on Validation errors -
|
||||||
|
returns bool instead for consistency's sake.
|
||||||
|
|
||||||
|
If an array can't be generated for a given case, returns `None`
|
||||||
|
so that the calling function can know to skip rather than fail the case.
|
||||||
|
"""
|
||||||
|
array = cls.generate_array(case, path)
|
||||||
|
if array is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
case.model(array=array)
|
||||||
|
# True if case is supposed to pass, False if it's not...
|
||||||
|
return case.passes
|
||||||
|
except ValidationError:
|
||||||
|
# False if the case is supposed to pass, True if it is...
|
||||||
|
return not case.passes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def skip(cls, case: "ValidationCase") -> bool:
|
||||||
|
"""
|
||||||
|
Whether a given interface should be skipped for the case
|
||||||
|
"""
|
||||||
|
# Assume an interface case is valid for all other cases
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ValidationCase(BaseModel):
|
class ValidationCase(BaseModel):
|
||||||
|
@ -15,6 +81,10 @@ class ValidationCase(BaseModel):
|
||||||
test in a given interface
|
test in a given interface
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
id: Optional[str] = None
|
||||||
|
"""
|
||||||
|
String identifying the validation case
|
||||||
|
"""
|
||||||
annotation: Any = NDArray[Shape["10, 10, *"], Float]
|
annotation: Any = NDArray[Shape["10, 10, *"], Float]
|
||||||
"""
|
"""
|
||||||
Array annotation used in the validating model
|
Array annotation used in the validating model
|
||||||
|
@ -24,8 +94,9 @@ class ValidationCase(BaseModel):
|
||||||
"""Shape of the array to validate"""
|
"""Shape of the array to validate"""
|
||||||
dtype: Union[Type, np.dtype] = float
|
dtype: Union[Type, np.dtype] = float
|
||||||
"""Dtype of the array to validate"""
|
"""Dtype of the array to validate"""
|
||||||
passes: bool
|
passes: bool = False
|
||||||
"""Whether the validation should pass or not"""
|
"""Whether the validation should pass or not"""
|
||||||
|
interface: Optional[InterfaceCase] = None
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@ -38,3 +109,62 @@ class ValidationCase(BaseModel):
|
||||||
array: annotation
|
array: annotation
|
||||||
|
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self, other: Union["ValidationCase", Sequence["ValidationCase"]]
|
||||||
|
) -> "ValidationCase":
|
||||||
|
"""
|
||||||
|
Merge two validation cases
|
||||||
|
|
||||||
|
Dump both, excluding any unset fields, and merge, preferring `other`.
|
||||||
|
|
||||||
|
``valid`` is ``True`` if and only if it is ``True`` in both.
|
||||||
|
"""
|
||||||
|
if isinstance(other, Sequence):
|
||||||
|
return merge_cases(self, *other)
|
||||||
|
|
||||||
|
self_dump = self.model_dump(exclude_unset=True)
|
||||||
|
other_dump = other.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# dumps might not have set `valid`, use only the ones that have
|
||||||
|
valids = [
|
||||||
|
v
|
||||||
|
for v in [self_dump.get("valid", None), other_dump.get("valid", None)]
|
||||||
|
if v is not None
|
||||||
|
]
|
||||||
|
valid = all(valids)
|
||||||
|
|
||||||
|
# combine ids if present
|
||||||
|
ids = "-".join(
|
||||||
|
[
|
||||||
|
str(v)
|
||||||
|
for v in [self_dump.get("id", None), other_dump.get("id", None)]
|
||||||
|
if v is not None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
merged = {**self_dump, **other_dump}
|
||||||
|
merged["valid"] = valid
|
||||||
|
merged["id"] = ids
|
||||||
|
return ValidationCase(**merged)
|
||||||
|
|
||||||
|
def skip(self) -> bool:
|
||||||
|
"""
|
||||||
|
Whether this case should be skipped
|
||||||
|
(eg. due to the interface case being incompatible
|
||||||
|
with the requested dtype or shape)
|
||||||
|
"""
|
||||||
|
return bool(self.interface is not None and self.interface.skip())
|
||||||
|
|
||||||
|
|
||||||
|
def merge_cases(*args: ValidationCase) -> ValidationCase:
|
||||||
|
"""
|
||||||
|
Merge multiple validation cases
|
||||||
|
"""
|
||||||
|
if len(args) == 1:
|
||||||
|
return args[0]
|
||||||
|
|
||||||
|
case = args[0]
|
||||||
|
for arg in args[1:]:
|
||||||
|
case = case.merge(arg)
|
||||||
|
return case
|
||||||
|
|
218
src/numpydantic/testing/interfaces.py
Normal file
218
src/numpydantic/testing/interfaces.py
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import dask.array as da
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import zarr
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from numpydantic.interface import (
|
||||||
|
DaskInterface,
|
||||||
|
H5ArrayPath,
|
||||||
|
H5Interface,
|
||||||
|
NumpyInterface,
|
||||||
|
VideoInterface,
|
||||||
|
ZarrArrayPath,
|
||||||
|
ZarrInterface,
|
||||||
|
)
|
||||||
|
from numpydantic.testing.helpers import InterfaceCase, ValidationCase
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyCase(InterfaceCase):
|
||||||
|
"""In-memory numpy array"""
|
||||||
|
|
||||||
|
interface = NumpyInterface
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(cls, case: "ValidationCase", path: Path) -> np.ndarray:
|
||||||
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
return np.full(shape=case.shape, fill_value=case.dtype(x=1))
|
||||||
|
else:
|
||||||
|
return np.zeros(shape=case.shape, dtype=case.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class _HDF5MetaCase(InterfaceCase):
|
||||||
|
"""Base case for hdf5 cases"""
|
||||||
|
|
||||||
|
interface = H5Interface
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def skip(cls, case: "ValidationCase") -> bool:
|
||||||
|
return not issubclass(case.dtype, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class HDF5Case(_HDF5MetaCase):
|
||||||
|
"""HDF5 Array"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(
|
||||||
|
cls, case: "ValidationCase", path: Path
|
||||||
|
) -> Optional[H5ArrayPath]:
|
||||||
|
if cls.skip(case):
|
||||||
|
return None
|
||||||
|
|
||||||
|
hdf5_file = path / "h5f.h5"
|
||||||
|
array_path = (
|
||||||
|
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__
|
||||||
|
)
|
||||||
|
generator = np.random.default_rng()
|
||||||
|
|
||||||
|
if case.dtype is str:
|
||||||
|
data = generator.random(case.shape).astype(bytes)
|
||||||
|
elif case.dtype is datetime:
|
||||||
|
data = np.empty(case.shape, dtype="S32")
|
||||||
|
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
|
||||||
|
else:
|
||||||
|
data = generator.random(case.shape).astype(case.dtype)
|
||||||
|
|
||||||
|
h5path = H5ArrayPath(hdf5_file, array_path)
|
||||||
|
|
||||||
|
with h5py.File(hdf5_file, "w") as h5f:
|
||||||
|
_ = h5f.create_dataset(array_path, data=data)
|
||||||
|
return h5path
|
||||||
|
|
||||||
|
|
||||||
|
class HDF5CompoundCase(_HDF5MetaCase):
|
||||||
|
"""HDF5 Array with a fake compound dtype"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(
|
||||||
|
cls, case: "ValidationCase", path: Path
|
||||||
|
) -> Optional[H5ArrayPath]:
|
||||||
|
if cls.skip(case):
|
||||||
|
return None
|
||||||
|
|
||||||
|
hdf5_file = path / "h5f.h5"
|
||||||
|
array_path = (
|
||||||
|
"/" + "_".join([str(s) for s in case.shape]) + "__" + case.dtype.__name__
|
||||||
|
)
|
||||||
|
if case.dtype is str:
|
||||||
|
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
||||||
|
data = np.array([("hey", 0)] * np.prod(case.shape), dtype=dt).reshape(
|
||||||
|
case.shape
|
||||||
|
)
|
||||||
|
elif case.dtype is datetime:
|
||||||
|
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
|
||||||
|
data = np.array(
|
||||||
|
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
|
||||||
|
* np.prod(case.shape),
|
||||||
|
dtype=dt,
|
||||||
|
).reshape(case.shape)
|
||||||
|
else:
|
||||||
|
dt = np.dtype([("data", case.dtype), ("extra", "i8")])
|
||||||
|
data = np.zeros(case.shape, dtype=dt)
|
||||||
|
h5path = H5ArrayPath(hdf5_file, array_path, "data")
|
||||||
|
|
||||||
|
with h5py.File(hdf5_file, "w") as h5f:
|
||||||
|
_ = h5f.create_dataset(array_path, data=data)
|
||||||
|
return h5path
|
||||||
|
|
||||||
|
|
||||||
|
class DaskCase(InterfaceCase):
|
||||||
|
"""In-memory dask array"""
|
||||||
|
|
||||||
|
interface = DaskInterface
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(cls, case: "ValidationCase", path: Path) -> da.Array:
|
||||||
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
|
||||||
|
else:
|
||||||
|
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
|
||||||
|
|
||||||
|
|
||||||
|
class _ZarrMetaCase(InterfaceCase):
|
||||||
|
"""Shared classmethods for zarr cases"""
|
||||||
|
|
||||||
|
interface = ZarrInterface
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def skip(cls, case: "ValidationCase") -> bool:
|
||||||
|
return not issubclass(case.dtype, BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class ZarrCase(_ZarrMetaCase):
|
||||||
|
"""In-memory zarr array"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[zarr.Array]:
|
||||||
|
return zarr.zeros(shape=case.shape, dtype=case.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class ZarrDirCase(_ZarrMetaCase):
|
||||||
|
"""On-disk zarr array"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
|
||||||
|
store = zarr.DirectoryStore(str(path / "array.zarr"))
|
||||||
|
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
|
||||||
|
|
||||||
|
|
||||||
|
class ZarrZipCase(_ZarrMetaCase):
|
||||||
|
"""Zarr zip store"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
|
||||||
|
store = zarr.ZipStore(str(path / "array.zarr"), mode="w")
|
||||||
|
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
|
||||||
|
|
||||||
|
|
||||||
|
class ZarrNestedCase(_ZarrMetaCase):
|
||||||
|
"""Nested zarr array"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(cls, case: "ValidationCase", path: Path) -> ZarrArrayPath:
|
||||||
|
file = str(path / "nested.zarr")
|
||||||
|
root = zarr.open(file, mode="w")
|
||||||
|
subpath = "a/b/c"
|
||||||
|
_ = root.zeros(subpath, shape=case.shape, dtype=case.dtype)
|
||||||
|
return ZarrArrayPath(file=file, path=subpath)
|
||||||
|
|
||||||
|
|
||||||
|
class VideoCase(InterfaceCase):
|
||||||
|
"""AVI video"""
|
||||||
|
|
||||||
|
interface = VideoInterface
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_array(cls, case: "ValidationCase", path: Path) -> Optional[Path]:
|
||||||
|
if cls.skip(case):
|
||||||
|
return None
|
||||||
|
|
||||||
|
is_color = len(case.shape) == 4
|
||||||
|
frames = case.shape[0]
|
||||||
|
frame_shape = case.shape[1:]
|
||||||
|
|
||||||
|
video_path = path / "test.avi"
|
||||||
|
writer = cv2.VideoWriter(
|
||||||
|
str(video_path),
|
||||||
|
cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes
|
||||||
|
30,
|
||||||
|
(frame_shape[1], frame_shape[0]),
|
||||||
|
is_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(frames):
|
||||||
|
# make fresh array every time bc opencv eats them
|
||||||
|
array = np.zeros(frame_shape, dtype=np.uint8)
|
||||||
|
if not is_color:
|
||||||
|
array[i, i] = i
|
||||||
|
else:
|
||||||
|
array[i, i, :] = i
|
||||||
|
writer.write(array)
|
||||||
|
writer.release()
|
||||||
|
return video_path
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def skip(cls, case: "ValidationCase") -> bool:
|
||||||
|
"""We really can only handle 3-4 dimensional cases in 8-bit rn lol"""
|
||||||
|
if len(case.shape) < 3 or len(case.shape) > 4:
|
||||||
|
return True
|
||||||
|
if case.dtype not in (int, np.uint8):
|
||||||
|
return True
|
||||||
|
# if we have a color video (ie. shape == 4, needs to be RGB)
|
||||||
|
if len(case.shape) == 4 and case.shape[3] != 3:
|
||||||
|
return True
|
|
@ -1,10 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from numpydantic.testing.cases import (
|
from numpydantic.testing.cases import DTYPE_CASES, SHAPE_CASES
|
||||||
DTYPE_CASES,
|
|
||||||
DTYPE_IDS,
|
|
||||||
RGB_UNION,
|
|
||||||
)
|
|
||||||
from numpydantic.testing.helpers import ValidationCase
|
from numpydantic.testing.helpers import ValidationCase
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
|
||||||
|
@ -17,43 +13,11 @@ def pytest_addoption(parser):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(scope="module", params=SHAPE_CASES)
|
||||||
scope="module",
|
|
||||||
params=[
|
|
||||||
ValidationCase(shape=(10, 10, 10), passes=True),
|
|
||||||
ValidationCase(shape=(10, 10), passes=False),
|
|
||||||
ValidationCase(shape=(10, 10, 10, 10), passes=False),
|
|
||||||
ValidationCase(shape=(11, 10, 10), passes=False),
|
|
||||||
ValidationCase(shape=(9, 10, 10), passes=False),
|
|
||||||
ValidationCase(shape=(10, 10, 9), passes=True),
|
|
||||||
ValidationCase(shape=(10, 10, 11), passes=True),
|
|
||||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True),
|
|
||||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True),
|
|
||||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True),
|
|
||||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False),
|
|
||||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False),
|
|
||||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False),
|
|
||||||
],
|
|
||||||
ids=[
|
|
||||||
"valid shape",
|
|
||||||
"missing dimension",
|
|
||||||
"extra dimension",
|
|
||||||
"dimension too large",
|
|
||||||
"dimension too small",
|
|
||||||
"wildcard smaller",
|
|
||||||
"wildcard larger",
|
|
||||||
"Union 2D",
|
|
||||||
"Union 3D",
|
|
||||||
"Union 4D",
|
|
||||||
"Union incorrect 3D",
|
|
||||||
"Union incorrect 4D",
|
|
||||||
"Union incorrect both",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def shape_cases(request) -> ValidationCase:
|
def shape_cases(request) -> ValidationCase:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=DTYPE_CASES, ids=DTYPE_IDS)
|
@pytest.fixture(scope="module", params=DTYPE_CASES)
|
||||||
def dtype_cases(request) -> ValidationCase:
|
def dtype_cases(request) -> ValidationCase:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
Loading…
Reference in a new issue