mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
hoo boy. working combinatoric testing.
Split out annotation dtype and shape, swap out all interface tests, fix numpy and dask model casting, make merging models more efficient, correctly parameterize and mark tests!
This commit is contained in:
parent
5d4f03a8a9
commit
1187b37b2d
12 changed files with 482 additions and 278 deletions
|
@ -5,7 +5,7 @@ Interface for Dask arrays
|
|||
from typing import Any, Iterable, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
from pydantic import BaseModel, SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
@ -70,9 +70,33 @@ class DaskInterface(Interface):
|
|||
else:
|
||||
return False
|
||||
|
||||
def before_validation(self, array: DaskArray) -> NDArrayType:
|
||||
"""
|
||||
Try and coerce dicts that should be model objects into the model objects
|
||||
"""
|
||||
try:
|
||||
if issubclass(self.dtype, BaseModel) and isinstance(
|
||||
array.reshape(-1)[0].compute(), dict
|
||||
):
|
||||
|
||||
def _chunked_to_model(array: np.ndarray) -> np.ndarray:
|
||||
def _vectorized_to_model(item: Union[dict, BaseModel]) -> BaseModel:
|
||||
if not isinstance(item, self.dtype):
|
||||
return self.dtype(**item)
|
||||
else:
|
||||
return item
|
||||
|
||||
return np.vectorize(_vectorized_to_model)(array)
|
||||
|
||||
array = array.map_blocks(_chunked_to_model, dtype=self.dtype)
|
||||
except TypeError:
|
||||
# fine, dtype isn't a type
|
||||
pass
|
||||
return array
|
||||
|
||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||
return type(array.ravel()[0].compute())
|
||||
return type(array.reshape(-1)[0].compute())
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
|
|
|
@ -4,7 +4,7 @@ Interface to numpy arrays
|
|||
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import SerializationInfo
|
||||
from pydantic import BaseModel, SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
|
||||
|
@ -59,6 +59,9 @@ class NumpyInterface(Interface):
|
|||
Check that this is in fact a numpy ndarray or something that can be
|
||||
coerced to one
|
||||
"""
|
||||
if array is None:
|
||||
return False
|
||||
|
||||
if isinstance(array, ndarray):
|
||||
return True
|
||||
elif isinstance(array, dict):
|
||||
|
@ -77,6 +80,14 @@ class NumpyInterface(Interface):
|
|||
"""
|
||||
if not isinstance(array, ndarray):
|
||||
array = np.array(array)
|
||||
|
||||
try:
|
||||
if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict):
|
||||
array = np.vectorize(lambda x: self.dtype(**x))(array)
|
||||
except TypeError:
|
||||
# fine, dtype isn't a type
|
||||
pass
|
||||
|
||||
return array
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -63,6 +63,7 @@ class ZarrJsonDict(JsonDict):
|
|||
type: Literal["zarr"]
|
||||
file: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
dtype: Optional[str] = None
|
||||
value: Optional[list] = None
|
||||
|
||||
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
||||
|
@ -73,7 +74,7 @@ class ZarrJsonDict(JsonDict):
|
|||
if self.file:
|
||||
array = ZarrArrayPath(file=self.file, path=self.path)
|
||||
else:
|
||||
array = zarr.array(self.value)
|
||||
array = zarr.array(self.value, dtype=self.dtype)
|
||||
return array
|
||||
|
||||
|
||||
|
@ -194,6 +195,7 @@ class ZarrInterface(Interface):
|
|||
is_file = False
|
||||
|
||||
as_json = {"type": cls.name}
|
||||
as_json["dtype"] = array.dtype.name
|
||||
if hasattr(array.store, "dir_path"):
|
||||
is_file = True
|
||||
as_json["file"] = array.store.dir_path()
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
import sys
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Generator, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.dtype import Float, Integer, Number
|
||||
from numpydantic.testing.helpers import ValidationCase, merge_cases
|
||||
from numpydantic.testing.helpers import ValidationCase, merged_product
|
||||
from numpydantic.testing.interfaces import (
|
||||
DaskCase,
|
||||
HDF5Case,
|
||||
|
@ -31,53 +28,6 @@ else:
|
|||
YES_PIPE = False
|
||||
|
||||
|
||||
def merged_product(
|
||||
*args: Sequence[ValidationCase],
|
||||
) -> Generator[ValidationCase, None, None]:
|
||||
"""
|
||||
Generator for the product of the iterators of validation cases,
|
||||
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
|
||||
or not.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
shape_cases = [
|
||||
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
|
||||
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
|
||||
]
|
||||
dtype_cases = [
|
||||
ValidationCase(dtype=float, passes=True, id="float"),
|
||||
ValidationCase(dtype=int, passes=False, id="int"),
|
||||
]
|
||||
|
||||
iterator = merged_product(shape_cases, dtype_cases))
|
||||
next(iterator)
|
||||
# ValidationCase(
|
||||
# shape=(10, 10, 10),
|
||||
# dtype=float,
|
||||
# passes=True,
|
||||
# id="valid shape-float"
|
||||
# )
|
||||
next(iterator)
|
||||
# ValidationCase(
|
||||
# shape=(10, 10, 10),
|
||||
# dtype=int,
|
||||
# passes=False,
|
||||
# id="valid shape-int"
|
||||
# )
|
||||
|
||||
|
||||
"""
|
||||
iterator = product(*args)
|
||||
for case_tuple in iterator:
|
||||
case = merge_cases(case_tuple)
|
||||
if case.skip():
|
||||
continue
|
||||
yield case
|
||||
|
||||
|
||||
class BasicModel(BaseModel):
|
||||
x: int
|
||||
|
||||
|
@ -94,39 +44,40 @@ class SubClass(BasicModel):
|
|||
# Annotations
|
||||
# --------------------------------------------------
|
||||
|
||||
RGB_UNION: TypeAlias = Union[
|
||||
NDArray[Shape["* x, * y"], Number],
|
||||
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
|
||||
]
|
||||
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
||||
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
|
||||
RGB_UNION = (("*", "*"), ("*", "*", 3), ("*", "*", 3, 4))
|
||||
UNION_TYPE: TypeAlias = Union[np.uint32, np.float32]
|
||||
|
||||
SHAPE_CASES = (
|
||||
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
|
||||
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
|
||||
ValidationCase(shape=(10, 10, 10, 10), passes=False, id="extra dimension"),
|
||||
ValidationCase(shape=(11, 10, 10), passes=False, id="dimension too large"),
|
||||
ValidationCase(shape=(9, 10, 10), passes=False, id="dimension too small"),
|
||||
ValidationCase(shape=(10, 10, 9), passes=True, id="wildcard smaller"),
|
||||
ValidationCase(shape=(10, 10, 11), passes=True, id="wildcard larger"),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True, id="Union 2D"),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True, id="Union 3D"),
|
||||
ValidationCase(shape=(10, 10, 2, 2), passes=True, id="valid shape"),
|
||||
ValidationCase(shape=(10, 10, 2), passes=False, id="missing dimension"),
|
||||
ValidationCase(shape=(10, 10, 2, 2, 2), passes=False, id="extra dimension"),
|
||||
ValidationCase(shape=(11, 10, 2, 2), passes=False, id="dimension too large"),
|
||||
ValidationCase(shape=(9, 10, 2, 2), passes=False, id="dimension too small"),
|
||||
ValidationCase(shape=(10, 10, 1, 1), passes=True, id="wildcard smaller"),
|
||||
ValidationCase(shape=(10, 10, 3, 3), passes=True, id="wildcard larger"),
|
||||
ValidationCase(
|
||||
annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True, id="Union 4D"
|
||||
annotation_shape=RGB_UNION, shape=(5, 5), passes=True, id="Union 2D"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=RGB_UNION, shape=(5, 5, 4), passes=False, id="Union incorrect 3D"
|
||||
annotation_shape=RGB_UNION, shape=(5, 5, 3), passes=True, id="Union 3D"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False, id="Union incorrect 4D"
|
||||
annotation_shape=RGB_UNION, shape=(5, 5, 3, 4), passes=True, id="Union 4D"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=RGB_UNION,
|
||||
annotation_shape=RGB_UNION,
|
||||
shape=(5, 5, 4),
|
||||
passes=False,
|
||||
id="Union incorrect 3D",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_shape=RGB_UNION,
|
||||
shape=(5, 5, 3, 6),
|
||||
passes=False,
|
||||
id="Union incorrect 4D",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_shape=RGB_UNION,
|
||||
shape=(5, 5, 4, 6),
|
||||
passes=False,
|
||||
id="Union incorrect both",
|
||||
|
@ -138,91 +89,144 @@ DTYPE_CASES = [
|
|||
ValidationCase(dtype=float, passes=True, id="float"),
|
||||
ValidationCase(dtype=int, passes=False, id="int"),
|
||||
ValidationCase(dtype=np.uint8, passes=False, id="uint8"),
|
||||
ValidationCase(annotation=NUMBER, dtype=int, passes=True, id="number-int"),
|
||||
ValidationCase(annotation=NUMBER, dtype=float, passes=True, id="number-float"),
|
||||
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True, id="number-uint8"),
|
||||
ValidationCase(annotation_dtype=Number, dtype=int, passes=True, id="number-int"),
|
||||
ValidationCase(
|
||||
annotation=NUMBER, dtype=np.float16, passes=True, id="number-float16"
|
||||
),
|
||||
ValidationCase(annotation=NUMBER, dtype=str, passes=False, id="number-str"),
|
||||
ValidationCase(annotation=INTEGER, dtype=int, passes=True, id="integer-int"),
|
||||
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True, id="integer-uint8"),
|
||||
ValidationCase(annotation=INTEGER, dtype=float, passes=False, id="integer-float"),
|
||||
ValidationCase(
|
||||
annotation=INTEGER, dtype=np.float32, passes=False, id="integer-float32"
|
||||
),
|
||||
ValidationCase(annotation=INTEGER, dtype=str, passes=False, id="integer-str"),
|
||||
ValidationCase(annotation=FLOAT, dtype=float, passes=True, id="float-float"),
|
||||
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True, id="float-float32"),
|
||||
ValidationCase(annotation=FLOAT, dtype=int, passes=False, id="float-int"),
|
||||
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False, id="float-uint8"),
|
||||
ValidationCase(annotation=FLOAT, dtype=str, passes=False, id="float-str"),
|
||||
ValidationCase(annotation=STRING, dtype=str, passes=True, id="str-str"),
|
||||
ValidationCase(annotation=STRING, dtype=int, passes=False, id="str-int"),
|
||||
ValidationCase(annotation=STRING, dtype=float, passes=False, id="str-float"),
|
||||
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True, id="model-model"),
|
||||
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False, id="model-badmodel"),
|
||||
ValidationCase(annotation=MODEL, dtype=int, passes=False, id="model-int"),
|
||||
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True, id="model-subclass"),
|
||||
ValidationCase(
|
||||
annotation=UNION_TYPE, dtype=np.uint32, passes=True, id="union-type-uint32"
|
||||
annotation_dtype=Number, dtype=float, passes=True, id="number-float"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=UNION_TYPE, dtype=np.float32, passes=True, id="union-type-float32"
|
||||
annotation_dtype=Number, dtype=np.uint8, passes=True, id="number-uint8"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=UNION_TYPE, dtype=np.uint64, passes=False, id="union-type-uint64"
|
||||
annotation_dtype=Number, dtype=np.float16, passes=True, id="number-float16"
|
||||
),
|
||||
ValidationCase(annotation_dtype=Number, dtype=str, passes=False, id="number-str"),
|
||||
ValidationCase(annotation_dtype=Integer, dtype=int, passes=True, id="integer-int"),
|
||||
ValidationCase(
|
||||
annotation_dtype=Integer, dtype=np.uint8, passes=True, id="integer-uint8"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=UNION_TYPE, dtype=np.float64, passes=False, id="union-type-float64"
|
||||
annotation_dtype=Integer, dtype=float, passes=False, id="integer-float"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=Integer, dtype=np.float32, passes=False, id="integer-float32"
|
||||
),
|
||||
ValidationCase(annotation_dtype=Integer, dtype=str, passes=False, id="integer-str"),
|
||||
ValidationCase(annotation_dtype=Float, dtype=float, passes=True, id="float-float"),
|
||||
ValidationCase(
|
||||
annotation_dtype=Float, dtype=np.float32, passes=True, id="float-float32"
|
||||
),
|
||||
ValidationCase(annotation_dtype=Float, dtype=int, passes=False, id="float-int"),
|
||||
ValidationCase(
|
||||
annotation_dtype=Float, dtype=np.uint8, passes=False, id="float-uint8"
|
||||
),
|
||||
ValidationCase(annotation_dtype=Float, dtype=str, passes=False, id="float-str"),
|
||||
ValidationCase(annotation_dtype=str, dtype=str, passes=True, id="str-str"),
|
||||
ValidationCase(annotation_dtype=str, dtype=int, passes=False, id="str-int"),
|
||||
ValidationCase(annotation_dtype=str, dtype=float, passes=False, id="str-float"),
|
||||
ValidationCase(
|
||||
annotation_dtype=BasicModel, dtype=BasicModel, passes=True, id="model-model"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=BasicModel, dtype=BadModel, passes=False, id="model-badmodel"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=BasicModel, dtype=int, passes=False, id="model-int"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=BasicModel, dtype=SubClass, passes=True, id="model-subclass"
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=UNION_TYPE,
|
||||
dtype=np.uint32,
|
||||
passes=True,
|
||||
id="union-type-uint32",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=UNION_TYPE,
|
||||
dtype=np.float32,
|
||||
passes=True,
|
||||
id="union-type-float32",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=UNION_TYPE,
|
||||
dtype=np.uint64,
|
||||
passes=False,
|
||||
id="union-type-uint64",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=UNION_TYPE,
|
||||
dtype=np.float64,
|
||||
passes=False,
|
||||
id="union-type-float64",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str"
|
||||
),
|
||||
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False, id="union-type-str"),
|
||||
]
|
||||
|
||||
|
||||
if YES_PIPE:
|
||||
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
|
||||
UNION_PIPE: TypeAlias = np.uint32 | np.float32
|
||||
|
||||
DTYPE_CASES.extend(
|
||||
[
|
||||
ValidationCase(
|
||||
annotation=UNION_PIPE,
|
||||
annotation_dtype=UNION_PIPE,
|
||||
dtype=np.uint32,
|
||||
passes=True,
|
||||
id="union-pipe-uint32",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=UNION_PIPE,
|
||||
annotation_dtype=UNION_PIPE,
|
||||
dtype=np.float32,
|
||||
passes=True,
|
||||
id="union-pipe-float32",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=UNION_PIPE,
|
||||
annotation_dtype=UNION_PIPE,
|
||||
dtype=np.uint64,
|
||||
passes=False,
|
||||
id="union-pipe-uint64",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=UNION_PIPE,
|
||||
annotation_dtype=UNION_PIPE,
|
||||
dtype=np.float64,
|
||||
passes=False,
|
||||
id="union-pipe-float64",
|
||||
),
|
||||
ValidationCase(
|
||||
annotation=UNION_PIPE, dtype=str, passes=False, id="union-pipe-str"
|
||||
annotation_dtype=UNION_PIPE,
|
||||
dtype=str,
|
||||
passes=False,
|
||||
id="union-pipe-str",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
_INTERFACE_CASES = [
|
||||
NumpyCase,
|
||||
HDF5Case,
|
||||
HDF5CompoundCase,
|
||||
DaskCase,
|
||||
ZarrCase,
|
||||
ZarrDirCase,
|
||||
ZarrZipCase,
|
||||
ZarrNestedCase,
|
||||
VideoCase,
|
||||
INTERFACE_CASES = [
|
||||
ValidationCase(interface=NumpyCase, id="numpy"),
|
||||
ValidationCase(interface=HDF5Case, id="hdf5"),
|
||||
ValidationCase(interface=HDF5CompoundCase, id="hdf5_compound"),
|
||||
ValidationCase(interface=DaskCase, id="dask"),
|
||||
ValidationCase(interface=ZarrCase, id="zarr"),
|
||||
ValidationCase(interface=ZarrDirCase, id="zarr_dir"),
|
||||
ValidationCase(interface=ZarrZipCase, id="zarr_zip"),
|
||||
ValidationCase(interface=ZarrNestedCase, id="zarr_nested"),
|
||||
ValidationCase(interface=VideoCase, id="video"),
|
||||
]
|
||||
|
||||
|
||||
DTYPE_AND_SHAPE_CASES = merged_product(SHAPE_CASES, DTYPE_CASES)
|
||||
DTYPE_AND_SHAPE_CASES_PASSING = merged_product(
|
||||
SHAPE_CASES, DTYPE_CASES, conditions={"passes": True}
|
||||
)
|
||||
|
||||
DTYPE_AND_INTERFACE_CASES = merged_product(INTERFACE_CASES, DTYPE_CASES)
|
||||
DTYPE_AND_INTERFACE_CASES_PASSING = merged_product(
|
||||
INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True}
|
||||
)
|
||||
|
||||
ALL_CASES = merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES)
|
||||
ALL_CASES_PASSING = merged_product(
|
||||
SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True}
|
||||
)
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from operator import ior
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Tuple, Type, Union
|
||||
from typing import Generator, List, Literal, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
|
||||
|
@ -101,6 +104,9 @@ class InterfaceCase(ABC):
|
|||
return False
|
||||
|
||||
|
||||
_a_shape_type = Tuple[Union[int, Literal["*"], Literal["..."]], ...]
|
||||
|
||||
|
||||
class ValidationCase(BaseModel):
|
||||
"""
|
||||
Test case for validating an array.
|
||||
|
@ -113,24 +119,56 @@ class ValidationCase(BaseModel):
|
|||
"""
|
||||
String identifying the validation case
|
||||
"""
|
||||
annotation: Any = NDArray[Shape["10, 10, *"], Float]
|
||||
annotation_shape: Union[
|
||||
Tuple[Union[int, str], ...], Tuple[Tuple[Union[int, str], ...], ...]
|
||||
] = (10, 10, "*", "*")
|
||||
"""
|
||||
Array annotation used in the validating model
|
||||
Any typed because the types of type annotations are weird
|
||||
Shape to use in computed annotation used to validate against
|
||||
"""
|
||||
shape: Tuple[int, ...] = (10, 10, 10)
|
||||
annotation_dtype: Union[DtypeType, Sequence[DtypeType]] = Float
|
||||
"""
|
||||
Dtype to use in computed annotation used to validate against
|
||||
"""
|
||||
shape: Tuple[int, ...] = (10, 10, 2, 2)
|
||||
"""Shape of the array to validate"""
|
||||
dtype: Union[Type, np.dtype] = float
|
||||
"""Dtype of the array to validate"""
|
||||
passes: bool = False
|
||||
"""Whether the validation should pass or not"""
|
||||
interface: Optional[InterfaceCase] = None
|
||||
interface: Optional[Type[InterfaceCase]] = None
|
||||
"""The interface test case to generate and validate the array with"""
|
||||
path: Optional[Path] = None
|
||||
"""The path to generate arrays into, if any."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@computed_field()
|
||||
def annotation(self) -> NDArray:
|
||||
"""
|
||||
Annotation used in the model we validate against
|
||||
"""
|
||||
# make a union type if we need to
|
||||
shape_union = all(isinstance(s, Sequence) for s in self.annotation_shape)
|
||||
dtype_union = isinstance(self.annotation_dtype, Sequence) and all(
|
||||
isinstance(s, Sequence) for s in self.annotation_dtype
|
||||
)
|
||||
if shape_union or dtype_union:
|
||||
shape_iter = (
|
||||
self.annotation_shape if shape_union else [self.annotation_shape]
|
||||
)
|
||||
dtype_iter = (
|
||||
self.annotation_dtype if dtype_union else [self.annotation_dtype]
|
||||
)
|
||||
annotations: List[type] = []
|
||||
for shape, dtype in product(shape_iter, dtype_iter):
|
||||
shape_str = ", ".join([str(i) for i in shape])
|
||||
annotations.append(NDArray[Shape[shape_str], dtype])
|
||||
return Union[tuple(annotations)]
|
||||
|
||||
else:
|
||||
shape_str = ", ".join([str(i) for i in self.annotation_shape])
|
||||
return NDArray[Shape[shape_str], self.annotation_dtype]
|
||||
|
||||
@computed_field()
|
||||
def model(self) -> Type[BaseModel]:
|
||||
"""A model with a field ``array`` with the given annotation"""
|
||||
|
@ -186,31 +224,8 @@ class ValidationCase(BaseModel):
|
|||
"""
|
||||
if isinstance(other, Sequence):
|
||||
return merge_cases(self, *other)
|
||||
|
||||
self_dump = self.model_dump(exclude_unset=True)
|
||||
other_dump = other.model_dump(exclude_unset=True)
|
||||
|
||||
# dumps might not have set `valid`, use only the ones that have
|
||||
valids = [
|
||||
v
|
||||
for v in [self_dump.get("valid", None), other_dump.get("valid", None)]
|
||||
if v is not None
|
||||
]
|
||||
valid = all(valids)
|
||||
|
||||
# combine ids if present
|
||||
ids = "-".join(
|
||||
[
|
||||
str(v)
|
||||
for v in [self_dump.get("id", None), other_dump.get("id", None)]
|
||||
if v is not None
|
||||
]
|
||||
)
|
||||
|
||||
merged = {**self_dump, **other_dump}
|
||||
merged["valid"] = valid
|
||||
merged["id"] = ids
|
||||
return ValidationCase(**merged)
|
||||
else:
|
||||
return merge_cases(self, other)
|
||||
|
||||
def skip(self) -> bool:
|
||||
"""
|
||||
|
@ -230,7 +245,73 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
|
|||
if len(args) == 1:
|
||||
return args[0]
|
||||
|
||||
case = args[0]
|
||||
for arg in args[1:]:
|
||||
case = case.merge(arg)
|
||||
return case
|
||||
dumped = [
|
||||
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args
|
||||
]
|
||||
|
||||
# self_dump = self.model_dump(exclude_unset=True)
|
||||
# other_dump = other.model_dump(exclude_unset=True)
|
||||
|
||||
# dumps might not have set `passes`, use only the ones that have
|
||||
passes = [v.get("passes") for v in dumped if "passes" in v]
|
||||
passes = all(passes)
|
||||
|
||||
# combine ids if present
|
||||
ids = "-".join([str(v.get("id")) for v in dumped if "id" in v])
|
||||
|
||||
# merge dicts
|
||||
merged = reduce(ior, dumped, {})
|
||||
merged["passes"] = passes
|
||||
merged["id"] = ids
|
||||
return ValidationCase.model_construct(**merged)
|
||||
|
||||
|
||||
def merged_product(
|
||||
*args: Sequence[ValidationCase], conditions: dict = None
|
||||
) -> Generator[ValidationCase, None, None]:
|
||||
"""
|
||||
Generator for the product of the iterators of validation cases,
|
||||
merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip`
|
||||
or not.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
shape_cases = [
|
||||
ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"),
|
||||
ValidationCase(shape=(10, 10), passes=False, id="missing dimension"),
|
||||
]
|
||||
dtype_cases = [
|
||||
ValidationCase(dtype=float, passes=True, id="float"),
|
||||
ValidationCase(dtype=int, passes=False, id="int"),
|
||||
]
|
||||
|
||||
iterator = merged_product(shape_cases, dtype_cases))
|
||||
next(iterator)
|
||||
# ValidationCase(
|
||||
# shape=(10, 10, 10),
|
||||
# dtype=float,
|
||||
# passes=True,
|
||||
# id="valid shape-float"
|
||||
# )
|
||||
next(iterator)
|
||||
# ValidationCase(
|
||||
# shape=(10, 10, 10),
|
||||
# dtype=int,
|
||||
# passes=False,
|
||||
# id="valid shape-int"
|
||||
# )
|
||||
|
||||
|
||||
"""
|
||||
iterator = product(*args)
|
||||
for case_tuple in iterator:
|
||||
case = merge_cases(*case_tuple)
|
||||
if case.skip():
|
||||
continue
|
||||
if conditions:
|
||||
matching = all([getattr(case, k, None) == v for k, v in conditions.items()])
|
||||
if not matching:
|
||||
continue
|
||||
yield case
|
||||
|
|
|
@ -154,7 +154,7 @@ class _ZarrMetaCase(InterfaceCase):
|
|||
|
||||
@classmethod
|
||||
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
|
||||
return issubclass(dtype, BaseModel)
|
||||
return issubclass(dtype, BaseModel) or dtype is str
|
||||
|
||||
|
||||
class ZarrCase(_ZarrMetaCase):
|
||||
|
@ -239,8 +239,8 @@ class VideoCase(InterfaceCase):
|
|||
@classmethod
|
||||
def make_array(
|
||||
cls,
|
||||
shape: Tuple[int, ...] = (10, 10),
|
||||
dtype: DtypeType = float,
|
||||
shape: Tuple[int, ...] = (10, 10, 10, 3),
|
||||
dtype: DtypeType = np.uint8,
|
||||
path: Optional[Path] = None,
|
||||
array: Optional[NDArrayType] = None,
|
||||
) -> Optional[Path]:
|
||||
|
@ -269,20 +269,26 @@ class VideoCase(InterfaceCase):
|
|||
frame = array[i]
|
||||
else:
|
||||
# make fresh array every time bc opencv eats them
|
||||
frame = np.zeros(frame_shape, dtype=np.uint8)
|
||||
if not is_color:
|
||||
frame[i, i] = i
|
||||
else:
|
||||
frame[i, i, :] = i
|
||||
frame = np.full(frame_shape, fill_value=i, dtype=np.uint8)
|
||||
writer.write(frame)
|
||||
writer.release()
|
||||
return video_path
|
||||
|
||||
@classmethod
|
||||
def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool:
|
||||
"""We really can only handle 3-4 dimensional cases in 8-bit rn lol"""
|
||||
if len(shape) < 3 or len(shape) > 4:
|
||||
"""
|
||||
We really can only handle 4 dimensional cases in 8-bit rn lol
|
||||
|
||||
.. todo::
|
||||
|
||||
Fix shape/writing for grayscale videos
|
||||
|
||||
"""
|
||||
if len(shape) != 4:
|
||||
return True
|
||||
|
||||
# if len(shape) < 3 or len(shape) > 4:
|
||||
# return True
|
||||
if dtype not in (int, np.uint8):
|
||||
return True
|
||||
# if we have a color video (ie. shape == 4, needs to be RGB)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import inspect
|
||||
from typing import Callable, Tuple, Type
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from numpydantic import NDArray, interface
|
||||
from numpydantic.testing.helpers import InterfaceCase
|
||||
from numpydantic.testing.cases import (
|
||||
ALL_CASES,
|
||||
ALL_CASES_PASSING,
|
||||
DTYPE_AND_INTERFACE_CASES_PASSING,
|
||||
)
|
||||
from numpydantic.testing.helpers import InterfaceCase, ValidationCase, merge_cases
|
||||
from numpydantic.testing.interfaces import (
|
||||
DaskCase,
|
||||
HDF5Case,
|
||||
|
@ -21,76 +21,130 @@ from numpydantic.testing.interfaces import (
|
|||
scope="function",
|
||||
params=[
|
||||
pytest.param(
|
||||
([[1, 2], [3, 4]], interface.NumpyInterface),
|
||||
marks=pytest.mark.numpy,
|
||||
id="numpy-list",
|
||||
),
|
||||
pytest.param(
|
||||
(NumpyCase, interface.NumpyInterface),
|
||||
NumpyCase,
|
||||
marks=pytest.mark.numpy,
|
||||
id="numpy",
|
||||
),
|
||||
pytest.param(
|
||||
(HDF5Case, interface.H5Interface),
|
||||
HDF5Case,
|
||||
marks=pytest.mark.hdf5,
|
||||
id="h5-array-path",
|
||||
),
|
||||
pytest.param(
|
||||
(DaskCase, interface.DaskInterface),
|
||||
DaskCase,
|
||||
marks=pytest.mark.dask,
|
||||
id="dask",
|
||||
),
|
||||
pytest.param(
|
||||
(ZarrCase, interface.ZarrInterface),
|
||||
ZarrCase,
|
||||
marks=pytest.mark.zarr,
|
||||
id="zarr-memory",
|
||||
),
|
||||
pytest.param(
|
||||
(ZarrNestedCase, interface.ZarrInterface),
|
||||
ZarrNestedCase,
|
||||
marks=pytest.mark.zarr,
|
||||
id="zarr-nested",
|
||||
),
|
||||
pytest.param(
|
||||
(ZarrDirCase, interface.ZarrInterface),
|
||||
ZarrDirCase,
|
||||
marks=pytest.mark.zarr,
|
||||
id="zarr-dir",
|
||||
),
|
||||
pytest.param(
|
||||
(VideoCase, interface.VideoInterface), marks=pytest.mark.video, id="video"
|
||||
),
|
||||
pytest.param(VideoCase, marks=pytest.mark.video, id="video"),
|
||||
],
|
||||
)
|
||||
def interface_type(
|
||||
request, tmp_output_dir_func
|
||||
) -> Tuple[NDArray, Type[interface.Interface]]:
|
||||
def interface_cases(request) -> InterfaceCase:
|
||||
"""
|
||||
Test cases for each interface's ``check`` method - each input should match the
|
||||
provided interface and that interface only
|
||||
Fixture for combinatoric tests across all interface cases
|
||||
"""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=(
|
||||
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
|
||||
for p in ALL_CASES
|
||||
)
|
||||
)
|
||||
def all_cases(interface_cases, request) -> ValidationCase:
|
||||
"""
|
||||
Combinatoric testing for all dtype, shape, and interface cases.
|
||||
|
||||
This is a very expensive fixture! Only use it for core functionality
|
||||
that we want to be sure is *very true* in every circumstance,
|
||||
INCLUDING invalid combinations of annotations and arrays.
|
||||
Typically, that means only use this in `test_interfaces.py`
|
||||
"""
|
||||
|
||||
if inspect.isclass(request.param[0]) and issubclass(
|
||||
request.param[0], InterfaceCase
|
||||
):
|
||||
array = request.param[0].make_array(path=tmp_output_dir_func)
|
||||
if array is None:
|
||||
pytest.skip()
|
||||
return array, request.param[1]
|
||||
else:
|
||||
return request.param
|
||||
case = merge_cases(request.param, ValidationCase(interface=interface_cases))
|
||||
if case.skip():
|
||||
pytest.skip()
|
||||
return case
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=(
|
||||
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
|
||||
for p in ALL_CASES_PASSING
|
||||
)
|
||||
)
|
||||
def all_passing_cases(request) -> ValidationCase:
|
||||
"""
|
||||
Combinatoric testing for all dtype, shape, and interface cases,
|
||||
but only the combinations that we expect to pass.
|
||||
|
||||
This is a very expensive fixture! Only use it for core functionality
|
||||
that we want to be sure is *very true* in every circumstance.
|
||||
Typically, that means only use this in `test_interfaces.py`
|
||||
"""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def all_interfaces(interface_type) -> BaseModel:
|
||||
def all_cases_instance(all_cases, tmp_output_dir_func):
|
||||
"""
|
||||
An instantiated version of each interface within a basemodel,
|
||||
with the array in an `array` field
|
||||
all_cases but with an instantiated model
|
||||
Args:
|
||||
all_cases:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
array, interface = interface_type
|
||||
if isinstance(array, Callable):
|
||||
array = array()
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray
|
||||
|
||||
instance = MyModel(array=array)
|
||||
array = all_cases.array(path=tmp_output_dir_func)
|
||||
instance = all_cases.model(array=array)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
|
||||
"""
|
||||
all_cases but with an instantiated model
|
||||
Args:
|
||||
all_cases:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
array = all_passing_cases.array(path=tmp_output_dir_func)
|
||||
instance = all_passing_cases.model(array=array)
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=(
|
||||
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
|
||||
for p in DTYPE_AND_INTERFACE_CASES_PASSING
|
||||
)
|
||||
)
|
||||
def dtype_by_interface(request):
|
||||
"""
|
||||
Tests for all dtypes by all interfaces
|
||||
"""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dtype_by_interface_instance(dtype_by_interface, tmp_output_dir_func):
|
||||
array = dtype_by_interface.array(path=tmp_output_dir_func)
|
||||
instance = dtype_by_interface.model(array=array)
|
||||
return instance
|
||||
|
|
|
@ -16,11 +16,13 @@ def test_dask_enabled():
|
|||
assert DaskInterface.enabled()
|
||||
|
||||
|
||||
def test_dask_check(interface_type):
|
||||
if interface_type[1] is DaskInterface:
|
||||
assert DaskInterface.check(interface_type[0])
|
||||
def test_dask_check(interface_cases, tmp_output_dir_func):
|
||||
array = interface_cases.make_array(path=tmp_output_dir_func)
|
||||
|
||||
if interface_cases.interface is DaskInterface:
|
||||
assert DaskInterface.check(array)
|
||||
else:
|
||||
assert not DaskInterface.check(interface_type[0])
|
||||
assert not DaskInterface.check(array)
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
|
|
|
@ -43,14 +43,12 @@ def test_hdf5_dtype(dtype_cases, hdf5_cases):
|
|||
dtype_cases.validate_case()
|
||||
|
||||
|
||||
def test_hdf5_check(interface_type):
|
||||
if interface_type[1] is H5Interface:
|
||||
assert H5Interface.check(interface_type[0])
|
||||
if isinstance(interface_type[0], H5ArrayPath):
|
||||
# also test that we can instantiate from a tuple like the H5ArrayPath
|
||||
assert H5Interface.check((interface_type[0].file, interface_type[0].path))
|
||||
def test_hdf5_check(interface_cases, tmp_output_dir_func):
|
||||
array = interface_cases.make_array(path=tmp_output_dir_func)
|
||||
if interface_cases.interface is H5Interface:
|
||||
assert H5Interface.check(array)
|
||||
else:
|
||||
assert not H5Interface.check(interface_type[0])
|
||||
assert not H5Interface.check(array)
|
||||
|
||||
|
||||
def test_hdf5_check_not_exists():
|
||||
|
|
|
@ -4,7 +4,6 @@ Tests that should be applied to all interfaces
|
|||
|
||||
import json
|
||||
from importlib.metadata import version
|
||||
from typing import Callable
|
||||
|
||||
import dask.array as da
|
||||
import numpy as np
|
||||
|
@ -13,76 +12,98 @@ from pydantic import BaseModel
|
|||
from zarr.core import Array as ZarrArray
|
||||
|
||||
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||
from numpydantic.testing.helpers import ValidationCase
|
||||
|
||||
|
||||
def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool):
|
||||
def _test_roundtrip(source: BaseModel, target: BaseModel):
|
||||
"""Test model equality for roundtrip tests"""
|
||||
if round_trip:
|
||||
assert type(target.array) is type(source.array)
|
||||
if isinstance(source.array, (np.ndarray, ZarrArray)):
|
||||
assert np.array_equal(target.array, np.array(source.array))
|
||||
elif isinstance(source.array, da.Array):
|
||||
assert np.all(da.equal(target.array, source.array))
|
||||
else:
|
||||
assert target.array == source.array
|
||||
|
||||
assert target.array.dtype == source.array.dtype
|
||||
else:
|
||||
assert type(target.array) is type(source.array)
|
||||
if isinstance(source.array, (np.ndarray, ZarrArray)):
|
||||
assert np.array_equal(target.array, np.array(source.array))
|
||||
elif isinstance(source.array, da.Array):
|
||||
if target.array.dtype == object:
|
||||
# object equality doesn't really work well with dask
|
||||
# just check that the types match
|
||||
target_type = type(target.array.ravel()[0].compute())
|
||||
source_type = type(source.array.ravel()[0].compute())
|
||||
assert target_type is source_type
|
||||
else:
|
||||
assert np.all(da.equal(target.array, source.array))
|
||||
else:
|
||||
assert target.array == source.array
|
||||
|
||||
assert target.array.dtype == source.array.dtype
|
||||
|
||||
|
||||
def test_dunder_len(all_interfaces):
|
||||
def test_dunder_len(interface_cases, tmp_output_dir_func):
|
||||
"""
|
||||
Each interface or proxy type should support __len__
|
||||
"""
|
||||
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
||||
case = ValidationCase(interface=interface_cases)
|
||||
if interface_cases.interface.name == "video":
|
||||
case.shape = (10, 10, 2, 3)
|
||||
case.dtype = np.uint8
|
||||
case.annotation_dtype = np.uint8
|
||||
case.annotation_shape = (10, 10, "*", 3)
|
||||
array = case.array(path=tmp_output_dir_func)
|
||||
instance = case.model(array=array)
|
||||
assert len(instance.array) == case.shape[0]
|
||||
|
||||
|
||||
def test_interface_revalidate(all_interfaces):
|
||||
def test_interface_revalidate(all_passing_cases_instance):
|
||||
"""
|
||||
An interface should revalidate with the output of its initial validation
|
||||
|
||||
See: https://github.com/p2p-ld/numpydantic/pull/14
|
||||
"""
|
||||
_ = type(all_interfaces)(array=all_interfaces.array)
|
||||
|
||||
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)
|
||||
|
||||
|
||||
def test_interface_rematch(interface_type):
|
||||
@pytest.mark.xfail
|
||||
def test_interface_rematch(interface_cases, tmp_output_dir_func):
|
||||
"""
|
||||
All interfaces should match the results of the object they return after validation
|
||||
"""
|
||||
array, interface = interface_type
|
||||
if isinstance(array, Callable):
|
||||
array = array()
|
||||
array = interface_cases.make_array(path=tmp_output_dir_func)
|
||||
|
||||
assert Interface.match(interface().validate(array)) is interface
|
||||
assert (
|
||||
Interface.match(interface_cases.interface.validate(array))
|
||||
is interface_cases.interface
|
||||
)
|
||||
|
||||
|
||||
def test_interface_to_numpy_array(all_interfaces):
|
||||
def test_interface_to_numpy_array(dtype_by_interface):
|
||||
"""
|
||||
All interfaces should be able to have the output of their validation stage
|
||||
coerced to a numpy array with np.array()
|
||||
"""
|
||||
_ = np.array(all_interfaces.array)
|
||||
_ = np.array(dtype_by_interface.array)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_interface_dump_json(all_interfaces):
|
||||
def test_interface_dump_json(dtype_by_interface_instance):
|
||||
"""
|
||||
All interfaces should be able to dump to json
|
||||
"""
|
||||
all_interfaces.model_dump_json()
|
||||
dtype_by_interface_instance.model_dump_json()
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("round_trip", [True, False])
|
||||
def test_interface_roundtrip_json(all_interfaces, round_trip):
|
||||
def test_interface_roundtrip_json(dtype_by_interface, tmp_output_dir_func):
|
||||
"""
|
||||
All interfaces should be able to roundtrip to and from json
|
||||
"""
|
||||
dumped_json = all_interfaces.model_dump_json(round_trip=round_trip)
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
_test_roundtrip(all_interfaces, model, round_trip)
|
||||
if "subclass" in dtype_by_interface.id.lower():
|
||||
pytest.xfail()
|
||||
|
||||
array = dtype_by_interface.array(path=tmp_output_dir_func)
|
||||
case = dtype_by_interface.model(array=array)
|
||||
|
||||
dumped_json = case.model_dump_json(round_trip=True)
|
||||
model = case.model_validate_json(dumped_json)
|
||||
_test_roundtrip(case, model)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
|
@ -101,15 +122,20 @@ def test_interface_mark_interface(an_interface):
|
|||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("valid", [True, False])
|
||||
@pytest.mark.parametrize("round_trip", [True, False])
|
||||
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
|
||||
def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
|
||||
def test_interface_mark_roundtrip(dtype_by_interface, valid, tmp_output_dir_func):
|
||||
"""
|
||||
All interfaces should be able to roundtrip with the marked interface,
|
||||
and a mismatch should raise a warning and attempt to proceed
|
||||
"""
|
||||
dumped_json = all_interfaces.model_dump_json(
|
||||
round_trip=round_trip, context={"mark_interface": True}
|
||||
if "subclass" in dtype_by_interface.id.lower():
|
||||
pytest.xfail()
|
||||
|
||||
array = dtype_by_interface.array(path=tmp_output_dir_func)
|
||||
case = dtype_by_interface.model(array=array)
|
||||
|
||||
dumped_json = case.model_dump_json(
|
||||
round_trip=True, context={"mark_interface": True}
|
||||
)
|
||||
|
||||
data = json.loads(dumped_json)
|
||||
|
@ -123,8 +149,8 @@ def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
|
|||
dumped_json = json.dumps(data)
|
||||
|
||||
with pytest.warns(match="Mismatch.*"):
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
model = case.model_validate_json(dumped_json)
|
||||
else:
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
model = case.model_validate_json(dumped_json)
|
||||
|
||||
_test_roundtrip(all_interfaces, model, round_trip)
|
||||
_test_roundtrip(case, model)
|
||||
|
|
|
@ -80,15 +80,12 @@ def test_video_getitem(avi_video):
|
|||
|
||||
instance = MyModel(array=vid)
|
||||
fifth_frame = instance.array[5]
|
||||
# the first frame should have 1's in the 1,1 position
|
||||
# the fifth frame should be all 5s
|
||||
assert (fifth_frame[5, 5, :] == [5, 5, 5]).all()
|
||||
# and nothing in the 6th position
|
||||
assert (fifth_frame[6, 6, :] == [0, 0, 0]).all()
|
||||
|
||||
# slicing should also work as if it were just a numpy array
|
||||
single_slice = instance.array[3, 0:10, 0:5]
|
||||
assert single_slice[3, 3, 0] == 3
|
||||
assert single_slice[4, 4, 0] == 0
|
||||
assert single_slice.shape == (10, 5, 3)
|
||||
|
||||
# also get a range of frames
|
||||
|
@ -96,19 +93,19 @@ def test_video_getitem(avi_video):
|
|||
range_slice = instance.array[3:5]
|
||||
assert range_slice.shape == (2, 100, 50, 3)
|
||||
assert range_slice[0, 3, 3, 0] == 3
|
||||
assert range_slice[0, 4, 4, 0] == 0
|
||||
assert range_slice[1, 4, 4, 0] == 4
|
||||
|
||||
# full range
|
||||
range_slice = instance.array[3:5, 0:10, 0:5]
|
||||
assert range_slice.shape == (2, 10, 5, 3)
|
||||
assert range_slice[0, 3, 3, 0] == 3
|
||||
assert range_slice[0, 4, 4, 0] == 0
|
||||
assert range_slice[1, 4, 4, 0] == 4
|
||||
|
||||
# starting range
|
||||
range_slice = instance.array[6:, 0:10, 0:10]
|
||||
assert range_slice.shape == (4, 10, 10, 3)
|
||||
assert range_slice[-1, 9, 9, 0] == 9
|
||||
assert range_slice[-2, 9, 9, 0] == 0
|
||||
assert range_slice[-2, 9, 9, 0] == 8
|
||||
|
||||
# ending range
|
||||
range_slice = instance.array[:3, 0:5, 0:5]
|
||||
|
@ -119,10 +116,8 @@ def test_video_getitem(avi_video):
|
|||
# second slice should be the second frame (instead of the first)
|
||||
assert range_slice.shape == (3, 6, 6, 3)
|
||||
assert range_slice[1, 2, 2, 0] == 2
|
||||
assert range_slice[1, 3, 3, 0] == 0
|
||||
# and the third should be the fourth (instead of the second)
|
||||
assert range_slice[2, 4, 4, 0] == 4
|
||||
assert range_slice[2, 5, 5, 0] == 0
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
# shouldn't be allowed to set
|
||||
|
|
|
@ -38,14 +38,15 @@ def test_zarr_enabled():
|
|||
assert ZarrInterface.enabled()
|
||||
|
||||
|
||||
def test_zarr_check(interface_type):
|
||||
def test_zarr_check(interface_cases, tmp_output_dir_func):
|
||||
"""
|
||||
We should only use the zarr interface for zarr-like things
|
||||
"""
|
||||
if interface_type[1] is ZarrInterface:
|
||||
assert ZarrInterface.check(interface_type[0])
|
||||
array = interface_cases.make_array(path=tmp_output_dir_func)
|
||||
if interface_cases.interface is ZarrInterface:
|
||||
assert ZarrInterface.check(array)
|
||||
else:
|
||||
assert not ZarrInterface.check(interface_type[0])
|
||||
assert not ZarrInterface.check(array)
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
|
|
Loading…
Reference in a new issue