mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
Merge pull request #39 from p2p-ld/bugfix-union-dtypes
Fix JSON Schema generation for union dtypes
This commit is contained in:
commit
d54698fc0f
11 changed files with 97 additions and 22 deletions
10
.github/workflows/tests.yml
vendored
10
.github/workflows/tests.yml
vendored
|
@ -15,20 +15,28 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
platform: ["ubuntu-latest", "macos-latest", "windows-latest"]
|
platform: ["ubuntu-latest", "macos-latest", "windows-latest"]
|
||||||
numpy-version: ["<2.0.0", ">=2.0.0"]
|
numpy-version: ["<2.0.0", ">=2.0.0"]
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||||
exclude:
|
exclude:
|
||||||
- numpy-version: "<2.0.0"
|
- numpy-version: "<2.0.0"
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
- numpy-version: "<2.0.0"
|
- numpy-version: "<2.0.0"
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
# let's call python 3.12 the last version we're going to support with numpy <2
|
||||||
|
# they don't provide wheels for <2 in 3.13 and beyond.
|
||||||
|
- numpy-version: "<2.0.0"
|
||||||
|
python-version: "3.13"
|
||||||
- platform: "macos-latest"
|
- platform: "macos-latest"
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
- platform: "macos-latest"
|
- platform: "macos-latest"
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
- platform: "macos-latest"
|
||||||
|
python-version: "3.12"
|
||||||
- platform: "windows-latest"
|
- platform: "windows-latest"
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
- platform: "windows-latest"
|
- platform: "windows-latest"
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
- platform: "windows-latest"
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
runs-on: ${{ matrix.platform }}
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,21 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## Upcoming
|
||||||
|
|
||||||
|
**Bugfix**
|
||||||
|
- [#38](https://github.com/p2p-ld/numpydantic/issues/38), [#39](https://github.com/p2p-ld/numpydantic/pull/39) -
|
||||||
|
- JSON Schema generation failed when the `dtype` was embedded from dtypes that lack a `__name__` attribute.
|
||||||
|
An additional check was added for presence of `__name__` when embedding.
|
||||||
|
- `NDArray` types were incorrectly cached s.t. pipe-union dtypes were considered equivalent to `Union[]`
|
||||||
|
dtypes. An additional tuple with the type of the args was added to the cache key to disambiguate them.
|
||||||
|
|
||||||
|
**Testing**
|
||||||
|
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Test that all combinations of shapes, dtypes, and interfaces
|
||||||
|
can generate JSON schema.
|
||||||
|
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add python 3.13 to the testing matrix.
|
||||||
|
- [#39](https://github.com/p2p-ld/numpydantic/pull/39) - Add an additional `marks` field to ValidationCase
|
||||||
|
for finer-grained control over running tests.
|
||||||
|
|
||||||
## 1.*
|
## 1.*
|
||||||
|
|
||||||
### 1.6.*
|
### 1.6.*
|
||||||
|
|
2
pdm.lock
2
pdm.lock
|
@ -5,7 +5,7 @@
|
||||||
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
|
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests", "video", "zarr"]
|
||||||
strategy = ["cross_platform", "inherit_metadata"]
|
strategy = ["cross_platform", "inherit_metadata"]
|
||||||
lock_version = "4.5.0"
|
lock_version = "4.5.0"
|
||||||
content_hash = "sha256:cc2b0fb32896c6df0ad747ddb5dee89af22f5c4c4643ee7a52db47fef30da936"
|
content_hash = "sha256:89ac87e811ecc42bf5117e9c9e4aa6a69011cb7c5c1a630fbfb2643b0045c526"
|
||||||
|
|
||||||
[[metadata.targets]]
|
[[metadata.targets]]
|
||||||
requires_python = "~=3.9"
|
requires_python = "~=3.9"
|
||||||
|
|
|
@ -129,6 +129,8 @@ markers = [
|
||||||
"numpy: numpy interface",
|
"numpy: numpy interface",
|
||||||
"video: video interface",
|
"video: video interface",
|
||||||
"zarr: zarr interface",
|
"zarr: zarr interface",
|
||||||
|
"union: union dtypes",
|
||||||
|
"pipe_union: union dtypes specified with a pipe",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
|
|
|
@ -204,9 +204,15 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
json_schema = handler(schema["metadata"])
|
json_schema = handler(schema["metadata"])
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
json_schema = handler.resolve_ref_schema(json_schema)
|
||||||
|
|
||||||
if not isinstance(dtype, tuple) and dtype.__module__ not in (
|
if (
|
||||||
"builtins",
|
not isinstance(dtype, tuple)
|
||||||
"typing",
|
and dtype.__module__
|
||||||
|
not in (
|
||||||
|
"builtins",
|
||||||
|
"typing",
|
||||||
|
"types",
|
||||||
|
)
|
||||||
|
and hasattr(dtype, "__name__")
|
||||||
):
|
):
|
||||||
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
|
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
|
||||||
|
|
||||||
|
|
|
@ -143,27 +143,35 @@ DTYPE_CASES = [
|
||||||
dtype=np.uint32,
|
dtype=np.uint32,
|
||||||
passes=True,
|
passes=True,
|
||||||
id="union-type-uint32",
|
id="union-type-uint32",
|
||||||
|
marks={"union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_TYPE,
|
annotation_dtype=UNION_TYPE,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
passes=True,
|
passes=True,
|
||||||
id="union-type-float32",
|
id="union-type-float32",
|
||||||
|
marks={"union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_TYPE,
|
annotation_dtype=UNION_TYPE,
|
||||||
dtype=np.uint64,
|
dtype=np.uint64,
|
||||||
passes=False,
|
passes=False,
|
||||||
id="union-type-uint64",
|
id="union-type-uint64",
|
||||||
|
marks={"union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_TYPE,
|
annotation_dtype=UNION_TYPE,
|
||||||
dtype=np.float64,
|
dtype=np.float64,
|
||||||
passes=False,
|
passes=False,
|
||||||
id="union-type-float64",
|
id="union-type-float64",
|
||||||
|
marks={"union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_TYPE, dtype=str, passes=False, id="union-type-str"
|
annotation_dtype=UNION_TYPE,
|
||||||
|
dtype=str,
|
||||||
|
passes=False,
|
||||||
|
id="union-type-str",
|
||||||
|
marks={"union"},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
|
@ -181,30 +189,35 @@ if YES_PIPE:
|
||||||
dtype=np.uint32,
|
dtype=np.uint32,
|
||||||
passes=True,
|
passes=True,
|
||||||
id="union-pipe-uint32",
|
id="union-pipe-uint32",
|
||||||
|
marks={"union", "pipe_union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_PIPE,
|
annotation_dtype=UNION_PIPE,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
passes=True,
|
passes=True,
|
||||||
id="union-pipe-float32",
|
id="union-pipe-float32",
|
||||||
|
marks={"union", "pipe_union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_PIPE,
|
annotation_dtype=UNION_PIPE,
|
||||||
dtype=np.uint64,
|
dtype=np.uint64,
|
||||||
passes=False,
|
passes=False,
|
||||||
id="union-pipe-uint64",
|
id="union-pipe-uint64",
|
||||||
|
marks={"union", "pipe_union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_PIPE,
|
annotation_dtype=UNION_PIPE,
|
||||||
dtype=np.float64,
|
dtype=np.float64,
|
||||||
passes=False,
|
passes=False,
|
||||||
id="union-pipe-float64",
|
id="union-pipe-float64",
|
||||||
|
marks={"union", "pipe_union"},
|
||||||
),
|
),
|
||||||
ValidationCase(
|
ValidationCase(
|
||||||
annotation_dtype=UNION_PIPE,
|
annotation_dtype=UNION_PIPE,
|
||||||
dtype=str,
|
dtype=str,
|
||||||
passes=False,
|
passes=False,
|
||||||
id="union-pipe-str",
|
id="union-pipe-str",
|
||||||
|
marks={"union", "pipe_union"},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,16 +4,19 @@ from functools import reduce
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from operator import ior
|
from operator import ior
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, List, Literal, Optional, Tuple, Type, Union
|
from typing import TYPE_CHECKING, Generator, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict, ValidationError, computed_field
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError, computed_field
|
||||||
|
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from numpydantic.dtype import Float
|
from numpydantic.dtype import Float
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.types import DtypeType, NDArrayType
|
from numpydantic.types import DtypeType, NDArrayType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from _pytest.mark.structures import MarkDecorator
|
||||||
|
|
||||||
|
|
||||||
class InterfaceCase(ABC):
|
class InterfaceCase(ABC):
|
||||||
"""
|
"""
|
||||||
|
@ -139,6 +142,8 @@ class ValidationCase(BaseModel):
|
||||||
"""The interface test case to generate and validate the array with"""
|
"""The interface test case to generate and validate the array with"""
|
||||||
path: Optional[Path] = None
|
path: Optional[Path] = None
|
||||||
"""The path to generate arrays into, if any."""
|
"""The path to generate arrays into, if any."""
|
||||||
|
marks: set[str] = Field(default_factory=set)
|
||||||
|
"""pytest marks to set for this test case"""
|
||||||
|
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@ -179,6 +184,19 @@ class ValidationCase(BaseModel):
|
||||||
|
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pytest_marks(self) -> list["MarkDecorator"]:
|
||||||
|
"""
|
||||||
|
Instantiated pytest marks from :attr:`.ValidationCase.marks`
|
||||||
|
plus the interface name.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
marks = self.marks.copy()
|
||||||
|
if self.interface is not None:
|
||||||
|
marks.add(self.interface.interface.name)
|
||||||
|
return [getattr(pytest.mark, m) for m in marks]
|
||||||
|
|
||||||
def validate_case(self, path: Optional[Path] = None) -> bool:
|
def validate_case(self, path: Optional[Path] = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Whether the generated array correctly validated against the annotation,
|
Whether the generated array correctly validated against the annotation,
|
||||||
|
@ -246,7 +264,10 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
|
||||||
return args[0]
|
return args[0]
|
||||||
|
|
||||||
dumped = [
|
dumped = [
|
||||||
m.model_dump(exclude_unset=True, exclude={"model", "annotation"}) for m in args
|
m.model_dump(
|
||||||
|
exclude_unset=True, exclude={"model", "annotation", "pytest_marks"}
|
||||||
|
)
|
||||||
|
for m in args
|
||||||
]
|
]
|
||||||
|
|
||||||
# self_dump = self.model_dump(exclude_unset=True)
|
# self_dump = self.model_dump(exclude_unset=True)
|
||||||
|
@ -263,6 +284,7 @@ def merge_cases(*args: ValidationCase) -> ValidationCase:
|
||||||
merged = reduce(ior, dumped, {})
|
merged = reduce(ior, dumped, {})
|
||||||
merged["passes"] = passes
|
merged["passes"] = passes
|
||||||
merged["id"] = ids
|
merged["id"] = ids
|
||||||
|
merged["marks"] = set().union(*[v.get("marks", set()) for v in dumped])
|
||||||
return ValidationCase.model_construct(**merged)
|
return ValidationCase.model_construct(**merged)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,7 @@ class SubscriptableMeta(ABCMeta):
|
||||||
new type is returned for every unique set of arguments.
|
new type is returned for every unique set of arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_all_types: Dict[Tuple[type, Tuple[Any, ...]], type] = {}
|
_all_types: Dict[Tuple[type, Tuple[Any, ...], tuple[type, ...]], type] = {}
|
||||||
_parameterized: bool = False
|
_parameterized: bool = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -160,7 +160,7 @@ class SubscriptableMeta(ABCMeta):
|
||||||
def _create_type(
|
def _create_type(
|
||||||
cls, args: Tuple[Any, ...], additional_values: Dict[str, Any]
|
cls, args: Tuple[Any, ...], additional_values: Dict[str, Any]
|
||||||
) -> type:
|
) -> type:
|
||||||
key = (cls, args)
|
key = (cls, args, tuple(type(a) for a in args))
|
||||||
if key not in cls._all_types:
|
if key not in cls._all_types:
|
||||||
cls._all_types[key] = type(
|
cls._all_types[key] = type(
|
||||||
cls.__name__,
|
cls.__name__,
|
||||||
|
|
|
@ -14,7 +14,8 @@ def pytest_addoption(parser):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function", params=[pytest.param(c, id=c.id) for c in SHAPE_CASES]
|
scope="function",
|
||||||
|
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in SHAPE_CASES],
|
||||||
)
|
)
|
||||||
def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
|
def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
|
||||||
case: ValidationCase = request.param.model_copy()
|
case: ValidationCase = request.param.model_copy()
|
||||||
|
@ -23,7 +24,8 @@ def shape_cases(request, tmp_output_dir_func) -> ValidationCase:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function", params=[pytest.param(c, id=c.id) for c in DTYPE_CASES]
|
scope="function",
|
||||||
|
params=[pytest.param(c, id=c.id, marks=c.pytest_marks) for c in DTYPE_CASES],
|
||||||
)
|
)
|
||||||
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
|
def dtype_cases(request, tmp_output_dir_func) -> ValidationCase:
|
||||||
case: ValidationCase = request.param.model_copy()
|
case: ValidationCase = request.param.model_copy()
|
||||||
|
|
|
@ -61,10 +61,7 @@ def interface_cases(request) -> InterfaceCase:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=(
|
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES)
|
||||||
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
|
|
||||||
for p in ALL_CASES
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
def all_cases(interface_cases, request) -> ValidationCase:
|
def all_cases(interface_cases, request) -> ValidationCase:
|
||||||
"""
|
"""
|
||||||
|
@ -83,10 +80,7 @@ def all_cases(interface_cases, request) -> ValidationCase:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=(
|
params=(pytest.param(p, id=p.id, marks=p.pytest_marks) for p in ALL_CASES_PASSING)
|
||||||
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
|
|
||||||
for p in ALL_CASES_PASSING
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
def all_passing_cases(request) -> ValidationCase:
|
def all_passing_cases(request) -> ValidationCase:
|
||||||
"""
|
"""
|
||||||
|
@ -132,7 +126,7 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func):
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
params=(
|
params=(
|
||||||
pytest.param(p, id=p.id, marks=getattr(pytest.mark, p.interface.interface.name))
|
pytest.param(p, id=p.id, marks=p.pytest_marks)
|
||||||
for p in DTYPE_AND_INTERFACE_CASES_PASSING
|
for p in DTYPE_AND_INTERFACE_CASES_PASSING
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -61,6 +61,18 @@ def test_interface_revalidate(all_passing_cases_instance):
|
||||||
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)
|
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.json_schema
|
||||||
|
def test_interface_jsonschema(all_passing_cases_instance):
|
||||||
|
"""
|
||||||
|
All interfaces should be able to generate json schema
|
||||||
|
for all combinations of dtype and shape
|
||||||
|
|
||||||
|
Note that this does not test for json schema correctness -
|
||||||
|
see ndarray tests for that
|
||||||
|
"""
|
||||||
|
_ = all_passing_cases_instance.model_json_schema()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
@pytest.mark.xfail
|
||||||
def test_interface_rematch(interface_cases, tmp_output_dir_func):
|
def test_interface_rematch(interface_cases, tmp_output_dir_func):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue