Merge pull request #39 from p2p-ld/bugfix-union-dtypes

Fix JSON Schema generation for union dtypes
This commit is contained in:
Jonny Saunders 2024-12-13 17:33:52 -08:00 committed by GitHub
commit d54698fc0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 97 additions and 22 deletions

View file

@ -15,20 +15,28 @@ jobs:
matrix: matrix:
platform: ["ubuntu-latest", "macos-latest", "windows-latest"] platform: ["ubuntu-latest", "macos-latest", "windows-latest"]
numpy-version: ["<2.0.0", ">=2.0.0"] numpy-version: ["<2.0.0", ">=2.0.0"]
python-version: ["3.9", "3.10", "3.11", "3.12"] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
exclude: exclude:
- numpy-version: "<2.0.0" - numpy-version: "<2.0.0"
python-version: "3.10" python-version: "3.10"
- numpy-version: "<2.0.0" - numpy-version: "<2.0.0"
python-version: "3.11" python-version: "3.11"
# let's call python 3.12 the last version we're going to support with numpy <2
# they don't provide wheels for <2 in 3.13 and beyond.
- numpy-version: "<2.0.0"
python-version: "3.13"
- platform: "macos-latest" - platform: "macos-latest"
python-version: "3.10" python-version: "3.10"
- platform: "macos-latest" - platform: "macos-latest"
python-version: "3.11" python-version: "3.11"
- platform: "macos-latest"
python-version: "3.12"
- platform: "windows-latest" - platform: "windows-latest"
python-version: "3.10" python-version: "3.10"
- platform: "windows-latest" - platform: "windows-latest"
python-version: "3.11" python-version: "3.11"
- platform: "windows-latest"
python-version: "3.12"
runs-on: ${{ matrix.platform }} runs-on: ${{ matrix.platform }}

View file

@ -1,5 +1,21 @@
# Changelog # Changelog
## Upcoming
**Bugfix**
- [#38](https://github.com/p2p-ld/numpydantic/issues/38), [#39](https://github.com/p2p-ld/numpydantic/pull/39) -
- JSON Schema generation failed when the `dtype` was embedded from dtypes that lack a `__name__` attribute.
An additional check was added for presence of `__name__` when embedding.
- `NDArray` types were incorrectly cached s.t. pipe-union dtypes were considered equivalent to `Union[]`
dtypes. An additional tuple with the type of the args was added to the cache key to disambiguate them.
**Testing**
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Test that all combinations of shapes, dtypes, and interfaces
can generate JSON schema.
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add python 3.13 to the testing matrix.
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add an additional `marks` field to ValidationCase
for finer-grained control over running tests.
## 1.* ## 1.*
### 1.6.* ### 1.6.*

View file

@ -5,7 +5,7 @@
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"] groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
strategy = ["cross_platform", "inherit_metadata"] strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.5.0" lock_version = "4.5.0"
content_hash = "sha256:cc2b0fb32896c6df0ad747ddb5dee89af22f5c4c4643ee7a52db47fef30da936" content_hash = "sha256:89ac87e811ecc42bf5117e9c9e4aa6a69011cb7c5c1a630fbfb2643b0045c526"
[[metadata.targets]] [[metadata.targets]]
requires_python = "~=3.9" requires_python = "~=3.9"

View file

@ -129,6 +129,8 @@ markers = [
"numpy: numpy interface", "numpy: numpy interface",
"video: video interface", "video: video interface",
"zarr: zarr interface", "zarr: zarr interface",
"union: union dtypes",
"pipe_union: union dtypes specified with a pipe",
] ]
[tool.black] [tool.black]

View file

@ -204,9 +204,15 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
json_schema = handler(schema["metadata"]) json_schema = handler(schema["metadata"])
json_schema = handler.resolve_ref_schema(json_schema) json_schema = handler.resolve_ref_schema(json_schema)
if not isinstance(dtype, tuple) and dtype.__module__ not in ( if (
not isinstance(dtype, tuple)
and dtype.__module__
not in (
"builtins", "builtins",
"typing", "typing",
"types",
)
and hasattr(dtype, "__name__")
): ):
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__]) json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])

