add ability to set marks in test cases, add combinatoric test for json schema generation

This commit is contained in:
sneakers-the-rat 2024-12-13 16:36:51 -08:00
parent 0cbdfb2890
commit 8bf2203911
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
7 changed files with 61 additions and 16 deletions

View file

@ -5,7 +5,7 @@
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"] groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
strategy = ["cross_platform", "inherit_metadata"] strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.5.0" lock_version = "4.5.0"
content_hash = "sha256:cc2b0fb32896c6df0ad747ddb5dee89af22f5c4c4643ee7a52db47fef30da936" content_hash = "sha256:89ac87e811ecc42bf5117e9c9e4aa6a69011cb7c5c1a630fbfb2643b0045c526"
[[metadata.targets]] [[metadata.targets]]
requires_python = "~=3.9" requires_python = "~=3.9"

View file

@ -129,6 +129,8 @@ markers = [
"numpy: numpy interface", "numpy: numpy interface",
"video: video interface", "video: video interface",
"zarr: zarr interface", "zarr: zarr interface",
"union: union dtypes",
"pipe_union: union dtypes specified with a pipe",
] ]
[tool.black] [tool.black]

View file

@ -143,27 +143,35 @@ DTYPE_CASES = [
dtype=np.uint32, dtype=np.uint32,
passes=True, passes=True,
id="union-type-uint32", id="union-type-uint32",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, annotation_dtype=UNION_TYPE,
dtype=np.float32, dtype=np.float32,
passes=True, passes=True,
id="union-type-float32", id="union-type-float32",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, annotation_dtype=UNION_TYPE,
dtype=np.uint64, dtype=np.uint64,
passes=False, passes=False,
id="union-type-uint64", id="union-type-uint64",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, annotation_dtype=UNION_TYPE,
dtype=np.float64, dtype=np.float64,
passes=False, passes=False,
id="union-type-float64", id="union-type-float64",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str" annotation_dtype=UNION_TYPE,
dtype=str,
passes=False,
id="union-type-str",
marks={"union"},
), ),
] ]
""" """
@ -181,30 +189,35 @@ if YES_PIPE:
dtype=np.uint32, dtype=np.uint32,
passes=True, passes=True,
id="union-pipe-uint32", id="union-pipe-uint32",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=np.float32, dtype=np.float32,
passes=True, passes=True,
id="union-pipe-float32", id="union-pipe-float32",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=np.uint64, dtype=np.uint64,
passes=False, passes=False,
id="union-pipe-uint64", id="union-pipe-uint64",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=np.float64, dtype=np.float64,
passes=False, passes=False,
id="union-pipe-float64", id="union-pipe-float64",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=str, dtype=str,
passes=False, passes=False,
id="union-pipe-str", id="union-pipe-str",
marks={"union", "pipe_union"},
), ),
] ]
) )

View file

@ -4,16 +4,19 @@ from functools import reduce
from itertools import product from itertools import product
from operator import ior from operator import ior
from pathlib import Path from pathlib import Path
from typing import Generator, List, Literal, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Generator, List, Literal, Optional, Tuple, Type, Union
import numpy as np import numpy as np
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
from numpydantic.dtype import Float from numpydantic.dtype import Float
from numpydantic.interface import Interface from numpydantic.interface import Interface
from numpydantic.types import DtypeType, NDArrayType from numpydantic.types import DtypeType, NDArrayType
if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator
class InterfaceCase(ABC): class InterfaceCase(ABC):
""" """
@ -139,6 +142,8 @@ class ValidationCase(BaseModel):
"""The interface test case to generate and validate the array with""" """The interface test case to generate and validate the array with"""
path: Optional[Path] = None path: Optional[Path] = None
"""The path to generate arrays into, if any.""" """The path to generate arrays into, if any."""
marks: set[str] = Field(default_factory=set)
"""pytest marks to set for this test case"""
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -179,6 +184,19 @@ class ValidationCase(BaseModel):
return Model return Model
@property
def pytest_marks(self) -> list["MarkDecorator"]:
"""
Instantiated pytest marks from :attr:`.ValidationCase.marks`
plus the interface name.
"""
import pytest
marks = self.marks.copy()
if self.interface is not None:
marks.add(self.interface.interface.name)
return [getattr(pytest.mark, m) for m in marks]
def validate_case(self, path: Optional[Path] = None) -> bool: def validate_case(self, path: Optional[Path] = None) -> bool:
""" """
Whether the generated array correctly validated against the annotation, Whether the generated array correctly validated against the annotation,
@ -246,7 +264,10 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
return args[0] return args[0]
dumped = [ dumped = [
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args m.model_dump(
exclude_unset=True, exclude={"model", "annotation", "pytest_marks"}
)
for m in args
] ]
# self_dump = self.model_dump(exclude_unset=True) # self_dump = self.model_dump(exclude_unset=True)
@ -263,6 +284,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
merged = reduce(ior, dumped, {}) merged = reduce(ior, dumped, {})
merged["passes"] = passes merged["passes"] = passes
merged["id"] = ids merged["id"] = ids
merged["marks"] = set().union(*[v.get("marks", set()) for v in dumped])
return ValidationCase.model_construct(**merged) return ValidationCase.model_construct(**merged)

View file

@ -14,7 +14,8 @@ def pytest_addoption(parser):
@pytest.fixture( @pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES] scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in SHAPE_CASES],
) )
def shape_cases(request, tmp_output_dir_func) -> ValidationCase: def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy() case: ValidationCase = request.param.model_copy()
@ -23,7 +24,8 @@ def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
@pytest.fixture( @pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES] scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in DTYPE_CASES],
) )
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase: def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy() case: ValidationCase = request.param.model_copy()

View file

@ -61,10 +61,7 @@ def interface_cases(request) -> InterfaceCase:
@pytest.fixture( @pytest.fixture(
params=( params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES)
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES
)
) )
def all_cases(interface_cases, request) -> ValidationCase: def all_cases(interface_cases, request) -> ValidationCase:
""" """
@ -83,10 +80,7 @@ def all_cases(interface_cases, request) -> ValidationCase:
@pytest.fixture( @pytest.fixture(
params=( params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES_PASSING)
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES_PASSING
)
) )
def all_passing_cases(request) -> ValidationCase: def all_passing_cases(request) -> ValidationCase:
""" """
@ -132,7 +126,7 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
@pytest.fixture( @pytest.fixture(
params=( params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name)) pytest.param(p, id=p.id, marks=p.pytest_marks)
for p in DTYPE_AND_INTERFACE_CASES_PASSING for p in DTYPE_AND_INTERFACE_CASES_PASSING
) )
) )

View file

@ -61,6 +61,18 @@ def test_interface_revalidate(all_passing_cases_instance):
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array) _ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)
@pytest.mark.json_schema
def test_interface_jsonschema(all_passing_cases_instance):
"""
All interfaces should be able to generate json schema
for all combinations of dtype and shape
Note that this does not test for json schema correctness -
see ndarray tests for that
"""
_ = all_passing_cases_instance.model_json_schema()
@pytest.mark.xfail @pytest.mark.xfail
def test_interface_rematch(interface_cases, tmp_output_dir_func): def test_interface_rematch(interface_cases, tmp_output_dir_func):
""" """