diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e0a4e3f..f2c8974 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,10 +33,10 @@ jobs: runs-on: ${{ matrix.platform }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/docs/_static/css/notebooks.css b/docs/_static/css/notebooks.css new file mode 100644 index 0000000..b7fcb03 --- /dev/null +++ b/docs/_static/css/notebooks.css @@ -0,0 +1,31 @@ +div.cell.tag_hide-cell details.above-input > summary, +div.cell.tag_hide-input details.above-input > summary, +div.cell.tag_hide-output details.below-input > summary{ + background-color: var(--color-admonition-title-background--admonition-todo); + color: var(--color-content-foreground); + border: unset; + border-left: 2px solid var(--mystnb-source-margin-color); + opacity: unset; + padding: 0.25em 0 0.25em 1em; +} + +div.cell.tag_hide-cell details.above-input > summary > span, +div.cell.tag_hide-input details.above-input > summary > span, +div.cell.tag_hide-output details.below-input > summary > span +{ + opacity: unset; +} + +div.cell details.above-input div.cell_input { + border: unset; + background-color: unset; + border-left: 2px solid var(--mystnb-source-margin-color); +} + +div.cell details.above-input div.cell_input div.highlight { + background: var(--color-admonition-background); +} + +.output.text_html pre { + font-size: 0.8em; +} \ No newline at end of file diff --git a/docs/api/testing/cases.md b/docs/api/testing/cases.md new file mode 100644 index 0000000..784bd62 --- /dev/null +++ b/docs/api/testing/cases.md @@ -0,0 +1,7 @@ +# cases + +```{eval-rst} +.. automodule:: numpydantic.testing.cases + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/docs/api/testing/helpers.md b/docs/api/testing/helpers.md new file mode 100644 index 0000000..084901b --- /dev/null +++ b/docs/api/testing/helpers.md @@ -0,0 +1,7 @@ +# helpers + +```{eval-rst} +.. automodule:: numpydantic.testing.helpers + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/docs/api/testing/index.md b/docs/api/testing/index.md new file mode 100644 index 0000000..7b2a976 --- /dev/null +++ b/docs/api/testing/index.md @@ -0,0 +1,19 @@ +# testing + +Utilities for testing and 3rd-party interface development. + +See also the [narrative testing docs](../../contributing/testing.md) + +```{toctree} +:maxdepth: 2 + +cases +helpers +interfaces +``` + +```{eval-rst} +.. automodule:: numpydantic.testing + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/docs/api/testing/interfaces.md b/docs/api/testing/interfaces.md new file mode 100644 index 0000000..c68c772 --- /dev/null +++ b/docs/api/testing/interfaces.md @@ -0,0 +1,7 @@ +# interfaces + +```{eval-rst} +.. automodule:: numpydantic.testing.interfaces + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/docs/changelog.md b/docs/changelog.md index 75ffcb3..365011a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,33 @@ ### 1.6.* +#### 1.6.4 - 24-10-11 - Combinatoric Testing + +PR: https://github.com/p2p-ld/numpydantic/pull/31 + + +We have rewritten our testing system for more rigorous tests, +where before we were limited to only testing dtype or shape cases one at a time, +now we can test all possible combinations together! + +This allows us to have better guarantees for behavior that all interfaces +should support, validating it against all possible dtypes and shapes. + +We also exposed all the helpers and array testing classes for downstream development +so that it would be easier to test and validate any 3rd-party interfaces +that haven't made their way into mainline numpydantic yet - +see the {mod}`numpydantic.testing` module. + +See the [testing documentation](./contributing/testing.md) for more details. + +**Bugfix** +- Previously, numpy and dask arrays with a model dtype would fail json roundtripping + because they wouldn't be correctly cast back to the model type. Now they are. +- Zarr would not dump the dtype of an array when it roundtripped to json, + causing every array to be interpreted as a random integer or float type. + `dtype` is now dumped and used when deserializing. + + #### 1.6.3 - 24-09-26 **Bugfix** diff --git a/docs/conf.py b/docs/conf.py index 963b0b3..32eb942 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -49,6 +49,7 @@ intersphinx_mapping = { html_theme = "furo" html_static_path = ["_static"] +html_css_files = ["css/notebooks.css"] # autodoc autodoc_pydantic_model_show_json_error_strategy = "coerce" diff --git a/docs/contributing/coc.md b/docs/contributing/coc.md new file mode 100644 index 0000000..8504e4c --- /dev/null +++ b/docs/contributing/coc.md @@ -0,0 +1,5 @@ +# Code of Conduct + +```{todo} +jonny write the code of conduct +``` \ No newline at end of file diff --git a/docs/contributing/index.md b/docs/contributing/index.md new file mode 100644 index 0000000..fd5a4c9 --- /dev/null +++ b/docs/contributing/index.md @@ -0,0 +1,8 @@ +# Contributing + +```{toctree} +coc +process +interface +testing +``` \ No newline at end of file diff --git a/docs/contributing/interface.md b/docs/contributing/interface.md new file mode 100644 index 0000000..d3e29a1 --- /dev/null +++ b/docs/contributing/interface.md @@ -0,0 +1,5 @@ +# Writing an Interface + +```{todo} +Jonny write the interface contrib docs +``` \ No newline at end of file diff --git a/docs/contributing/process.md b/docs/contributing/process.md new file mode 100644 index 0000000..51f29ba --- /dev/null +++ b/docs/contributing/process.md @@ -0,0 +1,15 @@ +# Contribution Process + +```{todo} +Jonny write the contribution docs +``` + +### Issues + +### Development Environment + +### Testing + +### Linting + +### Pull Requests \ No newline at end of file diff --git a/docs/contributing/testing.md b/docs/contributing/testing.md new file mode 100644 index 0000000..d7e9be0 --- /dev/null +++ b/docs/contributing/testing.md @@ -0,0 +1,213 @@ +--- +file_format: mystnb +mystnb: + output_stderr: remove + render_text_lexer: python + render_markdown_format: myst +myst: + enable_extensions: ["colon_fence"] +--- +# Testing + +```{code-cell} +--- +tags: [hide-cell] +--- + +from pathlib import Path +from rich.console import Console +from rich.theme import Theme +from rich.style import Style +from rich.color import Color + +theme = Theme({ + "repr.call": Style(color=Color.from_rgb(110,191,38), bold=True), + "repr.attrib_name": Style(color="slate_blue1"), + "repr.number": Style(color="deep_sky_blue1"), + "repr.none": Style(color="bright_magenta", italic=True), + "repr.attrib_name": Style(color="white"), + "repr.tag_contents": Style(color="light_steel_blue"), + "repr.str": Style(color="violet") +}) +console = Console(theme=theme) + +``` + +```{note} +Also see the [`numpydantic.testing` API docs](../api/testing/index.md) +and the [Writing an Interface](../interfaces.md) guide +``` + +Numpydantic exposes a system for combinatoric testing across dtypes, shapes, +and interfaces in the {mod}`numpydantic.testing` module. + +These helper classes and functions are included in the distributed package +so they can be used for downstream development of independent interfaces +(though we always welcome contributions!) + +## Validation Cases + +Each test case is parameterized by a {class}`.ValidationCase`. + +The case is intended to be able to be partially filled in so that multiple +validation cases can be merged together, but also used independently +by falling back on default values. + +There are three major parts to a validation case: + +- **Annotation specification:** {attr}`~.ValidationCase.annotation_dtype` and + {attr}`~.ValidationCase.annotation_shape` specifies how the + {class}`.NDArray` {attr}`.ValidationCase.annotation` that is used to test + against is generated +- **Array specification:** {attr}`~.ValidationCase.dtype` and {attr}`~.ValidationCase.shape` + specify that array that will be generated to test against the annotation +- **Interface specification:** An {class}`.InterfaceCase` that refers to + an {class}`.Interface`, and provides array generation and other auxilary logic. + +Typically, one specifies a dtype along with an annotation dtype or +a shape along with an annotation shape (or implicitly against the defaults for either), +along with a value for `passes` that indicates if that combination is valid. + +```{code-cell} +from numpydantic.testing import ValidationCase + +dtype_case = ValidationCase( + id="int_int", + dtype=int, + annotation_dtype=int, + passes=True +) +shape_case = ValidationCase( + id="cool_shape", + shape=(1,2,3), + annotation_shape=(1,"*","2-4"), + passes=True +) + +merged = dtype_case.merge(shape_case) +console.print(merged.model_dump(exclude={'annotation', 'model'}, exclude_unset=True)) +``` + +When merging validation cases, the merged case only `passes` if all the +original cases do. + +```{code-cell} +from numpydantic.testing import ValidationCase + +dtype_case = ValidationCase( + id="int_int", + dtype=int, + annotation_dtype=int, + passes=True +) +shape_case = ValidationCase( + id="uncool_shape", + shape=(1,2,3), + annotation_shape=(9,8,7), + passes=False +) + +merged = dtype_case.merge(shape_case) +console.print(merged.model_dump(exclude={'annotation', 'model'}, exclude_unset=True)) +``` + +We provide a convenience function {func}`.merged_product` for creating a merged product of +multiple sets of test cases. + +For example, you may want to create a set of dtype and shape cases and validate +against all combinations + +```{code-cell} +from numpydantic.testing.helpers import merged_product + +dtype_cases = [ + ValidationCase(dtype=int, annotation_dtype=int, passes=True), + ValidationCase(dtype=int, annotation_dtype=float, passes=False) +] +shape_cases = [ + ValidationCase(shape=(1,2,3), annotation_shape=(1,2,3), passes=True), + ValidationCase(shape=(4,5,6), annotation_shape=(1,2,3), passes=False) +] + +iterator = merged_product(dtype_cases, shape_cases) + +console.print([i.model_dump(exclude_unset=True, exclude={'model', 'annotation'}) for i in iterator]) + +``` + +You can pass constraints to the {func}`.merged_product` iterator to +filter cases that match some value, for example to get only the cases that pass: + +```{code-cell} +iterator = merged_product(dtype_cases, shape_cases, conditions={"passes": True}) +console.print([i.model_dump(exclude_unset=True, exclude={'model', 'annotation'}) for i in iterator]) +``` + +## Interface Cases + +Validation cases can be paired with interface cases that handle +generating arrays for the given interface from the specification in the +validation case. + +Since some array interfaces like Zarr have multiple possible forms +of an array (in memory, on disk, in a zip file, etc.) an interface +may have multiple cases that are important to test against. + +The {meth}`.InterfaceCase.make_array` method does what you'd expect it to, +creating an array, and returning the appropriate input type for the interface: + +```{code-cell} +from numpydantic.testing.interfaces import NumpyCase, ZarrNestedCase + +NumpyCase.make_array(shape=(1,2,3), dtype=float) +``` + +```{code-cell} +ZarrNestedCase.make_array(shape=(1,2,3), dtype=float, path=Path("__tmp__/zarr_dir")) +``` + +Interface cases also define when an interface should skip a given test +parameterization. For example, some array formats can't support arbitrary +object serialization, and the video class can only support 8-bit arrays +of a specific shape + +```{code-cell} +from numpydantic.testing.interfaces import VideoCase + +VideoCase.skip(shape=(1,1), dtype=float) +``` + +This, and the array generation methods are propagated up into +a ValidationCase that contains them + +```{code-cell} +case = ValidationCase(shape=(1,2,3), dtype=float, interface=VideoCase) +case.skip() +``` + +The {func}`.merged_product` iterator automatically excludes any +combinations of interfaces and test parameterizations that should be skipped. + +## Making Fixtures + +Pytest fixtures are a useful way to re-use validation case products. +To keep things tidy, you may want to use marks and ids when creating them +so that you can run tests against specific interfaces or conditions +with the `pytest -m mark` system. + +```python +import pytest + +@pytest.fixture( + params=( + pytest.param( + p, + id=p.id, + marks=getattr(pytest.mark, p.interface.interface.name) + ) + for p in iterator + ) +) +def my_cases(request): + return request.param +``` diff --git a/docs/index.md b/docs/index.md index 127719f..313601e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -514,6 +514,7 @@ api/meta api/schema api/serialization api/types +api/testing/index ``` @@ -523,6 +524,7 @@ api/types :hidden: true changelog +contributing/index development todo ``` diff --git a/pyproject.toml b/pyproject.toml index 3f46a80..bbbce4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "numpydantic" -version = "1.6.3" +version = "1.6.4" description = "Type and shape validation and serialization for arbitrary array types in pydantic models" authors = [ {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, @@ -96,6 +96,12 @@ distribution = true [tool.pdm.build] includes = [] + +[tool.pdm.scripts] +lint = "ruff check" +format = {shell = "ruff check --fix ; black ."} +test = "pytest" + [build-system] requires = ["pdm-backend"] build-backend = "pdm.backend" @@ -125,10 +131,12 @@ markers = [ "zarr: zarr interface", ] +[tool.black] +target-version = ["py39", "py310", "py311", "py312"] + [tool.ruff] target-version = "py39" -include = ["src/numpydantic/**/*.py", "pyproject.toml"] -exclude = ["tests"] +include = ["src/numpydantic/**/*.py", "tests/**/*.py", "pyproject.toml"] [tool.ruff.lint] select = [ @@ -177,6 +185,10 @@ ignore = [ fixable = ["ALL"] +[tool.ruff.lint.per-file-ignores] +"src/numpydantic/testing/*" = ["D", "F722"] +"tests/*" = ["D", "F403", "F722", "ANN", ] + [tool.mypy] plugins = [ "pydantic.mypy" diff --git a/src/numpydantic/interface/__init__.py b/src/numpydantic/interface/__init__.py index 36c7d97..b36608f 100644 --- a/src/numpydantic/interface/__init__.py +++ b/src/numpydantic/interface/__init__.py @@ -3,7 +3,7 @@ Interfaces between nptyping types and array backends """ from numpydantic.interface.dask import DaskInterface -from numpydantic.interface.hdf5 import H5Interface +from numpydantic.interface.hdf5 import H5ArrayPath, H5Interface from numpydantic.interface.interface import ( Interface, InterfaceMark, @@ -12,10 +12,11 @@ from numpydantic.interface.interface import ( ) from numpydantic.interface.numpy import NumpyInterface from numpydantic.interface.video import VideoInterface -from numpydantic.interface.zarr import ZarrInterface +from numpydantic.interface.zarr import ZarrArrayPath, ZarrInterface __all__ = [ "DaskInterface", + "H5ArrayPath", "H5Interface", "Interface", "InterfaceMark", @@ -23,5 +24,6 @@ __all__ = [ "MarkedJson", "NumpyInterface", "VideoInterface", + "ZarrArrayPath", "ZarrInterface", ] diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index 95d0619..62867a0 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -5,7 +5,7 @@ Interface for Dask arrays from typing import Any, Iterable, List, Literal, Optional, Union import numpy as np -from pydantic import SerializationInfo +from pydantic import BaseModel, SerializationInfo from numpydantic.interface.interface import Interface, JsonDict from numpydantic.types import DtypeType, NDArrayType @@ -70,9 +70,33 @@ class DaskInterface(Interface): else: return False + def before_validation(self, array: DaskArray) -> NDArrayType: + """ + Try and coerce dicts that should be model objects into the model objects + """ + try: + if issubclass(self.dtype, BaseModel) and isinstance( + array.reshape(-1)[0].compute(), dict + ): + + def _chunked_to_model(array: np.ndarray) -> np.ndarray: + def _vectorized_to_model(item: Union[dict, BaseModel]) -> BaseModel: + if not isinstance(item, self.dtype): + return self.dtype(**item) + else: # pragma: no cover + return item + + return np.vectorize(_vectorized_to_model)(array) + + array = array.map_blocks(_chunked_to_model, dtype=self.dtype) + except TypeError: + # fine, dtype isn't a type + pass + return array + def get_object_dtype(self, array: NDArrayType) -> DtypeType: """Dask arrays require a compute() call to retrieve a single value""" - return type(array.ravel()[0].compute()) + return type(array.reshape(-1)[0].compute()) @classmethod def enabled(cls) -> bool: diff --git a/src/numpydantic/interface/numpy.py b/src/numpydantic/interface/numpy.py index 6c84232..20019f7 100644 --- a/src/numpydantic/interface/numpy.py +++ b/src/numpydantic/interface/numpy.py @@ -4,7 +4,7 @@ Interface to numpy arrays from typing import Any, Literal, Union -from pydantic import SerializationInfo +from pydantic import BaseModel, SerializationInfo from numpydantic.interface.interface import Interface, JsonDict @@ -59,6 +59,9 @@ class NumpyInterface(Interface): Check that this is in fact a numpy ndarray or something that can be coerced to one """ + if array is None: + return False + if isinstance(array, ndarray): return True elif isinstance(array, dict): @@ -77,6 +80,14 @@ class NumpyInterface(Interface): """ if not isinstance(array, ndarray): array = np.array(array) + + try: + if issubclass(self.dtype, BaseModel) and isinstance(array.flat[0], dict): + array = np.vectorize(lambda x: self.dtype(**x))(array) + except TypeError: + # fine, dtype isn't a type + pass + return array @classmethod diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index 5dc647e..2e3d993 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -63,6 +63,7 @@ class ZarrJsonDict(JsonDict): type: Literal["zarr"] file: Optional[str] = None path: Optional[str] = None + dtype: Optional[str] = None value: Optional[list] = None def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]: @@ -73,7 +74,7 @@ class ZarrJsonDict(JsonDict): if self.file: array = ZarrArrayPath(file=self.file, path=self.path) else: - array = zarr.array(self.value) + array = zarr.array(self.value, dtype=self.dtype) return array @@ -194,6 +195,7 @@ class ZarrInterface(Interface): is_file = False as_json = {"type": cls.name} + as_json["dtype"] = array.dtype.name if hasattr(array.store, "dir_path"): is_file = True as_json["file"] = array.store.dir_path() diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index 6969d44..fa0fdc9 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -152,6 +152,8 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): result = str(dtype) elif isinstance(dtype, tuple): result = ", ".join([str(dt) for dt in dtype]) + else: + result = str(dtype) return result diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index 94c627c..eb5b4bc 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -80,7 +80,7 @@ def _relativize_paths( ): return v return str(relative_path(path, relative_to)) - except (TypeError, ValueError): + except (TypeError, ValueError, OSError): return v return _walk_and_apply(value, _r_path, skip) @@ -95,7 +95,7 @@ def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict: if not path.exists(): return v return str(path.resolve()) - except (TypeError, ValueError): + except (TypeError, ValueError, OSError): return v return _walk_and_apply(value, _a_path, skip) diff --git a/src/numpydantic/testing/__init__.py b/src/numpydantic/testing/__init__.py new file mode 100644 index 0000000..1238838 --- /dev/null +++ b/src/numpydantic/testing/__init__.py @@ -0,0 +1,6 @@ +from numpydantic.testing.helpers import InterfaceCase, ValidationCase + +__all__ = [ + "InterfaceCase", + "ValidationCase", +] diff --git a/src/numpydantic/testing/cases.py b/src/numpydantic/testing/cases.py new file mode 100644 index 0000000..d0f44ee --- /dev/null +++ b/src/numpydantic/testing/cases.py @@ -0,0 +1,259 @@ +import sys +from typing import Union + +import numpy as np +from pydantic import BaseModel + +from numpydantic.dtype import Float, Integer, Number +from numpydantic.testing.helpers import ValidationCase, merged_product +from numpydantic.testing.interfaces import ( + DaskCase, + HDF5Case, + HDF5CompoundCase, + NumpyCase, + VideoCase, + ZarrCase, + ZarrDirCase, + ZarrNestedCase, + ZarrZipCase, +) + +if sys.version_info.minor >= 10: + from typing import TypeAlias + + YES_PIPE = True +else: + from typing_extensions import TypeAlias + + YES_PIPE = False + + +class BasicModel(BaseModel): + x: int + + +class BadModel(BaseModel): + x: int + + +class SubClass(BasicModel): + pass + + +# -------------------------------------------------- +# Annotations +# -------------------------------------------------- + +RGB_UNION = (("*", "*"), ("*", "*", 3), ("*", "*", 3, 4)) +UNION_TYPE: TypeAlias = Union[np.uint32, np.float32] + +SHAPE_CASES = ( + ValidationCase(shape=(10, 10, 2, 2), passes=True, id="valid shape"), + ValidationCase(shape=(10, 10, 2), passes=False, id="missing dimension"), + ValidationCase(shape=(10, 10, 2, 2, 2), passes=False, id="extra dimension"), + ValidationCase(shape=(11, 10, 2, 2), passes=False, id="dimension too large"), + ValidationCase(shape=(9, 10, 2, 2), passes=False, id="dimension too small"), + ValidationCase(shape=(10, 10, 1, 1), passes=True, id="wildcard smaller"), + ValidationCase(shape=(10, 10, 3, 3), passes=True, id="wildcard larger"), + ValidationCase( + annotation_shape=RGB_UNION, shape=(5, 5), passes=True, id="Union 2D" + ), + ValidationCase( + annotation_shape=RGB_UNION, shape=(5, 5, 3), passes=True, id="Union 3D" + ), + ValidationCase( + annotation_shape=RGB_UNION, shape=(5, 5, 3, 4), passes=True, id="Union 4D" + ), + ValidationCase( + annotation_shape=RGB_UNION, + shape=(5, 5, 4), + passes=False, + id="Union incorrect 3D", + ), + ValidationCase( + annotation_shape=RGB_UNION, + shape=(5, 5, 3, 6), + passes=False, + id="Union incorrect 4D", + ), + ValidationCase( + annotation_shape=RGB_UNION, + shape=(5, 5, 4, 6), + passes=False, + id="Union incorrect both", + ), +) +""" +Base Shape cases +""" + + +DTYPE_CASES = [ + ValidationCase(dtype=float, passes=True, id="float"), + ValidationCase(dtype=int, passes=False, id="int"), + ValidationCase(dtype=np.uint8, passes=False, id="uint8"), + ValidationCase(annotation_dtype=Number, dtype=int, passes=True, id="number-int"), + ValidationCase( + annotation_dtype=Number, dtype=float, passes=True, id="number-float" + ), + ValidationCase( + annotation_dtype=Number, dtype=np.uint8, passes=True, id="number-uint8" + ), + ValidationCase( + annotation_dtype=Number, dtype=np.float16, passes=True, id="number-float16" + ), + ValidationCase(annotation_dtype=Number, dtype=str, passes=False, id="number-str"), + ValidationCase(annotation_dtype=Integer, dtype=int, passes=True, id="integer-int"), + ValidationCase( + annotation_dtype=Integer, dtype=np.uint8, passes=True, id="integer-uint8" + ), + ValidationCase( + annotation_dtype=Integer, dtype=float, passes=False, id="integer-float" + ), + ValidationCase( + annotation_dtype=Integer, dtype=np.float32, passes=False, id="integer-float32" + ), + ValidationCase(annotation_dtype=Integer, dtype=str, passes=False, id="integer-str"), + ValidationCase(annotation_dtype=Float, dtype=float, passes=True, id="float-float"), + ValidationCase( + annotation_dtype=Float, dtype=np.float32, passes=True, id="float-float32" + ), + ValidationCase(annotation_dtype=Float, dtype=int, passes=False, id="float-int"), + ValidationCase( + annotation_dtype=Float, dtype=np.uint8, passes=False, id="float-uint8" + ), + ValidationCase(annotation_dtype=Float, dtype=str, passes=False, id="float-str"), + ValidationCase(annotation_dtype=str, dtype=str, passes=True, id="str-str"), + ValidationCase(annotation_dtype=str, dtype=int, passes=False, id="str-int"), + ValidationCase(annotation_dtype=str, dtype=float, passes=False, id="str-float"), + ValidationCase( + annotation_dtype=BasicModel, dtype=BasicModel, passes=True, id="model-model" + ), + ValidationCase( + annotation_dtype=BasicModel, dtype=BadModel, passes=False, id="model-badmodel" + ), + ValidationCase( + annotation_dtype=BasicModel, dtype=int, passes=False, id="model-int" + ), + ValidationCase( + annotation_dtype=BasicModel, dtype=SubClass, passes=True, id="model-subclass" + ), + ValidationCase( + annotation_dtype=UNION_TYPE, + dtype=np.uint32, + passes=True, + id="union-type-uint32", + ), + ValidationCase( + annotation_dtype=UNION_TYPE, + dtype=np.float32, + passes=True, + id="union-type-float32", + ), + ValidationCase( + annotation_dtype=UNION_TYPE, + dtype=np.uint64, + passes=False, + id="union-type-uint64", + ), + ValidationCase( + annotation_dtype=UNION_TYPE, + dtype=np.float64, + passes=False, + id="union-type-float64", + ), + ValidationCase( + annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str" + ), +] +""" +Base Dtype cases +""" + + +if YES_PIPE: + UNION_PIPE: TypeAlias = np.uint32 | np.float32 + + DTYPE_CASES.extend( + [ + ValidationCase( + annotation_dtype=UNION_PIPE, + dtype=np.uint32, + passes=True, + id="union-pipe-uint32", + ), + ValidationCase( + annotation_dtype=UNION_PIPE, + dtype=np.float32, + passes=True, + id="union-pipe-float32", + ), + ValidationCase( + annotation_dtype=UNION_PIPE, + dtype=np.uint64, + passes=False, + id="union-pipe-uint64", + ), + ValidationCase( + annotation_dtype=UNION_PIPE, + dtype=np.float64, + passes=False, + id="union-pipe-float64", + ), + ValidationCase( + annotation_dtype=UNION_PIPE, + dtype=str, + passes=False, + id="union-pipe-str", + ), + ] + ) + +INTERFACE_CASES = [ + ValidationCase(interface=NumpyCase, id="numpy"), + ValidationCase(interface=HDF5Case, id="hdf5"), + ValidationCase(interface=HDF5CompoundCase, id="hdf5_compound"), + ValidationCase(interface=DaskCase, id="dask"), + ValidationCase(interface=ZarrCase, id="zarr"), + ValidationCase(interface=ZarrDirCase, id="zarr_dir"), + ValidationCase(interface=ZarrZipCase, id="zarr_zip"), + ValidationCase(interface=ZarrNestedCase, id="zarr_nested"), + ValidationCase(interface=VideoCase, id="video"), +] +""" +All the interface cases +""" + + +DTYPE_AND_SHAPE_CASES = merged_product(SHAPE_CASES, DTYPE_CASES) +""" +Merged product of dtype and shape cases +""" +DTYPE_AND_SHAPE_CASES_PASSING = merged_product( + SHAPE_CASES, DTYPE_CASES, conditions={"passes": True} +) +""" +Merged product of dtype and shape cases that are valid +""" + +DTYPE_AND_INTERFACE_CASES = merged_product(INTERFACE_CASES, DTYPE_CASES) +""" +Merged product of dtype and interface cases +""" +DTYPE_AND_INTERFACE_CASES_PASSING = merged_product( + INTERFACE_CASES, DTYPE_CASES, conditions={"passes": True} +) +""" +Merged product of dtype and interface cases that pass +""" + +ALL_CASES = merged_product(SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES) +""" +Merged product of all cases - dtype, shape, and interface +""" +ALL_CASES_PASSING = merged_product( + SHAPE_CASES, DTYPE_CASES, INTERFACE_CASES, conditions={"passes": True} +) +""" +Merged product of all cases, but only those that pass +""" diff --git a/src/numpydantic/testing/helpers.py b/src/numpydantic/testing/helpers.py new file mode 100644 index 0000000..b337e7d --- /dev/null +++ b/src/numpydantic/testing/helpers.py @@ -0,0 +1,317 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence +from functools import reduce +from itertools import product +from operator import ior +from pathlib import Path +from typing import Generator, List, Literal, Optional, Tuple, Type, Union + +import numpy as np +from pydantic import BaseModel, ConfigDict, ValidationError, computed_field + +from numpydantic import NDArray, Shape +from numpydantic.dtype import Float +from numpydantic.interface import Interface +from numpydantic.types import DtypeType, NDArrayType + + +class InterfaceCase(ABC): + """ + An interface test helper that allows a given interface to generate and validate + arrays in one of its formats. + + Each instance of "interface test case" should be considered one of the + potentially multiple realizations of a given interface. + If an interface has multiple formats (eg. zarr's different `store` s), + then it should have several test helpers. + """ + + @property + @abstractmethod + def interface(self) -> Interface: + """The interface that this helper is for""" + + @classmethod + def array_from_case( + cls, case: "ValidationCase", path: Optional[Path] = None + ) -> Optional[NDArrayType]: + """ + Generate an array from the given validation case. + + Returns ``None`` if an array can't be generated for a specific case. + """ + return cls.make_array(shape=case.shape, dtype=case.dtype, path=path) + + @classmethod + @abstractmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> Optional[NDArrayType]: + """ + Make an array from a shape and dtype, and a path if needed + + Args: + shape: shape of the array + dtype: dtype of the array + path: Path, if needed to generate on disk + array: Rather than passing shape and dtype, pass a literal arraylike thing + """ + + @classmethod + def validate_case(cls, case: "ValidationCase", path: Path) -> bool: + """ + Validate a generated array against the annotation in the validation case. + + Kept in the InterfaceCase in case an interface has specific + needs aside from just validating against a model, but typically left as is. + + If an array can't be generated for a given case, returns `None` + so that the calling function can know to skip rather than fail the case. + + Raises exceptions if validation fails (or succeeds when it shouldn't) + + Args: + case (ValidationCase): The validation case to validate. + path (Path): Path to generate arrays into, if any. + + Returns: + ``True`` if array is valid and was supposed to be, + or invalid and wasn't supposed to be + """ + import pytest + + array = cls.array_from_case(case, path) + if array is None: + pytest.skip() + if case.passes: + case.model(array=array) + return True + else: + with pytest.raises(ValidationError): + case.model(array=array) + return True + + @classmethod + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: + """ + Whether a given interface should be skipped for the case + """ + # Assume an interface case is valid for all other cases + return False + + +_a_shape_type = Tuple[Union[int, Literal["*"], Literal["..."]], ...] + + +class ValidationCase(BaseModel): + """ + Test case for validating an array. + + Contains both the validating model and the parameterization for an array to + test in a given interface + """ + + id: Optional[str] = None + """ + String identifying the validation case + """ + annotation_shape: Union[ + Tuple[Union[int, str], ...], Tuple[Tuple[Union[int, str], ...], ...] + ] = (10, 10, "*", "*") + """ + Shape to use in computed annotation used to validate against + """ + annotation_dtype: Union[DtypeType, Sequence[DtypeType]] = Float + """ + Dtype to use in computed annotation used to validate against + """ + shape: Tuple[int, ...] = (10, 10, 2, 2) + """Shape of the array to validate""" + dtype: Union[Type, np.dtype] = float + """Dtype of the array to validate""" + passes: bool = False + """Whether the validation should pass or not""" + interface: Optional[Type[InterfaceCase]] = None + """The interface test case to generate and validate the array with""" + path: Optional[Path] = None + """The path to generate arrays into, if any.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @computed_field() + def annotation(self) -> NDArray: + """ + Annotation used in the model we validate against + """ + # make a union type if we need to + shape_union = all(isinstance(s, Sequence) for s in self.annotation_shape) + dtype_union = isinstance(self.annotation_dtype, Sequence) and all( + isinstance(s, Sequence) for s in self.annotation_dtype + ) + if shape_union or dtype_union: + shape_iter = ( + self.annotation_shape if shape_union else [self.annotation_shape] + ) + dtype_iter = ( + self.annotation_dtype if dtype_union else [self.annotation_dtype] + ) + annotations: List[type] = [] + for shape, dtype in product(shape_iter, dtype_iter): + shape_str = ", ".join([str(i) for i in shape]) + annotations.append(NDArray[Shape[shape_str], dtype]) + return Union[tuple(annotations)] + + else: + shape_str = ", ".join([str(i) for i in self.annotation_shape]) + return NDArray[Shape[shape_str], self.annotation_dtype] + + @computed_field() + def model(self) -> Type[BaseModel]: + """A model with a field ``array`` with the given annotation""" + annotation = self.annotation + + class Model(BaseModel): + array: annotation + + return Model + + def validate_case(self, path: Optional[Path] = None) -> bool: + """ + Whether the generated array correctly validated against the annotation, + given the interface + + Args: + path (:class:`pathlib.Path`): Directory to generate array into, if on disk. + + Raises: + ValueError: if an ``interface`` is missing + """ + if self.interface is None: # pragma: no cover + raise ValueError("Missing an interface") + if path is None: + if self.path: + path = self.path + else: # pragma: no cover + raise ValueError("Missing a path to generate arrays into") + + return self.interface.validate_case(self, path) + + def array(self, path: Path) -> NDArrayType: + """Generate an array for the validation case if we have an interface to do so""" + if self.interface is None: # pragma: no cover + raise ValueError("Missing an interface") + if path is None: # pragma: no cover + if self.path: + path = self.path + else: + raise ValueError("Missing a path to generate arrays into") + + return self.interface.array_from_case(self, path) + + def merge( + self, other: Union["ValidationCase", Sequence["ValidationCase"]] + ) -> "ValidationCase": + """ + Merge two validation cases + + Dump both, excluding any unset fields, and merge, preferring `other`. + + ``valid`` is ``True`` if and only if it is ``True`` in both. + """ + if isinstance(other, Sequence): + return merge_cases(self, *other) + else: + return merge_cases(self, other) + + def skip(self) -> bool: + """ + Whether this case should be skipped + (eg. due to the interface case being incompatible + with the requested dtype or shape) + """ + return bool( + self.interface is not None and self.interface.skip(self.shape, self.dtype) + ) + + +def merge_cases(*args: ValidationCase) -> ValidationCase: + """ + Merge multiple validation cases + """ + if len(args) == 1: # pragma: no cover + return args[0] + + dumped = [ + m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args + ] + + # self_dump = self.model_dump(exclude_unset=True) + # other_dump = other.model_dump(exclude_unset=True) + + # dumps might not have set `passes`, use only the ones that have + passes = [v.get("passes") for v in dumped if "passes" in v] + passes = all(passes) + + # combine ids if present + ids = "-".join([str(v.get("id")) for v in dumped if "id" in v]) + + # merge dicts + merged = reduce(ior, dumped, {}) + merged["passes"] = passes + merged["id"] = ids + return ValidationCase.model_construct(**merged) + + +def merged_product( + *args: Sequence[ValidationCase], conditions: dict = None +) -> Generator[ValidationCase, None, None]: + """ + Generator for the product of the iterators of validation cases, + merging each tuple, and respecting if they should be :meth:`.ValidationCase.skip` + or not. + + Examples: + + .. code-block:: python + + shape_cases = [ + ValidationCase(shape=(10, 10, 10), passes=True, id="valid shape"), + ValidationCase(shape=(10, 10), passes=False, id="missing dimension"), + ] + dtype_cases = [ + ValidationCase(dtype=float, passes=True, id="float"), + ValidationCase(dtype=int, passes=False, id="int"), + ] + + iterator = merged_product(shape_cases, dtype_cases)) + next(iterator) + # ValidationCase( + # shape=(10, 10, 10), + # dtype=float, + # passes=True, + # id="valid shape-float" + # ) + next(iterator) + # ValidationCase( + # shape=(10, 10, 10), + # dtype=int, + # passes=False, + # id="valid shape-int" + # ) + + + """ + iterator = product(*args) + for case_tuple in iterator: + case = merge_cases(*case_tuple) + if case.skip(): + continue + if conditions: + matching = all([getattr(case, k, None) == v for k, v in conditions.items()]) + if not matching: + continue + yield case diff --git a/src/numpydantic/testing/interfaces.py b/src/numpydantic/testing/interfaces.py new file mode 100644 index 0000000..a11bc6a --- /dev/null +++ b/src/numpydantic/testing/interfaces.py @@ -0,0 +1,295 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional, Tuple + +import cv2 +import dask.array as da +import h5py +import numpy as np +import zarr +from pydantic import BaseModel + +from numpydantic.interface import ( + DaskInterface, + H5ArrayPath, + H5Interface, + NumpyInterface, + VideoInterface, + ZarrArrayPath, + ZarrInterface, +) +from numpydantic.testing.helpers import InterfaceCase +from numpydantic.types import DtypeType, NDArrayType + + +class NumpyCase(InterfaceCase): + """In-memory numpy array""" + + interface = NumpyInterface + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> np.ndarray: + if array is not None: + return np.array(array, dtype=dtype) + elif issubclass(dtype, BaseModel): + return np.full(shape=shape, fill_value=dtype(x=1)) + else: + return np.zeros(shape=shape, dtype=dtype) + + +class _HDF5MetaCase(InterfaceCase): + """Base case for hdf5 cases""" + + interface = H5Interface + + @classmethod + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: + return issubclass(dtype, BaseModel) + + +class HDF5Case(_HDF5MetaCase): + """HDF5 Array""" + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> Optional[H5ArrayPath]: + if cls.skip(shape, dtype): # pragma: no cover + return None + + hdf5_file = path / "h5f.h5" + array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ + generator = np.random.default_rng() + + if array is not None: + data = np.array(array, dtype=dtype) + elif dtype is str: + data = generator.random(shape).astype(bytes) + elif dtype is datetime: + data = np.empty(shape, dtype="S32") + data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8")) + else: + data = generator.random(shape).astype(dtype) + + h5path = H5ArrayPath(hdf5_file, array_path) + + with h5py.File(hdf5_file, "w") as h5f: + _ = h5f.create_dataset(array_path, data=data) + return h5path + + +class HDF5CompoundCase(_HDF5MetaCase): + """HDF5 Array with a fake compound dtype""" + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> Optional[H5ArrayPath]: + if cls.skip(shape, dtype): # pragma: no cover + return None + + hdf5_file = path / "h5f.h5" + array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ + if array is not None: + data = np.array(array, dtype=dtype) + elif dtype is str: + dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) + data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape) + elif dtype is datetime: + dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")]) + data = np.array( + [(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)] + * np.prod(shape), + dtype=dt, + ).reshape(shape) + else: + dt = np.dtype([("data", dtype), ("extra", "i8")]) + data = np.zeros(shape, dtype=dt) + h5path = H5ArrayPath(hdf5_file, array_path, "data") + + with h5py.File(hdf5_file, "w") as h5f: + _ = h5f.create_dataset(array_path, data=data) + return h5path + + +class DaskCase(InterfaceCase): + """In-memory dask array""" + + interface = DaskInterface + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> da.Array: + if array is not None: + return da.array(array, dtype=dtype) + if issubclass(dtype, BaseModel): + return da.full(shape=shape, fill_value=dtype(x=1), chunks=-1) + else: + return da.zeros(shape=shape, dtype=dtype, chunks=10) + + +class _ZarrMetaCase(InterfaceCase): + """Shared classmethods for zarr cases""" + + interface = ZarrInterface + + @classmethod + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: + return issubclass(dtype, BaseModel) or dtype is str + + +class ZarrCase(_ZarrMetaCase): + """In-memory zarr array""" + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> Optional[zarr.Array]: + if array is not None: + return zarr.array(array, dtype=dtype, chunks=-1) + else: + return zarr.zeros(shape=shape, dtype=dtype) + + +class ZarrDirCase(_ZarrMetaCase): + """On-disk zarr array""" + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> Optional[zarr.Array]: + store = zarr.DirectoryStore(str(path / "array.zarr")) + if array is not None: + return zarr.array(array, dtype=dtype, store=store, chunks=-1) + else: + return zarr.zeros(shape=shape, dtype=dtype, store=store) + + +class ZarrZipCase(_ZarrMetaCase): + """Zarr zip store""" + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> Optional[zarr.Array]: + store = zarr.ZipStore(str(path / "array.zarr"), mode="w") + if array is not None: + return zarr.array(array, dtype=dtype, store=store, chunks=-1) + else: + return zarr.zeros(shape=shape, dtype=dtype, store=store) + + +class ZarrNestedCase(_ZarrMetaCase): + """Nested zarr array""" + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10), + dtype: DtypeType = float, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> ZarrArrayPath: + file = str(path / "nested.zarr") + root = zarr.open(file, mode="w") + subpath = "a/b/c" + if array is not None: + _ = root.array(subpath, array, dtype=dtype) + else: + _ = root.zeros(subpath, shape=shape, dtype=dtype) + return ZarrArrayPath(file=file, path=subpath) + + +class VideoCase(InterfaceCase): + """AVI video""" + + interface = VideoInterface + + @classmethod + def make_array( + cls, + shape: Tuple[int, ...] = (10, 10, 10, 3), + dtype: DtypeType = np.uint8, + path: Optional[Path] = None, + array: Optional[NDArrayType] = None, + ) -> Optional[Path]: + if cls.skip(shape, dtype): # pragma: no cover + return None + + if array is not None: + array = np.array(array, dtype=np.uint8) + shape = array.shape + + is_color = len(shape) == 4 + frames = shape[0] + frame_shape = shape[1:] + + video_path = path / "test.avi" + writer = cv2.VideoWriter( + str(video_path), + cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes + 30, + (frame_shape[1], frame_shape[0]), + is_color, + ) + for i in range(frames): + if array is not None: + frame = array[i] + else: + # make fresh array every time bc opencv eats them + frame = np.full(frame_shape, fill_value=i, dtype=np.uint8) + writer.write(frame) + writer.release() + return video_path + + @classmethod + def skip(cls, shape: Tuple[int, ...], dtype: DtypeType) -> bool: + """ + We really can only handle 4 dimensional cases in 8-bit rn lol + + .. todo:: + + Fix shape/writing for grayscale videos + + """ + if len(shape) != 4: + return True + + # if len(shape) < 3 or len(shape) > 4: + # return True + if dtype not in (int, np.uint8): + return True + # if we have a color video (ie. shape == 4, needs to be RGB) + if len(shape) == 4 and shape[3] != 3: + return True diff --git a/tests/conftest.py b/tests/conftest.py index 96f8a7e..669de7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,25 +1,9 @@ -import sys - import pytest -from typing import Any, Tuple, Union, Type - -from pydantic import BaseModel, computed_field, ConfigDict -from numpydantic import NDArray, Shape -from numpydantic.ndarray import NDArrayMeta -from numpydantic.dtype import Float, Number, Integer -import numpy as np +from numpydantic.testing.cases import DTYPE_CASES, SHAPE_CASES +from numpydantic.testing.helpers import ValidationCase from tests.fixtures import * -if sys.version_info.minor >= 10: - from typing import TypeAlias - - YES_PIPE = True -else: - from typing_extensions import TypeAlias - - YES_PIPE = False - def pytest_addoption(parser): parser.addoption( @@ -29,191 +13,19 @@ def pytest_addoption(parser): ) -class ValidationCase(BaseModel): - """ - Test case for validating an array. - - Contains both the validating model and the parameterization for an array to - test in a given interface - """ - - annotation: Any = NDArray[Shape["10, 10, *"], Float] - """ - Array annotation used in the validating model - Any typed because the types of type annotations are weird - """ - shape: Tuple[int, ...] = (10, 10, 10) - """Shape of the array to validate""" - dtype: Union[Type, np.dtype] = float - """Dtype of the array to validate""" - passes: bool - """Whether the validation should pass or not""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @computed_field() - def model(self) -> Type[BaseModel]: - """A model with a field ``array`` with the given annotation""" - annotation = self.annotation - - class Model(BaseModel): - array: annotation - - return Model - - -class BasicModel(BaseModel): - x: int - - -class BadModel(BaseModel): - x: int - - -class SubClass(BasicModel): - pass - - -RGB_UNION: TypeAlias = Union[ - NDArray[Shape["* x, * y"], Number], - NDArray[Shape["* x, * y, 3 r_g_b"], Number], - NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number], -] - -NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number] -INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] -FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] -STRING: TypeAlias = NDArray[Shape["*, *, *"], str] -MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] -UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]] -if YES_PIPE: - UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32] +@pytest.fixture( + scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES] +) +def shape_cases(request, tmp_output_dir_func) -> ValidationCase: + case: ValidationCase = request.param.model_copy() + case.path = tmp_output_dir_func + return case @pytest.fixture( - scope="module", - params=[ - ValidationCase(shape=(10, 10, 10), passes=True), - ValidationCase(shape=(10, 10), passes=False), - ValidationCase(shape=(10, 10, 10, 10), passes=False), - ValidationCase(shape=(11, 10, 10), passes=False), - ValidationCase(shape=(9, 10, 10), passes=False), - ValidationCase(shape=(10, 10, 9), passes=True), - ValidationCase(shape=(10, 10, 11), passes=True), - ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True), - ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True), - ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True), - ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False), - ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False), - ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False), - ], - ids=[ - "valid shape", - "missing dimension", - "extra dimension", - "dimension too large", - "dimension too small", - "wildcard smaller", - "wildcard larger", - "Union 2D", - "Union 3D", - "Union 4D", - "Union incorrect 3D", - "Union incorrect 4D", - "Union incorrect both", - ], + scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES] ) -def shape_cases(request) -> ValidationCase: - return request.param - - -DTYPE_CASES = [ - ValidationCase(dtype=float, passes=True), - ValidationCase(dtype=int, passes=False), - ValidationCase(dtype=np.uint8, passes=False), - ValidationCase(annotation=NUMBER, dtype=int, passes=True), - ValidationCase(annotation=NUMBER, dtype=float, passes=True), - ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True), - ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True), - ValidationCase(annotation=NUMBER, dtype=str, passes=False), - ValidationCase(annotation=INTEGER, dtype=int, passes=True), - ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True), - ValidationCase(annotation=INTEGER, dtype=float, passes=False), - ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False), - ValidationCase(annotation=INTEGER, dtype=str, passes=False), - ValidationCase(annotation=FLOAT, dtype=float, passes=True), - ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True), - ValidationCase(annotation=FLOAT, dtype=int, passes=False), - ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False), - ValidationCase(annotation=FLOAT, dtype=str, passes=False), - ValidationCase(annotation=STRING, dtype=str, passes=True), - ValidationCase(annotation=STRING, dtype=int, passes=False), - ValidationCase(annotation=STRING, dtype=float, passes=False), - ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), - ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), - ValidationCase(annotation=MODEL, dtype=int, passes=False), - ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), - ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True), - ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True), - ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False), - ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False), - ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False), -] - -DTYPE_IDS = [ - "float", - "int", - "uint8", - "number-int", - "number-float", - "number-uint8", - "number-float16", - "number-str", - "integer-int", - "integer-uint8", - "integer-float", - "integer-float32", - "integer-str", - "float-float", - "float-float32", - "float-int", - "float-uint8", - "float-str", - "str-str", - "str-int", - "str-float", - "model-model", - "model-badmodel", - "model-int", - "model-subclass", - "union-type-uint32", - "union-type-float32", - "union-type-uint64", - "union-type-float64", - "union-type-str", -] - -if YES_PIPE: - DTYPE_CASES.extend( - [ - ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True), - ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True), - ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False), - ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False), - ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False), - ] - ) - DTYPE_IDS.extend( - [ - "union-pipe-uint32", - "union-pipe-float32", - "union-pipe-uint64", - "union-pipe-float64", - "union-pipe-str", - ] - ) - - -@pytest.fixture(scope="module", params=DTYPE_CASES, ids=DTYPE_IDS) -def dtype_cases(request) -> ValidationCase: - return request.param +def dtype_cases(request, tmp_output_dir_func) -> ValidationCase: + case: ValidationCase = request.param.model_copy() + case.path = tmp_output_dir_func + return case diff --git a/tests/fixtures.py b/tests/fixtures.py deleted file mode 100644 index b090c7d..0000000 --- a/tests/fixtures.py +++ /dev/null @@ -1,197 +0,0 @@ -import shutil -from pathlib import Path -from typing import Any, Callable, Optional, Tuple, Type, Union -from warnings import warn -from datetime import datetime, timezone - -import h5py -import numpy as np -import pytest -from pydantic import BaseModel, Field -import zarr -import cv2 - -from numpydantic.interface.hdf5 import H5ArrayPath -from numpydantic.interface.zarr import ZarrArrayPath -from numpydantic import NDArray, Shape -from numpydantic.maps import python_to_nptyping -from numpydantic.dtype import Number - - -@pytest.fixture(scope="session") -def tmp_output_dir(request: pytest.FixtureRequest) -> Path: - path = Path(__file__).parent.resolve() / "__tmp__" - if path.exists(): - shutil.rmtree(str(path)) - path.mkdir() - - yield path - - if not request.config.getvalue("--with-output"): - try: - shutil.rmtree(str(path)) - except PermissionError as e: - # sporadic error on windows machines... - warn( - f"Temporary directory could not be removed due to a permissions error: \n{str(e)}" - ) - - -@pytest.fixture(scope="function") -def tmp_output_dir_func(tmp_output_dir, request: pytest.FixtureRequest) -> Path: - """ - tmp output dir that gets cleared between every function - cleans at the start rather than at cleanup in case the output is to be inspected - """ - subpath = tmp_output_dir / f"__tmpfunc_{request.node.name}__" - if subpath.exists(): - shutil.rmtree(str(subpath)) - subpath.mkdir() - return subpath - - -@pytest.fixture(scope="module") -def tmp_output_dir_mod(tmp_output_dir, request: pytest.FixtureRequest) -> Path: - """ - tmp output dir that gets cleared between every function - cleans at the start rather than at cleanup in case the output is to be inspected - """ - subpath = tmp_output_dir / f"__tmpmod_{request.module}__" - if subpath.exists(): - shutil.rmtree(str(subpath)) - subpath.mkdir() - return subpath - - -@pytest.fixture(scope="function") -def array_model() -> ( - Callable[[Tuple[int, ...], Union[Type, np.dtype]], Type[BaseModel]] -): - def _model( - shape: Tuple[int, ...] = (10, 10), dtype: Union[Type, np.dtype] = float - ) -> Type[BaseModel]: - shape_str = ", ".join([str(s) for s in shape]) - - class MyModel(BaseModel): - array: NDArray[Shape[shape_str], dtype] - - return MyModel - - return _model - - -@pytest.fixture(scope="session") -def model_rgb() -> Type[BaseModel]: - class RGB(BaseModel): - array: Optional[ - Union[ - NDArray[Shape["* x, * y"], Number], - NDArray[Shape["* x, * y, 3 r_g_b"], Number], - NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number], - ] - ] = Field(None) - - return RGB - - -@pytest.fixture(scope="session") -def model_blank() -> Type[BaseModel]: - """A model with any shape and dtype""" - - class BlankModel(BaseModel): - array: NDArray[Shape["*, ..."], Any] - - return BlankModel - - -@pytest.fixture(scope="function") -def hdf5_array( - request, tmp_output_dir_func -) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]: - hdf5_file = tmp_output_dir_func / "h5f.h5" - - def _hdf5_array( - shape: Tuple[int, ...] = (10, 10), - dtype: Union[np.dtype, type] = float, - compound: bool = False, - ) -> H5ArrayPath: - array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ - - if not compound: - if dtype is str: - data = np.random.random(shape).astype(bytes) - elif dtype is datetime: - data = np.empty(shape, dtype="S32") - data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8")) - else: - data = np.random.random(shape).astype(dtype) - - h5path = H5ArrayPath(hdf5_file, array_path) - else: - if dtype is str: - dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) - data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape) - elif dtype is datetime: - dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")]) - data = np.array( - [(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)] - * np.prod(shape), - dtype=dt, - ).reshape(shape) - else: - dt = np.dtype([("data", dtype), ("extra", "i8")]) - data = np.zeros(shape, dtype=dt) - h5path = H5ArrayPath(hdf5_file, array_path, "data") - - with h5py.File(hdf5_file, "w") as h5f: - _ = h5f.create_dataset(array_path, data=data) - return h5path - - return _hdf5_array - - -@pytest.fixture(scope="function") -def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath: - """Zarr array within a nested array""" - file = tmp_output_dir_func / "nested.zarr" - path = "a/b/c" - root = zarr.open(str(file), mode="w") - array = root.zeros(path, shape=(100, 100), chunks=(10, 10)) - return ZarrArrayPath(file=file, path=path) - - -@pytest.fixture(scope="function") -def zarr_array(tmp_output_dir_func) -> Path: - file = tmp_output_dir_func / "array.zarr" - array = zarr.open(str(file), mode="w", shape=(100, 100), chunks=(10, 10)) - array[:] = 0 - return file - - -@pytest.fixture(scope="function") -def avi_video(tmp_output_dir_func) -> Callable[[Tuple[int, int], int, bool], Path]: - video_path = tmp_output_dir_func / "test.avi" - - def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path: - writer = cv2.VideoWriter( - str(video_path), - cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes - 30, - (shape[1], shape[0]), - is_color, - ) - if is_color: - shape = (*shape, 3) - - for i in range(frames): - # make fresh array every time bc opencv eats them - array = np.zeros(shape, dtype=np.uint8) - if not is_color: - array[i, i] = i - else: - array[i, i, :] = i - writer.write(array) - writer.release() - return video_path - - return _make_video diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000..f080c4f --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1,3 @@ +from .generation import * +from .models import * +from .paths import * diff --git a/tests/fixtures/generation.py b/tests/fixtures/generation.py new file mode 100644 index 0000000..382b0b4 --- /dev/null +++ b/tests/fixtures/generation.py @@ -0,0 +1,63 @@ +from pathlib import Path +from typing import Callable, Tuple, Union + +import numpy as np +import pytest +import zarr + +from numpydantic.interface.hdf5 import H5ArrayPath +from numpydantic.interface.zarr import ZarrArrayPath +from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase, VideoCase + + +@pytest.fixture(scope="function") +def hdf5_array( + request, tmp_output_dir_func +) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]: + + def _hdf5_array( + shape: Tuple[int, ...] = (10, 10), + dtype: Union[np.dtype, type] = float, + compound: bool = False, + ) -> H5ArrayPath: + if compound: + array: H5ArrayPath = HDF5CompoundCase.make_array( + shape, dtype, tmp_output_dir_func + ) + return array + else: + return HDF5Case.make_array(shape, dtype, tmp_output_dir_func) + + return _hdf5_array + + +@pytest.fixture(scope="function") +def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath: + """Zarr array within a nested array""" + file = tmp_output_dir_func / "nested.zarr" + path = "a/b/c" + root = zarr.open(str(file), mode="w") + _ = root.zeros(path, shape=(100, 100), chunks=(10, 10)) + return ZarrArrayPath(file=file, path=path) + + +@pytest.fixture(scope="function") +def zarr_array(tmp_output_dir_func) -> Path: + file = tmp_output_dir_func / "array.zarr" + array = zarr.open(str(file), mode="w", shape=(100, 100), chunks=(10, 10)) + array[:] = 0 + return file + + +@pytest.fixture(scope="function") +def avi_video(tmp_output_dir_func) -> Callable[[Tuple[int, int], int, bool], Path]: + + def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path: + shape = (frames, *shape) + if is_color: + shape = (*shape, 3) + return VideoCase.make_array( + shape=shape, dtype=np.uint8, path=tmp_output_dir_func + ) + + return _make_video diff --git a/tests/fixtures/models.py b/tests/fixtures/models.py new file mode 100644 index 0000000..08f1ac3 --- /dev/null +++ b/tests/fixtures/models.py @@ -0,0 +1,49 @@ +from typing import Any, Callable, Optional, Tuple, Type, Union + +import numpy as np +import pytest +from pydantic import BaseModel, Field + +from numpydantic import NDArray, Shape +from numpydantic.dtype import Number + + +@pytest.fixture(scope="function") +def array_model() -> ( + Callable[[Tuple[int, ...], Union[Type, np.dtype]], Type[BaseModel]] +): + def _model( + shape: Tuple[int, ...] = (10, 10), dtype: Union[Type, np.dtype] = float + ) -> Type[BaseModel]: + shape_str = ", ".join([str(s) for s in shape]) + + class MyModel(BaseModel): + array: NDArray[Shape[shape_str], dtype] + + return MyModel + + return _model + + +@pytest.fixture(scope="session") +def model_rgb() -> Type[BaseModel]: + class RGB(BaseModel): + array: Optional[ + Union[ + NDArray[Shape["* x, * y"], Number], + NDArray[Shape["* x, * y, 3 r_g_b"], Number], + NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number], + ] + ] = Field(None) + + return RGB + + +@pytest.fixture(scope="session") +def model_blank() -> Type[BaseModel]: + """A model with any shape and dtype""" + + class BlankModel(BaseModel): + array: NDArray[Shape["*, ..."], Any] + + return BlankModel diff --git a/tests/fixtures/paths.py b/tests/fixtures/paths.py new file mode 100644 index 0000000..2a6b133 --- /dev/null +++ b/tests/fixtures/paths.py @@ -0,0 +1,51 @@ +import shutil +from _warnings import warn +from pathlib import Path + +import pytest + + +@pytest.fixture(scope="session") +def tmp_output_dir(request: pytest.FixtureRequest) -> Path: + path = Path(__file__).parents[1].resolve() / "__tmp__" + if path.exists(): + shutil.rmtree(str(path)) + path.mkdir() + + yield path + + if not request.config.getvalue("--with-output"): + try: + shutil.rmtree(str(path)) + except PermissionError as e: + # sporadic error on windows machines... + warn( + "Temporary directory could not be removed due to a permissions error: " + f"\n{str(e)}" + ) + + +@pytest.fixture(scope="function") +def tmp_output_dir_func(tmp_output_dir, request: pytest.FixtureRequest) -> Path: + """ + tmp output dir that gets cleared between every function + cleans at the start rather than at cleanup in case the output is to be inspected + """ + subpath = tmp_output_dir / f"__tmpfunc_{request.node.name}__" + if subpath.exists(): + shutil.rmtree(str(subpath)) + subpath.mkdir() + return subpath + + +@pytest.fixture(scope="module") +def tmp_output_dir_mod(tmp_output_dir, request: pytest.FixtureRequest) -> Path: + """ + tmp output dir that gets cleared between every function + cleans at the start rather than at cleanup in case the output is to be inspected + """ + subpath = tmp_output_dir / f"__tmpmod_{request.module}__" + if subpath.exists(): + shutil.rmtree(str(subpath)) + subpath.mkdir() + return subpath diff --git a/tests/test_interface/conftest.py b/tests/test_interface/conftest.py index 5d36fa5..0f1048a 100644 --- a/tests/test_interface/conftest.py +++ b/tests/test_interface/conftest.py @@ -1,80 +1,150 @@ import pytest -from typing import Callable, Tuple, Type -import numpy as np -import dask.array as da -import zarr -from pydantic import BaseModel - -from numpydantic import interface, NDArray +from numpydantic.testing.cases import ( + ALL_CASES, + ALL_CASES_PASSING, + DTYPE_AND_INTERFACE_CASES_PASSING, +) +from numpydantic.testing.helpers import InterfaceCase, ValidationCase, merge_cases +from numpydantic.testing.interfaces import ( + DaskCase, + HDF5Case, + NumpyCase, + VideoCase, + ZarrCase, + ZarrDirCase, + ZarrNestedCase, +) @pytest.fixture( scope="function", params=[ pytest.param( - ([[1, 2], [3, 4]], interface.NumpyInterface), - marks=pytest.mark.numpy, - id="numpy-list", - ), - pytest.param( - (np.zeros((3, 4)), interface.NumpyInterface), + NumpyCase, marks=pytest.mark.numpy, id="numpy", ), pytest.param( - ("hdf5_array", interface.H5Interface), + HDF5Case, marks=pytest.mark.hdf5, id="h5-array-path", ), pytest.param( - (da.random.random((10, 10)), interface.DaskInterface), + DaskCase, marks=pytest.mark.dask, id="dask", ), pytest.param( - (zarr.ones((10, 10)), interface.ZarrInterface), + ZarrCase, marks=pytest.mark.zarr, id="zarr-memory", ), pytest.param( - ("zarr_nested_array", interface.ZarrInterface), + ZarrNestedCase, marks=pytest.mark.zarr, id="zarr-nested", ), pytest.param( - ("zarr_array", interface.ZarrInterface), + ZarrDirCase, marks=pytest.mark.zarr, - id="zarr-array", - ), - pytest.param( - ("avi_video", interface.VideoInterface), marks=pytest.mark.video, id="video" + id="zarr-dir", ), + pytest.param(VideoCase, marks=pytest.mark.video, id="video"), ], ) -def interface_type(request) -> Tuple[NDArray, Type[interface.Interface]]: +def interface_cases(request) -> InterfaceCase: """ - Test cases for each interface's ``check`` method - each input should match the - provided interface and that interface only + Fixture for combinatoric tests across all interface cases """ - if isinstance(request.param[0], str): - return (request.getfixturevalue(request.param[0]), request.param[1]) - else: - return request.param + return request.param + + +@pytest.fixture( + params=( + pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name)) + for p in ALL_CASES + ) +) +def all_cases(interface_cases, request) -> ValidationCase: + """ + Combinatoric testing for all dtype, shape, and interface cases. + + This is a very expensive fixture! Only use it for core functionality + that we want to be sure is *very true* in every circumstance, + INCLUDING invalid combinations of annotations and arrays. + Typically, that means only use this in `test_interfaces.py` + """ + + case = merge_cases(request.param, ValidationCase(interface=interface_cases)) + if case.skip(): + pytest.skip() + return case + + +@pytest.fixture( + params=( + pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name)) + for p in ALL_CASES_PASSING + ) +) +def all_passing_cases(request) -> ValidationCase: + """ + Combinatoric testing for all dtype, shape, and interface cases, + but only the combinations that we expect to pass. + + This is a very expensive fixture! Only use it for core functionality + that we want to be sure is *very true* in every circumstance. + Typically, that means only use this in `test_interfaces.py` + """ + return request.param @pytest.fixture() -def all_interfaces(interface_type) -> BaseModel: +def all_cases_instance(all_cases, tmp_output_dir_func): """ - An instantiated version of each interface within a basemodel, - with the array in an `array` field + all_cases but with an instantiated model + Args: + all_cases: + + Returns: + """ - array, interface = interface_type - if isinstance(array, Callable): - array = array() - - class MyModel(BaseModel): - array: NDArray - - instance = MyModel(array=array) + array = all_cases.array(path=tmp_output_dir_func) + instance = all_cases.model(array=array) + return instance + + +@pytest.fixture() +def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func): + """ + all_cases but with an instantiated model + Args: + all_cases: + + Returns: + + """ + array = all_passing_cases.array(path=tmp_output_dir_func) + instance = all_passing_cases.model(array=array) + return instance + + +@pytest.fixture( + params=( + pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name)) + for p in DTYPE_AND_INTERFACE_CASES_PASSING + ) +) +def dtype_by_interface(request): + """ + Tests for all dtypes by all interfaces + """ + return request.param + + +@pytest.fixture() +def dtype_by_interface_instance(dtype_by_interface, tmp_output_dir_func): + array = dtype_by_interface.array(path=tmp_output_dir_func) + instance = dtype_by_interface.model(array=array) return instance diff --git a/tests/test_interface/test_dask.py b/tests/test_interface/test_dask.py index c3b70e0..24a3761 100644 --- a/tests/test_interface/test_dask.py +++ b/tests/test_interface/test_dask.py @@ -1,33 +1,14 @@ -import pytest import json import dask.array as da -from pydantic import BaseModel, ValidationError +import pytest from numpydantic.interface import DaskInterface -from numpydantic.exceptions import DtypeError, ShapeError - -from tests.conftest import ValidationCase +from numpydantic.testing.interfaces import DaskCase pytestmark = pytest.mark.dask -def dask_array(case: ValidationCase) -> da.Array: - if issubclass(case.dtype, BaseModel): - return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1) - else: - return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) - - -def _test_dask_case(case: ValidationCase): - array = dask_array(case) - if case.passes: - case.model(array=array) - else: - with pytest.raises((ValidationError, DtypeError, ShapeError)): - case.model(array=array) - - def test_dask_enabled(): """ We need dask to be available to run these tests :) @@ -35,21 +16,25 @@ def test_dask_enabled(): assert DaskInterface.enabled() -def test_dask_check(interface_type): - if interface_type[1] is DaskInterface: - assert DaskInterface.check(interface_type[0]) +def test_dask_check(interface_cases, tmp_output_dir_func): + array = interface_cases.make_array(path=tmp_output_dir_func) + + if interface_cases.interface is DaskInterface: + assert DaskInterface.check(array) else: - assert not DaskInterface.check(interface_type[0]) + assert not DaskInterface.check(array) @pytest.mark.shape def test_dask_shape(shape_cases): - _test_dask_case(shape_cases) + shape_cases.interface = DaskCase + shape_cases.validate_case() @pytest.mark.dtype def test_dask_dtype(dtype_cases): - _test_dask_case(dtype_cases) + dtype_cases.interface = DaskCase + dtype_cases.validate_case() @pytest.mark.serialization diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index bd94810..42d1a5b 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -1,61 +1,54 @@ import json -from datetime import datetime, timezone +from datetime import datetime from typing import Any import h5py -import pytest -from pydantic import BaseModel, ValidationError - import numpy as np +import pytest +from pydantic import BaseModel + from numpydantic import NDArray, Shape from numpydantic.interface import H5Interface from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy -from numpydantic.exceptions import DtypeError, ShapeError - -from tests.conftest import ValidationCase +from numpydantic.testing.interfaces import HDF5Case, HDF5CompoundCase pytestmark = pytest.mark.hdf5 -def hdf5_array_case( - case: ValidationCase, array_func, compound: bool = False -) -> H5ArrayPath: - """ - Args: - case: - array_func: ( the function returned from the `hdf5_array` fixture ) - - Returns: - - """ - if issubclass(case.dtype, BaseModel): - pytest.skip("hdf5 cant support arbitrary python objects") - return array_func(case.shape, case.dtype, compound) - - -def _test_hdf5_case(case: ValidationCase, array_func, compound: bool = False) -> None: - array = hdf5_array_case(case, array_func, compound) - if case.passes: - case.model(array=array) - else: - with pytest.raises((ValidationError, DtypeError, ShapeError)): - case.model(array=array) +@pytest.fixture( + params=[ + pytest.param(HDF5Case, id="hdf5"), + pytest.param(HDF5CompoundCase, id="hdf5-compound"), + ] +) +def hdf5_cases(request): + return request.param def test_hdf5_enabled(): assert H5Interface.enabled() -def test_hdf5_check(interface_type): - if interface_type[1] is H5Interface: - if interface_type[0].__name__ == "_hdf5_array": - interface_type = (interface_type[0](), interface_type[1]) - assert H5Interface.check(interface_type[0]) - if isinstance(interface_type[0], H5ArrayPath): - # also test that we can instantiate from a tuple like the H5ArrayPath - assert H5Interface.check((interface_type[0].file, interface_type[0].path)) +@pytest.mark.shape +def test_hdf5_shape(shape_cases, hdf5_cases): + shape_cases.interface = hdf5_cases + if shape_cases.skip(): + pytest.skip() + shape_cases.validate_case() + + +@pytest.mark.dtype +def test_hdf5_dtype(dtype_cases, hdf5_cases): + dtype_cases.interface = hdf5_cases + dtype_cases.validate_case() + + +def test_hdf5_check(interface_cases, tmp_output_dir_func): + array = interface_cases.make_array(path=tmp_output_dir_func) + if interface_cases.interface is H5Interface: + assert H5Interface.check(array) else: - assert not H5Interface.check(interface_type[0]) + assert not H5Interface.check(array) def test_hdf5_check_not_exists(): @@ -74,18 +67,6 @@ def test_hdf5_check_not_hdf5(tmp_path): assert not H5Interface.check(spec) -@pytest.mark.shape -@pytest.mark.parametrize("compound", [True, False]) -def test_hdf5_shape(shape_cases, hdf5_array, compound): - _test_hdf5_case(shape_cases, hdf5_array, compound) - - -@pytest.mark.dtype -@pytest.mark.parametrize("compound", [True, False]) -def test_hdf5_dtype(dtype_cases, hdf5_array, compound): - _test_hdf5_case(dtype_cases, hdf5_array, compound) - - def test_hdf5_dataset_not_exists(hdf5_array, model_blank): array = hdf5_array() with pytest.raises(ValueError) as e: @@ -221,10 +202,7 @@ def test_empty_dataset(dtype, tmp_path): Empty datasets shouldn't choke us during validation """ array_path = tmp_path / "test.h5" - if dtype in (str, datetime): - np_dtype = "S32" - else: - np_dtype = dtype + np_dtype = "S32" if dtype in (str, datetime) else dtype with h5py.File(array_path, "w") as h5f: _ = h5f.create_dataset(name="/data", dtype=np_dtype) diff --git a/tests/test_interface/test_interface_base.py b/tests/test_interface/test_interface_base.py index 0b99ae6..f82d0a5 100644 --- a/tests/test_interface/test_interface_base.py +++ b/tests/test_interface/test_interface_base.py @@ -6,18 +6,17 @@ for tests that should apply to all interfaces, use ``test_interfaces.py`` import gc from typing import Literal -import pytest import numpy as np +import pytest +from pydantic import ValidationError from numpydantic.interface import ( Interface, - JsonDict, InterfaceMark, - NumpyInterface, + JsonDict, MarkedJson, + NumpyInterface, ) -from pydantic import ValidationError - from numpydantic.interface.interface import V @@ -46,9 +45,7 @@ def interfaces(): @classmethod def check(cls, array): cls.checked = True - if isinstance(array, list): - return True - return False + return isinstance(array, list) @classmethod def enabled(cls) -> bool: @@ -94,7 +91,8 @@ def interfaces(): def test_interface_match_error(interfaces): """ - Test that `match` and `match_output` raises errors when no or multiple matches are found + Test that `match` and `match_output` raises errors when no or multiple matches + are found """ with pytest.raises(ValueError) as e: Interface.match([1, 2, 3]) diff --git a/tests/test_interface/test_interfaces.py b/tests/test_interface/test_interfaces.py index faec0d8..a368aee 100644 --- a/tests/test_interface/test_interfaces.py +++ b/tests/test_interface/test_interfaces.py @@ -2,87 +2,108 @@ Tests that should be applied to all interfaces """ -import pytest -from typing import Callable -from importlib.metadata import version import json +from importlib.metadata import version -import numpy as np import dask.array as da -from zarr.core import Array as ZarrArray +import numpy as np +import pytest from pydantic import BaseModel +from zarr.core import Array as ZarrArray from numpydantic.interface import Interface, InterfaceMark, MarkedJson +from numpydantic.testing.helpers import ValidationCase -def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool): +def _test_roundtrip(source: BaseModel, target: BaseModel): """Test model equality for roundtrip tests""" - if round_trip: - assert type(target.array) is type(source.array) - if isinstance(source.array, (np.ndarray, ZarrArray)): - assert np.array_equal(target.array, np.array(source.array)) - elif isinstance(source.array, da.Array): - assert np.all(da.equal(target.array, source.array)) - else: - assert target.array == source.array - assert target.array.dtype == source.array.dtype - else: + assert type(target.array) is type(source.array) + if isinstance(source.array, (np.ndarray, ZarrArray)): assert np.array_equal(target.array, np.array(source.array)) + elif isinstance(source.array, da.Array): + if target.array.dtype == object: + # object equality doesn't really work well with dask + # just check that the types match + target_type = type(target.array.ravel()[0].compute()) + source_type = type(source.array.ravel()[0].compute()) + assert target_type is source_type + else: + assert np.all(da.equal(target.array, source.array)) + else: + assert target.array == source.array + + assert target.array.dtype == source.array.dtype -def test_dunder_len(all_interfaces): +def test_dunder_len(interface_cases, tmp_output_dir_func): """ Each interface or proxy type should support __len__ """ - assert len(all_interfaces.array) == all_interfaces.array.shape[0] + case = ValidationCase(interface=interface_cases) + if interface_cases.interface.name == "video": + case.shape = (10, 10, 2, 3) + case.dtype = np.uint8 + case.annotation_dtype = np.uint8 + case.annotation_shape = (10, 10, "*", 3) + array = case.array(path=tmp_output_dir_func) + instance = case.model(array=array) + assert len(instance.array) == case.shape[0] -def test_interface_revalidate(all_interfaces): +def test_interface_revalidate(all_passing_cases_instance): """ An interface should revalidate with the output of its initial validation See: https://github.com/p2p-ld/numpydantic/pull/14 """ - _ = type(all_interfaces)(array=all_interfaces.array) + + _ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array) -def test_interface_rematch(interface_type): +@pytest.mark.xfail +def test_interface_rematch(interface_cases, tmp_output_dir_func): """ All interfaces should match the results of the object they return after validation """ - array, interface = interface_type - if isinstance(array, Callable): - array = array() + array = interface_cases.make_array(path=tmp_output_dir_func) - assert Interface.match(interface().validate(array)) is interface + assert ( + Interface.match(interface_cases.interface.validate(array)) + is interface_cases.interface + ) -def test_interface_to_numpy_array(all_interfaces): +def test_interface_to_numpy_array(dtype_by_interface_instance): """ All interfaces should be able to have the output of their validation stage coerced to a numpy array with np.array() """ - _ = np.array(all_interfaces.array) + _ = np.array(dtype_by_interface_instance.array) @pytest.mark.serialization -def test_interface_dump_json(all_interfaces): +def test_interface_dump_json(dtype_by_interface_instance): """ All interfaces should be able to dump to json """ - all_interfaces.model_dump_json() + dtype_by_interface_instance.model_dump_json() @pytest.mark.serialization -@pytest.mark.parametrize("round_trip", [True, False]) -def test_interface_roundtrip_json(all_interfaces, round_trip): +def test_interface_roundtrip_json(dtype_by_interface, tmp_output_dir_func): """ All interfaces should be able to roundtrip to and from json """ - dumped_json = all_interfaces.model_dump_json(round_trip=round_trip) - model = all_interfaces.model_validate_json(dumped_json) - _test_roundtrip(all_interfaces, model, round_trip) + if "subclass" in dtype_by_interface.id.lower(): + pytest.xfail() + + array = dtype_by_interface.array(path=tmp_output_dir_func) + case = dtype_by_interface.model(array=array) + + dumped_json = case.model_dump_json(round_trip=True) + model = case.model_validate_json(dumped_json) + _test_roundtrip(case, model) @pytest.mark.serialization @@ -101,15 +122,20 @@ def test_interface_mark_interface(an_interface): @pytest.mark.serialization @pytest.mark.parametrize("valid", [True, False]) -@pytest.mark.parametrize("round_trip", [True, False]) @pytest.mark.filterwarnings("ignore:Mismatch between serialized mark") -def test_interface_mark_roundtrip(all_interfaces, valid, round_trip): +def test_interface_mark_roundtrip(dtype_by_interface, valid, tmp_output_dir_func): """ All interfaces should be able to roundtrip with the marked interface, and a mismatch should raise a warning and attempt to proceed """ - dumped_json = all_interfaces.model_dump_json( - round_trip=round_trip, context={"mark_interface": True} + if "subclass" in dtype_by_interface.id.lower(): + pytest.xfail() + + array = dtype_by_interface.array(path=tmp_output_dir_func) + case = dtype_by_interface.model(array=array) + + dumped_json = case.model_dump_json( + round_trip=True, context={"mark_interface": True} ) data = json.loads(dumped_json) @@ -123,8 +149,8 @@ def test_interface_mark_roundtrip(all_interfaces, valid, round_trip): dumped_json = json.dumps(data) with pytest.warns(match="Mismatch.*"): - model = all_interfaces.model_validate_json(dumped_json) + model = case.model_validate_json(dumped_json) else: - model = all_interfaces.model_validate_json(dumped_json) + model = case.model_validate_json(dumped_json) - _test_roundtrip(all_interfaces, model, round_trip) + _test_roundtrip(case, model) diff --git a/tests/test_interface/test_numpy.py b/tests/test_interface/test_numpy.py index bfb4c4d..a2a6f24 100644 --- a/tests/test_interface/test_numpy.py +++ b/tests/test_interface/test_numpy.py @@ -1,37 +1,21 @@ import numpy as np import pytest -from pydantic import ValidationError, BaseModel -from numpydantic.exceptions import DtypeError, ShapeError -from tests.conftest import ValidationCase +from numpydantic.testing.cases import NumpyCase pytestmark = pytest.mark.numpy -def numpy_array(case: ValidationCase) -> np.ndarray: - if issubclass(case.dtype, BaseModel): - return np.full(shape=case.shape, fill_value=case.dtype(x=1)) - else: - return np.zeros(shape=case.shape, dtype=case.dtype) - - -def _test_np_case(case: ValidationCase): - array = numpy_array(case) - if case.passes: - case.model(array=array) - else: - with pytest.raises((ValidationError, DtypeError, ShapeError)): - case.model(array=array) - - @pytest.mark.shape def test_numpy_shape(shape_cases): - _test_np_case(shape_cases) + shape_cases.interface = NumpyCase + shape_cases.validate_case() @pytest.mark.dtype def test_numpy_dtype(dtype_cases): - _test_np_case(dtype_cases) + dtype_cases.interface = NumpyCase + dtype_cases.validate_case() def test_numpy_coercion(model_blank): diff --git a/tests/test_interface/test_video.py b/tests/test_interface/test_video.py index 5f03a57..44bf2a5 100644 --- a/tests/test_interface/test_video.py +++ b/tests/test_interface/test_video.py @@ -2,12 +2,10 @@ Needs to be refactored to DRY, but works for now """ -import numpy as np -import pytest - from pathlib import Path -import cv2 +import cv2 +import pytest from pydantic import BaseModel, ValidationError from numpydantic import NDArray, Shape @@ -65,7 +63,7 @@ def test_video_wrong_shape(avi_video): # should correctly validate :) with pytest.raises(ValidationError): - instance = MyModel(array=vid) + _ = MyModel(array=vid) @pytest.mark.proxy @@ -82,15 +80,12 @@ def test_video_getitem(avi_video): instance = MyModel(array=vid) fifth_frame = instance.array[5] - # the first frame should have 1's in the 1,1 position + # the fifth frame should be all 5s assert (fifth_frame[5, 5, :] == [5, 5, 5]).all() - # and nothing in the 6th position - assert (fifth_frame[6, 6, :] == [0, 0, 0]).all() # slicing should also work as if it were just a numpy array single_slice = instance.array[3, 0:10, 0:5] assert single_slice[3, 3, 0] == 3 - assert single_slice[4, 4, 0] == 0 assert single_slice.shape == (10, 5, 3) # also get a range of frames @@ -98,19 +93,19 @@ def test_video_getitem(avi_video): range_slice = instance.array[3:5] assert range_slice.shape == (2, 100, 50, 3) assert range_slice[0, 3, 3, 0] == 3 - assert range_slice[0, 4, 4, 0] == 0 + assert range_slice[1, 4, 4, 0] == 4 # full range range_slice = instance.array[3:5, 0:10, 0:5] assert range_slice.shape == (2, 10, 5, 3) assert range_slice[0, 3, 3, 0] == 3 - assert range_slice[0, 4, 4, 0] == 0 + assert range_slice[1, 4, 4, 0] == 4 # starting range range_slice = instance.array[6:, 0:10, 0:10] assert range_slice.shape == (4, 10, 10, 3) assert range_slice[-1, 9, 9, 0] == 9 - assert range_slice[-2, 9, 9, 0] == 0 + assert range_slice[-2, 9, 9, 0] == 8 # ending range range_slice = instance.array[:3, 0:5, 0:5] @@ -121,10 +116,8 @@ def test_video_getitem(avi_video): # second slice should be the second frame (instead of the first) assert range_slice.shape == (3, 6, 6, 3) assert range_slice[1, 2, 2, 0] == 2 - assert range_slice[1, 3, 3, 0] == 0 # and the third should be the fourth (instead of the second) assert range_slice[2, 4, 4, 0] == 4 - assert range_slice[2, 5, 5, 0] == 0 with pytest.raises(NotImplementedError): # shouldn't be allowed to set diff --git a/tests/test_interface/test_zarr.py b/tests/test_interface/test_zarr.py index 6b21b20..f6df2f7 100644 --- a/tests/test_interface/test_zarr.py +++ b/tests/test_interface/test_zarr.py @@ -1,61 +1,21 @@ import json +import numpy as np import pytest -import zarr - -from pydantic import BaseModel, ValidationError -from numcodecs import Pickle from numpydantic.interface import ZarrInterface from numpydantic.interface.zarr import ZarrArrayPath -from numpydantic.exceptions import DtypeError, ShapeError - -from tests.conftest import ValidationCase +from numpydantic.testing.cases import ZarrCase, ZarrDirCase, ZarrNestedCase, ZarrZipCase +from numpydantic.testing.helpers import InterfaceCase pytestmark = pytest.mark.zarr -@pytest.fixture() -def dir_array(tmp_output_dir_func) -> zarr.DirectoryStore: - store = zarr.DirectoryStore(tmp_output_dir_func / "array.zarr") - return store - - -@pytest.fixture() -def zip_array(tmp_output_dir_func) -> zarr.ZipStore: - store = zarr.ZipStore(tmp_output_dir_func / "array.zip", mode="w") - return store - - -@pytest.fixture() -def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore: - store = zarr.NestedDirectoryStore(tmp_output_dir_func / "nested") - return store - - -def _zarr_array(case: ValidationCase, store) -> zarr.core.Array: - if issubclass(case.dtype, BaseModel): - pytest.skip( - f"Zarr can't handle objects properly at the moment, " - "see https://github.com/zarr-developers/zarr-python/issues/2081" - ) - # return zarr.full( - # shape=case.shape, - # fill_value=case.dtype(x=1), - # dtype=object, - # object_codec=Pickle(), - # ) - else: - return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) - - -def _test_zarr_case(case: ValidationCase, store): - array = _zarr_array(case, store) - if case.passes: - case.model(array=array) - else: - with pytest.raises((ValidationError, DtypeError, ShapeError)): - case.model(array=array) +@pytest.fixture( + params=[ZarrCase, ZarrZipCase, ZarrDirCase, ZarrNestedCase], +) +def zarr_case(request) -> InterfaceCase: + return request.param @pytest.fixture( @@ -78,24 +38,29 @@ def test_zarr_enabled(): assert ZarrInterface.enabled() -def test_zarr_check(interface_type): +def test_zarr_check(interface_cases, tmp_output_dir_func): """ We should only use the zarr interface for zarr-like things """ - if interface_type[1] is ZarrInterface: - assert ZarrInterface.check(interface_type[0]) + array = interface_cases.make_array(path=tmp_output_dir_func) + if interface_cases.interface is ZarrInterface: + assert ZarrInterface.check(array) else: - assert not ZarrInterface.check(interface_type[0]) + assert not ZarrInterface.check(array) @pytest.mark.shape -def test_zarr_shape(store, shape_cases): - _test_zarr_case(shape_cases, store) +def test_zarr_shape(shape_cases, zarr_case): + shape_cases.interface = zarr_case + shape_cases.validate_case() @pytest.mark.dtype -def test_zarr_dtype(dtype_cases, store): - _test_zarr_case(dtype_cases, store) +def test_zarr_dtype(dtype_cases, zarr_case): + dtype_cases.interface = zarr_case + if dtype_cases.skip(): + pytest.skip() + dtype_cases.validate_case() @pytest.mark.parametrize("array", ["zarr_nested_array", "zarr_array"]) @@ -103,14 +68,14 @@ def test_zarr_from_tuple(array, model_blank, request): """Should be able to do the same validation logic from tuples as an input""" array = request.getfixturevalue(array) if isinstance(array, ZarrArrayPath): - instance = model_blank(array=(array.file, array.path)) + _ = model_blank(array=(array.file, array.path)) else: - instance = model_blank(array=(array,)) + _ = model_blank(array=(array,)) def test_zarr_from_path(zarr_array, model_blank): """Should be able to just pass a path""" - instance = model_blank(array=zarr_array) + _ = model_blank(array=zarr_array) def test_zarr_array_path_from_iterable(zarr_array): @@ -129,7 +94,7 @@ def test_zarr_array_path_from_iterable(zarr_array): @pytest.mark.serialization @pytest.mark.parametrize("dump_array", [True, False]) @pytest.mark.parametrize("roundtrip", [True, False]) -def test_zarr_to_json(store, model_blank, roundtrip, dump_array): +def test_zarr_to_json(zarr_case, model_blank, roundtrip, dump_array, tmp_path): expected_fields = ( "Type", "Data type", @@ -139,9 +104,9 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array): "Store type", "hexdigest", ) - lol_array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + lol_array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=int) - array = zarr.array(lol_array, store=store) + array = zarr_case.make_array(array=lol_array, dtype=int, path=tmp_path) instance = model_blank(array=array) context = {"dump_array": dump_array} @@ -151,7 +116,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array): if roundtrip: if dump_array: - assert as_json["value"] == lol_array + assert np.array_equal(as_json["value"], lol_array) else: if as_json.get("file", False): assert "array" not in as_json @@ -161,4 +126,4 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array): assert len(as_json["info"]["hexdigest"]) == 40 else: - assert as_json == lol_array + assert np.array_equal(as_json, lol_array) diff --git a/tests/test_meta.py b/tests/test_meta.py index c137633..c5d2a51 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -1,4 +1,5 @@ import sys + import pytest from numpydantic import NDArray @@ -40,4 +41,4 @@ def test_stub_revealed_type(): """ Check that the revealed type matches the stub """ - type = reveal_type(NDArray) + _ = reveal_type(NDArray) diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index cda092c..18d035b 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -1,16 +1,13 @@ -import pytest - -from typing import Union, Optional, Any import json +from typing import Any, Optional, Union import numpy as np -from pydantic import BaseModel, ValidationError, Field +import pytest +from pydantic import BaseModel, Field, ValidationError - -from numpydantic import NDArray, Shape -from numpydantic.exceptions import ShapeError, DtypeError -from numpydantic import dtype +from numpydantic import NDArray, Shape, dtype from numpydantic.dtype import Number +from numpydantic.exceptions import DtypeError @pytest.mark.json_schema @@ -28,15 +25,15 @@ def test_ndarray_type(): assert schema["properties"]["array"]["minItems"] == 2 # models should instantiate correctly! - instance = Model(array=np.zeros((2, 3))) + _ = Model(array=np.zeros((2, 3))) with pytest.raises(ValidationError): - instance = Model(array=np.zeros((4, 6))) + _ = Model(array=np.zeros((4, 6))) with pytest.raises(ValidationError): - instance = Model(array=np.ones((2, 3), dtype=bool)) + _ = Model(array=np.ones((2, 3), dtype=bool)) - instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5))) + _ = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5))) @pytest.mark.dtype @@ -93,6 +90,8 @@ def test_schema_number(): def test_ndarray_union(): + generator = np.random.default_rng() + class Model(BaseModel): array: Optional[ Union[ @@ -102,22 +101,22 @@ def test_ndarray_union(): ] ] = Field(None) - instance = Model() - instance = Model(array=np.random.random((5, 10))) - instance = Model(array=np.random.random((5, 10, 3))) - instance = Model(array=np.random.random((5, 10, 3, 4))) + _ = Model() + _ = Model(array=generator.random((5, 10))) + _ = Model(array=generator.random((5, 10, 3))) + _ = Model(array=generator.random((5, 10, 3, 4))) with pytest.raises(ValidationError): - instance = Model(array=np.random.random((5,))) + _ = Model(array=generator.random((5,))) with pytest.raises(ValidationError): - instance = Model(array=np.random.random((5, 10, 4))) + _ = Model(array=generator.random((5, 10, 4))) with pytest.raises(ValidationError): - instance = Model(array=np.random.random((5, 10, 3, 6))) + _ = Model(array=generator.random((5, 10, 3, 6))) with pytest.raises(ValidationError): - instance = Model(array=np.random.random((5, 10, 4, 6))) + _ = Model(array=generator.random((5, 10, 4, 6))) @pytest.mark.shape @@ -127,15 +126,16 @@ def test_ndarray_unparameterized(dtype): """ NDArray without any parameters is any shape, any type """ + generator = np.random.default_rng() class Model(BaseModel): array: NDArray # not very sophisticated fuzzing of "any shape" test_cases = 10 - for i in range(test_cases): - n_dimensions = np.random.randint(1, 8) - dim_sizes = np.random.randint(1, 7, size=n_dimensions) + for _ in range(test_cases): + n_dimensions = generator.integers(1, 8) + dim_sizes = generator.integers(1, 7, size=n_dimensions) _ = Model(array=np.zeros(dim_sizes, dtype=dtype)) @@ -144,15 +144,16 @@ def test_ndarray_any(): """ using :class:`typing.Any` in for the shape means any shape """ + generator = np.random.default_rng() class Model(BaseModel): array: NDArray[Any, np.uint8] # not very sophisticated fuzzing of "any shape" test_cases = 100 - for i in range(test_cases): - n_dimensions = np.random.randint(1, 8) - dim_sizes = np.random.randint(1, 16, size=n_dimensions) + for _ in range(test_cases): + n_dimensions = generator.integers(1, 8) + dim_sizes = generator.integers(1, 16, size=n_dimensions) _ = Model(array=np.zeros(dim_sizes, dtype=np.uint8)) @@ -191,7 +192,7 @@ def test_ndarray_serialize(): class Model(BaseModel): array: NDArray[Any, Number] - mod = Model(array=np.random.random((3, 3))) + mod = Model(array=np.random.default_rng().random((3, 3))) mod_str = mod.model_dump_json() mod_json = json.loads(mod_str) assert isinstance(mod_json["array"], list) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index a42ee7c..17f2cfe 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -3,14 +3,15 @@ Test serialization-specific functionality that doesn't need to be applied across every interface (use test_interface/test_interfaces for that """ -import h5py -import pytest +import json from pathlib import Path from typing import Callable -import numpy as np -import json -from numpydantic.serialization import _walk_and_apply, _relativize_paths, relative_path +import h5py +import numpy as np +import pytest + +from numpydantic.serialization import _relativize_paths, _walk_and_apply, relative_path pytestmark = pytest.mark.serialization @@ -115,7 +116,8 @@ def test_absolute_path(hdf5_at_path, tmp_output_dir, model_blank): def test_walk_and_apply(): """ - Walk and apply should recursively apply a function to everything in a nesty structure + Walk and apply should recursively apply a function to everything in a + nesty structure """ test = { "a": 1, diff --git a/tests/test_shape.py b/tests/test_shape.py index b521054..c2ce279 100644 --- a/tests/test_shape.py +++ b/tests/test_shape.py @@ -1,9 +1,8 @@ -import pytest - from typing import Any -from pydantic import BaseModel, ValidationError import numpy as np +import pytest +from pydantic import BaseModel, ValidationError from numpydantic import NDArray, Shape diff --git a/tests/test_testing_helpers.py b/tests/test_testing_helpers.py new file mode 100644 index 0000000..3f93a41 --- /dev/null +++ b/tests/test_testing_helpers.py @@ -0,0 +1,60 @@ +""" +Tests for the testing helpers lmao +""" + +import numpy as np +import pytest +from pydantic import BaseModel + +from numpydantic import NDArray, Shape +from numpydantic.testing.cases import INTERFACE_CASES +from numpydantic.testing.helpers import ValidationCase +from numpydantic.testing.interfaces import NumpyCase + + +def test_validation_case_merge(): + case_1 = ValidationCase(id="1", interface=NumpyCase, passes=False) + case_2 = ValidationCase(id="2", dtype=str, passes=True) + case_3 = ValidationCase(id="3", shape=(1, 2, 3), passes=True) + + merged_simple = case_2.merge(case_3) + assert merged_simple.dtype == case_2.dtype + assert merged_simple.shape == case_3.shape + + merged_multi = case_1.merge([case_2, case_3]) + assert merged_multi.dtype == case_2.dtype + assert merged_multi.shape == case_3.shape + assert merged_multi.interface == case_1.interface + + # passes should be true only if all the cases are + assert merged_simple.passes + assert not merged_multi.passes + + # ids should merge + assert merged_simple.id == "2-3" + assert merged_multi.id == "1-2-3" + + +@pytest.mark.parametrize( + "interface", + [ + pytest.param( + i.interface, marks=getattr(pytest.mark, i.interface.interface.name) + ) + for i in INTERFACE_CASES + if i.id not in ("hdf5_compound") + ], +) +def test_make_array(interface, tmp_output_dir_func): + """ + An interface case can generate an array from params or a given array + + Not testing correctness here, that's what hte rest of the testing does. + """ + arr = np.zeros((10, 10, 2, 3), dtype=np.uint8) + arr = interface.make_array(array=arr, dtype=np.uint8, path=tmp_output_dir_func) + + class MyModel(BaseModel): + array: NDArray[Shape["10, 10, 2, 3"], np.uint8] + + _ = MyModel(array=arr)