View file

@ -143,27 +143,35 @@ DTYPE_CASES = [
dtype=np.uint32, dtype=np.uint32,
passes=True, passes=True,
id="union-type-uint32", id="union-type-uint32",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, annotation_dtype=UNION_TYPE,
dtype=np.float32, dtype=np.float32,
passes=True, passes=True,
id="union-type-float32", id="union-type-float32",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, annotation_dtype=UNION_TYPE,
dtype=np.uint64, dtype=np.uint64,
passes=False, passes=False,
id="union-type-uint64", id="union-type-uint64",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, annotation_dtype=UNION_TYPE,
dtype=np.float64, dtype=np.float64,
passes=False, passes=False,
id="union-type-float64", id="union-type-float64",
marks={"union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str" annotation_dtype=UNION_TYPE,
dtype=str,
passes=False,
id="union-type-str",
marks={"union"},
), ),
] ]
""" """
@ -181,30 +189,35 @@ if YES_PIPE:
dtype=np.uint32, dtype=np.uint32,
passes=True, passes=True,
id="union-pipe-uint32", id="union-pipe-uint32",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=np.float32, dtype=np.float32,
passes=True, passes=True,
id="union-pipe-float32", id="union-pipe-float32",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=np.uint64, dtype=np.uint64,
passes=False, passes=False,
id="union-pipe-uint64", id="union-pipe-uint64",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=np.float64, dtype=np.float64,
passes=False, passes=False,
id="union-pipe-float64", id="union-pipe-float64",
marks={"union", "pipe_union"},
), ),
ValidationCase( ValidationCase(
annotation_dtype=UNION_PIPE, annotation_dtype=UNION_PIPE,
dtype=str, dtype=str,
passes=False, passes=False,
id="union-pipe-str", id="union-pipe-str",
marks={"union", "pipe_union"},
), ),
] ]
) )

View file

@ -4,16 +4,19 @@ from functools import reduce
from itertools import product from itertools import product
from operator import ior from operator import ior
from pathlib import Path from pathlib import Path
from typing import Generator, List, Literal, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Generator, List, Literal, Optional, Tuple, Type, Union
import numpy as np import numpy as np
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
from numpydantic.dtype import Float from numpydantic.dtype import Float
from numpydantic.interface import Interface from numpydantic.interface import Interface
from numpydantic.types import DtypeType, NDArrayType from numpydantic.types import DtypeType, NDArrayType
if TYPE_CHECKING:
from _pytest.mark.structures import MarkDecorator
class InterfaceCase(ABC): class InterfaceCase(ABC):
""" """
@ -139,6 +142,8 @@ class ValidationCase(BaseModel):
"""The interface test case to generate and validate the array with""" """The interface test case to generate and validate the array with"""
path: Optional[Path] = None path: Optional[Path] = None
"""The path to generate arrays into, if any.""" """The path to generate arrays into, if any."""
marks: set[str] = Field(default_factory=set)
"""pytest marks to set for this test case"""
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -179,6 +184,19 @@ class ValidationCase(BaseModel):
return Model return Model
@property
def pytest_marks(self) -> list["MarkDecorator"]:
"""
Instantiated pytest marks from :attr:`.ValidationCase.marks`
plus the interface name.
"""
import pytest
marks = self.marks.copy()
if self.interface is not None:
marks.add(self.interface.interface.name)
return [getattr(pytest.mark, m) for m in marks]
def validate_case(self, path: Optional[Path] = None) -> bool: def validate_case(self, path: Optional[Path] = None) -> bool:
""" """
Whether the generated array correctly validated against the annotation, Whether the generated array correctly validated against the annotation,
@ -246,7 +264,10 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
return args[0] return args[0]
dumped = [ dumped = [
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args m.model_dump(
exclude_unset=True, exclude={"model", "annotation", "pytest_marks"}
)
for m in args
] ]
# self_dump = self.model_dump(exclude_unset=True) # self_dump = self.model_dump(exclude_unset=True)
@ -263,6 +284,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
merged = reduce(ior, dumped, {}) merged = reduce(ior, dumped, {})
merged["passes"] = passes merged["passes"] = passes
merged["id"] = ids merged["id"] = ids
merged["marks"] = set().union(*[v.get("marks", set()) for v in dumped])
return ValidationCase.model_construct(**merged) return ValidationCase.model_construct(**merged)

