mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
add array shape ranges! take over shape specification and checking from nptyping
This commit is contained in:
parent
d567ce0194
commit
07ab3d1b76
7 changed files with 350 additions and 22 deletions
|
@ -8,8 +8,7 @@ apply_patches()
|
||||||
|
|
||||||
from numpydantic.ndarray import NDArray
|
from numpydantic.ndarray import NDArray
|
||||||
from numpydantic.meta import update_ndarray_stub
|
from numpydantic.meta import update_ndarray_stub
|
||||||
|
from numpydantic.shape import Shape
|
||||||
from nptyping import Shape
|
|
||||||
|
|
||||||
update_ndarray_stub()
|
update_ndarray_stub()
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ from operator import attrgetter
|
||||||
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nptyping.shape_expression import check_shape
|
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.exceptions import (
|
from numpydantic.exceptions import (
|
||||||
|
@ -16,6 +15,7 @@ from numpydantic.exceptions import (
|
||||||
ShapeError,
|
ShapeError,
|
||||||
TooManyMatchesError,
|
TooManyMatchesError,
|
||||||
)
|
)
|
||||||
|
from numpydantic.shape import check_shape
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
|
||||||
T = TypeVar("T", bound=NDArrayType)
|
T = TypeVar("T", bound=NDArrayType)
|
||||||
|
|
|
@ -42,6 +42,8 @@ from numpydantic.types import DtypeType, ShapeType
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from nptyping.base_meta_classes import SubscriptableMeta
|
from nptyping.base_meta_classes import SubscriptableMeta
|
||||||
|
|
||||||
|
from numpydantic import Shape
|
||||||
|
|
||||||
|
|
||||||
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
"""
|
"""
|
||||||
|
@ -78,6 +80,28 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
except InterfaceError:
|
except InterfaceError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
|
||||||
|
"""
|
||||||
|
Override of base method to use our local definition of shape
|
||||||
|
"""
|
||||||
|
from numpydantic.shape import Shape
|
||||||
|
|
||||||
|
if dtype_candidate is Any or dtype_candidate is Shape:
|
||||||
|
shape = Any
|
||||||
|
elif issubclass(dtype_candidate, Shape):
|
||||||
|
shape = dtype_candidate
|
||||||
|
elif cls._is_literal_like(dtype_candidate):
|
||||||
|
shape_expression = dtype_candidate.__args__[0]
|
||||||
|
shape = Shape[shape_expression]
|
||||||
|
else:
|
||||||
|
raise InvalidArgumentsError(
|
||||||
|
f"Unexpected argument '{dtype_candidate}', expecting"
|
||||||
|
" Shape[<ShapeExpression>]"
|
||||||
|
" or Literal[<ShapeExpression>]"
|
||||||
|
" or typing.Any."
|
||||||
|
)
|
||||||
|
return shape
|
||||||
|
|
||||||
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
||||||
"""
|
"""
|
||||||
Override of base _get_dtype method to allow for compound tuple types
|
Override of base _get_dtype method to allow for compound tuple types
|
||||||
|
|
|
@ -5,11 +5,10 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import nptyping.structure
|
import nptyping.structure
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nptyping import Shape
|
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
||||||
|
@ -19,6 +18,9 @@ from numpydantic.interface import Interface
|
||||||
from numpydantic.maps import np_to_python
|
from numpydantic.maps import np_to_python
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from numpydantic import Shape
|
||||||
|
|
||||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
||||||
_UNSUPPORTED_TYPES = (complex,)
|
_UNSUPPORTED_TYPES = (complex,)
|
||||||
|
|
||||||
|
@ -88,7 +90,7 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||||
return array_type
|
return array_type
|
||||||
|
|
||||||
|
|
||||||
def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
|
def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
|
||||||
"""
|
"""
|
||||||
Make a pydantic JSON schema for an array as a list of lists.
|
Make a pydantic JSON schema for an array as a list of lists.
|
||||||
|
|
||||||
|
@ -98,13 +100,15 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
|
||||||
This function is typically called from :func:`.make_json_schema`
|
This function is typically called from :func:`.make_json_schema`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape (:class:`.Shape` ): Shape determines the depth and max/min elements
|
shape (:class:`~numpydantic.Shape`): Shape determines the depth and max/min
|
||||||
for each layer of list schema
|
elements for each layer of list schema
|
||||||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||||
core schema to use in the innermost list entry
|
core schema to use in the innermost list entry
|
||||||
"""
|
"""
|
||||||
|
from numpydantic.shape import _is_range
|
||||||
|
|
||||||
shape_parts = shape.__args__[0].split(",")
|
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
||||||
|
# labels, if present
|
||||||
split_parts = [
|
split_parts = [
|
||||||
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
|
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
|
||||||
]
|
]
|
||||||
|
@ -129,18 +133,28 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
|
||||||
elif arg == "...":
|
elif arg == "...":
|
||||||
list_schema = _unbounded_shape(inner_schema, metadata=metadata)
|
list_schema = _unbounded_shape(inner_schema, metadata=metadata)
|
||||||
else:
|
else:
|
||||||
try:
|
if _is_range(arg):
|
||||||
arg = int(arg)
|
arg_min, arg_max = arg.split("-")
|
||||||
except ValueError as e:
|
arg_min = None if arg_min == "*" else int(arg_min)
|
||||||
raise ValueError(
|
arg_max = None if arg_max == "*" else int(arg_max)
|
||||||
"Array shapes must be integers, wildcards, or ellipses. "
|
|
||||||
"Shape variables (for declaring that one dimension must be the "
|
else:
|
||||||
"same size as another) are not supported because it is "
|
try:
|
||||||
"impossible to express dynamic minItems/maxItems in JSON Schema. "
|
arg = int(arg)
|
||||||
"See: https://github.com/orgs/json-schema-org/discussions/730"
|
arg_min = arg
|
||||||
) from e
|
arg_max = arg
|
||||||
|
except ValueError as e:
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Array shapes must be integers, wildcards, ellipses, or "
|
||||||
|
"ranges. Shape variables (for declaring that one dimension "
|
||||||
|
"must be the same size as another) are not supported because "
|
||||||
|
"it is impossible to express dynamic minItems/maxItems in "
|
||||||
|
"JSON Schema. "
|
||||||
|
"See: https://github.com/orgs/json-schema-org/discussions/730"
|
||||||
|
) from e
|
||||||
list_schema = core_schema.list_schema(
|
list_schema = core_schema.list_schema(
|
||||||
inner_schema, min_length=arg, max_length=arg, metadata=metadata
|
inner_schema, min_length=arg_min, max_length=arg_max, metadata=metadata
|
||||||
)
|
)
|
||||||
return list_schema
|
return list_schema
|
||||||
|
|
||||||
|
|
211
src/numpydantic/shape.py
Normal file
211
src/numpydantic/shape.py
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
"""
|
||||||
|
Declaration and validation functions for array shapes.
|
||||||
|
|
||||||
|
Mostly a mildly modified version of nptyping's
|
||||||
|
:func:`npytping.shape_expression.check_shape`
|
||||||
|
and its internals to allow for extended syntax, including ranges of shapes.
|
||||||
|
|
||||||
|
Modifications from nptyping:
|
||||||
|
|
||||||
|
- **"..."** - In nptyping, ``'...'`` means "any number of dimensions with the same shape
|
||||||
|
as the last dimension. ie ``Shape[2, ...]`` means "any number of 2-length
|
||||||
|
dimensions. Here ``'...'`` always means "any number of any-shape dimensions"
|
||||||
|
- **Ranges** - (inclusive) shape ranges are allowed. eg. to specify an array
|
||||||
|
where the first dimension can be 2, 3, or 4 length:
|
||||||
|
|
||||||
|
Shape["2-4, ..."]
|
||||||
|
|
||||||
|
To specify a range with an unbounded min or max, use wildcards, eg. for
|
||||||
|
an array with the first dimension at least length 2, and the second dimension
|
||||||
|
at most length 5 (both inclusive):
|
||||||
|
|
||||||
|
Shape["2-*, *-5"]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import string
|
||||||
|
from abc import ABC
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
|
from nptyping.base_meta_classes import ContainerMeta
|
||||||
|
from nptyping.error import InvalidShapeError
|
||||||
|
from nptyping.nptyping_type import NPTypingType
|
||||||
|
from nptyping.shape_expression import (
|
||||||
|
get_dimensions,
|
||||||
|
normalize_shape_expression,
|
||||||
|
remove_labels,
|
||||||
|
)
|
||||||
|
from nptyping.typing_ import ShapeExpression, ShapeTuple
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeMeta(ContainerMeta, implementation="Shape"):
|
||||||
|
"""
|
||||||
|
Metaclass that is coupled to nptyping.Shape.
|
||||||
|
|
||||||
|
Overridden from nptyping to use local shape validation function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _validate_expression(cls, item: str) -> None:
|
||||||
|
validate_shape_expression(item)
|
||||||
|
|
||||||
|
def _normalize_expression(cls, item: str) -> str:
|
||||||
|
return normalize_shape_expression(item)
|
||||||
|
|
||||||
|
def _get_additional_values(cls, item: Any) -> Dict[str, Any]:
|
||||||
|
dim_strings = get_dimensions(item)
|
||||||
|
dim_string_without_labels = remove_labels(dim_strings)
|
||||||
|
return {"prepared_args": dim_string_without_labels}
|
||||||
|
|
||||||
|
|
||||||
|
class Shape(NPTypingType, ABC, metaclass=ShapeMeta):
|
||||||
|
"""
|
||||||
|
A container for shape expressions that describe the shape of an multi
|
||||||
|
dimensional array.
|
||||||
|
|
||||||
|
Simple example:
|
||||||
|
|
||||||
|
>>> Shape['2, 2']
|
||||||
|
Shape['2, 2']
|
||||||
|
|
||||||
|
A Shape can be compared to a typing.Literal. You can use Literals in
|
||||||
|
NDArray as well.
|
||||||
|
|
||||||
|
>>> from typing import Literal
|
||||||
|
|
||||||
|
>>> Shape['2, 2'] == Literal['2, 2']
|
||||||
|
True
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
__args__ = ("*, ...",)
|
||||||
|
prepared_args = ("*", "...")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None:
|
||||||
|
"""
|
||||||
|
CHANGES FROM NPTYPING: Allow ranges
|
||||||
|
"""
|
||||||
|
shape_expression_no_quotes = shape_expression.replace("'", "").replace('"', "")
|
||||||
|
if shape_expression is not Any and not re.match(
|
||||||
|
_REGEX_SHAPE_EXPRESSION, shape_expression_no_quotes
|
||||||
|
):
|
||||||
|
raise InvalidShapeError(
|
||||||
|
f"'{shape_expression}' is not a valid shape expression."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def check_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
||||||
|
"""
|
||||||
|
Check whether the given shape corresponds to the given shape_expression.
|
||||||
|
:param shape: the shape in question.
|
||||||
|
:param target: the shape expression to which shape is tested.
|
||||||
|
:return: True if the given shape corresponds to shape_expression.
|
||||||
|
"""
|
||||||
|
target_shape = _handle_ellipsis(shape, target.prepared_args)
|
||||||
|
return _check_dimensions_against_shape(shape, target_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dimensions_against_shape(shape: ShapeTuple, target: List[str]) -> bool:
|
||||||
|
# Walk through the shape and test them against the given target,
|
||||||
|
# taking into consideration variables, wildcards, etc.
|
||||||
|
|
||||||
|
if len(shape) != len(target):
|
||||||
|
return False
|
||||||
|
shape_as_strings = (str(dim) for dim in shape)
|
||||||
|
variables: Dict[str, str] = {}
|
||||||
|
for dim, target_dim in zip(shape_as_strings, target):
|
||||||
|
if _is_wildcard(target_dim) or _is_assignable_var(dim, target_dim, variables):
|
||||||
|
continue
|
||||||
|
if _is_range(target_dim) and _check_range(dim, target_dim):
|
||||||
|
continue
|
||||||
|
if dim != target_dim:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_ellipsis(shape: ShapeTuple, target: List[str]) -> List[str]:
|
||||||
|
# Let the ellipsis allows for any number of dimensions by replacing the
|
||||||
|
# ellipsis with the dimension size repeated the number of times that
|
||||||
|
# corresponds to the shape of the instance.
|
||||||
|
if target[-1] == "...":
|
||||||
|
dim_to_repeat = "*"
|
||||||
|
target = target[0:-1]
|
||||||
|
if len(shape) > len(target):
|
||||||
|
difference = len(shape) - len(target)
|
||||||
|
target += difference * [dim_to_repeat]
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
|
def _is_range(target_dim: str) -> bool:
|
||||||
|
"""Whether the dimension is a range (literally whether it includes a hyphen)"""
|
||||||
|
return "-" in target_dim and len(target_dim.split("-")) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def _check_range(dim: str, target_dim: str) -> bool:
|
||||||
|
"""check whether the given dimension is within the target_dim range"""
|
||||||
|
dim = int(dim)
|
||||||
|
|
||||||
|
range_min, range_max = target_dim.split("-")
|
||||||
|
if _is_wildcard(range_min):
|
||||||
|
return dim <= int(range_max)
|
||||||
|
elif _is_wildcard(range_max):
|
||||||
|
return dim >= int(range_min)
|
||||||
|
else:
|
||||||
|
return int(range_min) <= dim <= int(range_max)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_wildcard(dim: str) -> bool:
|
||||||
|
"""
|
||||||
|
CHANGES FROM NPTYPING: added '*-*' range, which is a wildcard
|
||||||
|
"""
|
||||||
|
# Return whether dim is a wildcard (i.e. the character that takes any
|
||||||
|
# dimension size).
|
||||||
|
return dim == "*" or dim == "*-*"
|
||||||
|
|
||||||
|
|
||||||
|
# CHANGES FROM NPTYPING: Allow ranges
|
||||||
|
_REGEX_SEPARATOR = r"(\s*,\s*)"
|
||||||
|
_REGEX_DIMENSION_SIZE = r"(\s*[0-9]+\s*)"
|
||||||
|
_REGEX_DIMENSION_RANGE = r"(\s*[0-9\*]+-[0-9\*]+\s*)"
|
||||||
|
_REGEX_VARIABLE = r"(\s*\b[A-Z]\w*\s*)"
|
||||||
|
_REGEX_LABEL = r"(\s*\b[a-z]\w*\s*)"
|
||||||
|
_REGEX_LABELS = rf"({_REGEX_LABEL}({_REGEX_SEPARATOR}{_REGEX_LABEL})*)"
|
||||||
|
_REGEX_WILDCARD = r"(\s*\*\s*)"
|
||||||
|
_REGEX_DIMENSION_BREAKDOWN = rf"(\s*\[{_REGEX_LABELS}\]\s*)"
|
||||||
|
_REGEX_DIMENSION = (
|
||||||
|
rf"({_REGEX_DIMENSION_SIZE}"
|
||||||
|
rf"|{_REGEX_DIMENSION_RANGE}"
|
||||||
|
rf"|{_REGEX_VARIABLE}"
|
||||||
|
rf"|{_REGEX_WILDCARD}"
|
||||||
|
rf"|{_REGEX_DIMENSION_BREAKDOWN})"
|
||||||
|
)
|
||||||
|
_REGEX_DIMENSION_WITH_LABEL = rf"({_REGEX_DIMENSION}(\s+{_REGEX_LABEL})*)"
|
||||||
|
_REGEX_DIMENSIONS = (
|
||||||
|
rf"{_REGEX_DIMENSION_WITH_LABEL}({_REGEX_SEPARATOR}{_REGEX_DIMENSION_WITH_LABEL})*"
|
||||||
|
)
|
||||||
|
_REGEX_DIMENSIONS_ELLIPSIS = rf"({_REGEX_DIMENSIONS}{_REGEX_SEPARATOR}\.\.\.\s*)"
|
||||||
|
_REGEX_SHAPE_EXPRESSION = rf"^({_REGEX_DIMENSIONS}|{_REGEX_DIMENSIONS_ELLIPSIS})$"
|
||||||
|
|
||||||
|
# --------------------------------------------------
|
||||||
|
# Below - unchanged from nptyping
|
||||||
|
# --------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _is_assignable_var(dim: str, target_dim: str, variables: Dict[str, str]) -> bool:
|
||||||
|
# Return whether target_dim is a variable and can be assigned with dim.
|
||||||
|
return _is_variable(target_dim) and _can_assign_variable(dim, target_dim, variables)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_variable(dim: str) -> bool:
|
||||||
|
# Return whether dim is a variable.
|
||||||
|
return dim[0] in string.ascii_uppercase
|
||||||
|
|
||||||
|
|
||||||
|
def _can_assign_variable(dim: str, target_dim: str, variables: Dict[str, str]) -> bool:
|
||||||
|
# Check and assign a variable.
|
||||||
|
assignable = variables.get(target_dim) in (None, dim)
|
||||||
|
variables[target_dim] = dim
|
||||||
|
return assignable
|
|
@ -7,9 +7,9 @@ import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ValidationError, Field
|
from pydantic import BaseModel, ValidationError, Field
|
||||||
from nptyping import Shape, Number
|
from nptyping import Number
|
||||||
|
|
||||||
from numpydantic import NDArray
|
from numpydantic import NDArray, Shape
|
||||||
from numpydantic.exceptions import ShapeError, DtypeError
|
from numpydantic.exceptions import ShapeError, DtypeError
|
||||||
from numpydantic import dtype
|
from numpydantic import dtype
|
||||||
|
|
||||||
|
|
80
tests/test_shape.py
Normal file
80
tests/test_shape.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shape,valid",
|
||||||
|
[
|
||||||
|
((2, 6), True),
|
||||||
|
((2, 7), True),
|
||||||
|
((3, 6), True),
|
||||||
|
((3, 7), True),
|
||||||
|
((4, 6), True),
|
||||||
|
((4, 7), True),
|
||||||
|
((1, 6), False),
|
||||||
|
((5, 6), False),
|
||||||
|
((2, 5), False),
|
||||||
|
((2, 8), False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_shape_range(shape, valid):
|
||||||
|
"""Specify a dimension with a range of possible sizes"""
|
||||||
|
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray[Shape["2-4, 6-7"], Any]
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shape,valid",
|
||||||
|
[
|
||||||
|
((2, 5), True),
|
||||||
|
((10, 5), True),
|
||||||
|
((2, 2), True),
|
||||||
|
((1, 5), False),
|
||||||
|
((2, 6), False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_shape_wildcard(shape, valid):
|
||||||
|
"""Specify an open-ended minimum or maximum size for a given dimension"""
|
||||||
|
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray[Shape["2-*, *-5"], Any]
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||||
|
|
||||||
|
|
||||||
|
def test_range_shape_schema():
|
||||||
|
"""
|
||||||
|
Range shapes should correctly generate JSON Schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array_range: NDArray[Shape["2-4"], Any]
|
||||||
|
array_range_min: NDArray[Shape["2-*"], Any]
|
||||||
|
array_range_max: NDArray[Shape["*-4"], Any]
|
||||||
|
|
||||||
|
schema = MyModel.model_json_schema()
|
||||||
|
assert schema["properties"]["array_range"]["minItems"] == 2
|
||||||
|
assert schema["properties"]["array_range"]["maxItems"] == 4
|
||||||
|
assert schema["properties"]["array_range_min"]["minItems"] == 2
|
||||||
|
assert "maxItems" not in schema["properties"]["array_range_min"]
|
||||||
|
assert schema["properties"]["array_range_max"]["maxItems"] == 4
|
||||||
|
assert "minItems" not in schema["properties"]["array_range_max"]
|
Loading…
Reference in a new issue