View file

@ -120,7 +120,7 @@ class SubscriptableMeta(ABCMeta):
new type is returned for every unique set of arguments. new type is returned for every unique set of arguments.
""" """
_all_types: Dict[Tuple[type, Tuple[Any, ...]], type] = {} _all_types: Dict[Tuple[type, Tuple[Any, ...], tuple[type, ...]], type] = {}
_parameterized: bool = False _parameterized: bool = False
@abstractmethod @abstractmethod
@ -160,7 +160,7 @@ class SubscriptableMeta(ABCMeta):
def _create_type( def _create_type(
cls, args: Tuple[Any, ...], additional_values: Dict[str, Any] cls, args: Tuple[Any, ...], additional_values: Dict[str, Any]
) -> type: ) -> type:
key = (cls, args) key = (cls, args, tuple(type(a) for a in args))
if key not in cls._all_types: if key not in cls._all_types:
cls._all_types[key] = type( cls._all_types[key] = type(
cls.__name__, cls.__name__,

View file

@ -14,7 +14,8 @@ def pytest_addoption(parser):
@pytest.fixture( @pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES] scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in SHAPE_CASES],
) )
def shape_cases(request, tmp_output_dir_func) -> ValidationCase: def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy() case: ValidationCase = request.param.model_copy()
@ -23,7 +24,8 @@ def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
@pytest.fixture( @pytest.fixture(
scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES] scope="function",
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in DTYPE_CASES],
) )
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase: def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
case: ValidationCase = request.param.model_copy() case: ValidationCase = request.param.model_copy()

View file

@ -61,10 +61,7 @@ def interface_cases(request) -> InterfaceCase:
@pytest.fixture( @pytest.fixture(
params=( params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES)
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES
)
) )
def all_cases(interface_cases, request) -> ValidationCase: def all_cases(interface_cases, request) -> ValidationCase:
""" """
@ -83,10 +80,7 @@ def all_cases(interface_cases, request) -> ValidationCase:
@pytest.fixture( @pytest.fixture(
params=( params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES_PASSING)
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
for p in ALL_CASES_PASSING
)
) )
def all_passing_cases(request) -> ValidationCase: def all_passing_cases(request) -> ValidationCase:
""" """
@ -132,7 +126,7 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
@pytest.fixture( @pytest.fixture(
params=( params=(
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name)) pytest.param(p, id=p.id, marks=p.pytest_marks)
for p in DTYPE_AND_INTERFACE_CASES_PASSING for p in DTYPE_AND_INTERFACE_CASES_PASSING
) )
) )

View file

@ -61,6 +61,18 @@ def test_interface_revalidate(all_passing_cases_instance):
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array) _ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)
@pytest.mark.json_schema
def test_interface_jsonschema(all_passing_cases_instance):
"""
All interfaces should be able to generate json schema
for all combinations of dtype and shape
Note that this does not test for json schema correctness -
see ndarray tests for that
"""
_ = all_passing_cases_instance.model_json_schema()
@pytest.mark.xfail @pytest.mark.xfail
def test_interface_rematch(interface_cases, tmp_output_dir_func): def test_interface_rematch(interface_cases, tmp_output_dir_func):
""" """