mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
add array shape ranges! take over shape specification and checking from nptyping
This commit is contained in:
parent
d567ce0194
commit
07ab3d1b76
7 changed files with 350 additions and 22 deletions
|
@ -8,8 +8,7 @@ apply_patches()
|
|||
|
||||
from numpydantic.ndarray import NDArray
|
||||
from numpydantic.meta import update_ndarray_stub
|
||||
|
||||
from nptyping import Shape
|
||||
from numpydantic.shape import Shape
|
||||
|
||||
update_ndarray_stub()
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ from operator import attrgetter
|
|||
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from nptyping.shape_expression import check_shape
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.exceptions import (
|
||||
|
@ -16,6 +15,7 @@ from numpydantic.exceptions import (
|
|||
ShapeError,
|
||||
TooManyMatchesError,
|
||||
)
|
||||
from numpydantic.shape import check_shape
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
T = TypeVar("T", bound=NDArrayType)
|
||||
|
|
|
@ -42,6 +42,8 @@ from numpydantic.types import DtypeType, ShapeType
|
|||
if TYPE_CHECKING: # pragma: no cover
|
||||
from nptyping.base_meta_classes import SubscriptableMeta
|
||||
|
||||
from numpydantic import Shape
|
||||
|
||||
|
||||
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||
"""
|
||||
|
@ -78,6 +80,28 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
|||
except InterfaceError:
|
||||
return False
|
||||
|
||||
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
|
||||
"""
|
||||
Override of base method to use our local definition of shape
|
||||
"""
|
||||
from numpydantic.shape import Shape
|
||||
|
||||
if dtype_candidate is Any or dtype_candidate is Shape:
|
||||
shape = Any
|
||||
elif issubclass(dtype_candidate, Shape):
|
||||
shape = dtype_candidate
|
||||
elif cls._is_literal_like(dtype_candidate):
|
||||
shape_expression = dtype_candidate.__args__[0]
|
||||
shape = Shape[shape_expression]
|
||||
else:
|
||||
raise InvalidArgumentsError(
|
||||
f"Unexpected argument '{dtype_candidate}', expecting"
|
||||
" Shape[<ShapeExpression>]"
|
||||
" or Literal[<ShapeExpression>]"
|
||||
" or typing.Any."
|
||||
)
|
||||
return shape
|
||||
|
||||
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
||||
"""
|
||||
Override of base _get_dtype method to allow for compound tuple types
|
||||
|
|
|
@ -5,11 +5,10 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import nptyping.structure
|
||||
import numpy as np
|
||||
from nptyping import Shape
|
||||
from pydantic import SerializationInfo
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
||||
|
@ -19,6 +18,9 @@ from numpydantic.interface import Interface
|
|||
from numpydantic.maps import np_to_python
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpydantic import Shape
|
||||
|
||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
||||
_UNSUPPORTED_TYPES = (complex,)
|
||||
|
||||
|
@ -88,7 +90,7 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
|||
return array_type
|
||||
|
||||
|
||||
def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
|
||||
def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
|
||||
"""
|
||||
Make a pydantic JSON schema for an array as a list of lists.
|
||||
|
||||
|
@ -98,13 +100,15 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
|
|||
This function is typically called from :func:`.make_json_schema`
|
||||
|
||||
Args:
|
||||
shape (:class:`.Shape` ): Shape determines the depth and max/min elements
|
||||
for each layer of list schema
|
||||
shape (:class:`~numpydantic.Shape`): Shape determines the depth and max/min
|
||||
elements for each layer of list schema
|
||||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||
core schema to use in the innermost list entry
|
||||
"""
|
||||
from numpydantic.shape import _is_range
|
||||
|
||||
shape_parts = shape.__args__[0].split(",")
|
||||
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
||||
# labels, if present
|
||||
split_parts = [
|
||||
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
|
||||
]
|
||||
|
@ -129,18 +133,28 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
|
|||
elif arg == "...":
|
||||
list_schema = _unbounded_shape(inner_schema, metadata=metadata)
|
||||
else:
|
||||
try:
|
||||
arg = int(arg)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Array shapes must be integers, wildcards, or ellipses. "
|
||||
"Shape variables (for declaring that one dimension must be the "
|
||||
"same size as another) are not supported because it is "
|
||||
"impossible to express dynamic minItems/maxItems in JSON Schema. "
|
||||
"See: https://github.com/orgs/json-schema-org/discussions/730"
|
||||
) from e
|
||||
if _is_range(arg):
|
||||
arg_min, arg_max = arg.split("-")
|
||||
arg_min = None if arg_min == "*" else int(arg_min)
|
||||
arg_max = None if arg_max == "*" else int(arg_max)
|
||||
|
||||
else:
|
||||
try:
|
||||
arg = int(arg)
|
||||
arg_min = arg
|
||||
arg_max = arg
|
||||
except ValueError as e:
|
||||
|
||||
raise ValueError(
|
||||
"Array shapes must be integers, wildcards, ellipses, or "
|
||||
"ranges. Shape variables (for declaring that one dimension "
|
||||
"must be the same size as another) are not supported because "
|
||||
"it is impossible to express dynamic minItems/maxItems in "
|
||||
"JSON Schema. "
|
||||
"See: https://github.com/orgs/json-schema-org/discussions/730"
|
||||
) from e
|
||||
list_schema = core_schema.list_schema(
|
||||
inner_schema, min_length=arg, max_length=arg, metadata=metadata
|
||||
inner_schema, min_length=arg_min, max_length=arg_max, metadata=metadata
|
||||
)
|
||||
return list_schema
|
||||
|
||||
|
|
211
src/numpydantic/shape.py
Normal file
211
src/numpydantic/shape.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
"""
|
||||
Declaration and validation functions for array shapes.
|
||||
|
||||
Mostly a mildly modified version of nptyping's
|
||||
:func:`npytping.shape_expression.check_shape`
|
||||
and its internals to allow for extended syntax, including ranges of shapes.
|
||||
|
||||
Modifications from nptyping:
|
||||
|
||||
- **"..."** - In nptyping, ``'...'`` means "any number of dimensions with the same shape
|
||||
as the last dimension. ie ``Shape[2, ...]`` means "any number of 2-length
|
||||
dimensions. Here ``'...'`` always means "any number of any-shape dimensions"
|
||||
- **Ranges** - (inclusive) shape ranges are allowed. eg. to specify an array
|
||||
where the first dimension can be 2, 3, or 4 length:
|
||||
|
||||
Shape["2-4, ..."]
|
||||
|
||||
To specify a range with an unbounded min or max, use wildcards, eg. for
|
||||
an array with the first dimension at least length 2, and the second dimension
|
||||
at most length 5 (both inclusive):
|
||||
|
||||
Shape["2-*, *-5"]
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
import string
|
||||
from abc import ABC
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from nptyping.base_meta_classes import ContainerMeta
|
||||
from nptyping.error import InvalidShapeError
|
||||
from nptyping.nptyping_type import NPTypingType
|
||||
from nptyping.shape_expression import (
|
||||
get_dimensions,
|
||||
normalize_shape_expression,
|
||||
remove_labels,
|
||||
)
|
||||
from nptyping.typing_ import ShapeExpression, ShapeTuple
|
||||
|
||||
|
||||
class ShapeMeta(ContainerMeta, implementation="Shape"):
|
||||
"""
|
||||
Metaclass that is coupled to nptyping.Shape.
|
||||
|
||||
Overridden from nptyping to use local shape validation function
|
||||
"""
|
||||
|
||||
def _validate_expression(cls, item: str) -> None:
|
||||
validate_shape_expression(item)
|
||||
|
||||
def _normalize_expression(cls, item: str) -> str:
|
||||
return normalize_shape_expression(item)
|
||||
|
||||
def _get_additional_values(cls, item: Any) -> Dict[str, Any]:
|
||||
dim_strings = get_dimensions(item)
|
||||
dim_string_without_labels = remove_labels(dim_strings)
|
||||
return {"prepared_args": dim_string_without_labels}
|
||||
|
||||
|
||||
class Shape(NPTypingType, ABC, metaclass=ShapeMeta):
|
||||
"""
|
||||
A container for shape expressions that describe the shape of an multi
|
||||
dimensional array.
|
||||
|
||||
Simple example:
|
||||
|
||||
>>> Shape['2, 2']
|
||||
Shape['2, 2']
|
||||
|
||||
A Shape can be compared to a typing.Literal. You can use Literals in
|
||||
NDArray as well.
|
||||
|
||||
>>> from typing import Literal
|
||||
|
||||
>>> Shape['2, 2'] == Literal['2, 2']
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
__args__ = ("*, ...",)
|
||||
prepared_args = ("*", "...")
|
||||
|
||||
|
||||
def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None:
|
||||
"""
|
||||
CHANGES FROM NPTYPING: Allow ranges
|
||||
"""
|
||||
shape_expression_no_quotes = shape_expression.replace("'", "").replace('"', "")
|
||||
if shape_expression is not Any and not re.match(
|
||||
_REGEX_SHAPE_EXPRESSION, shape_expression_no_quotes
|
||||
):
|
||||
raise InvalidShapeError(
|
||||
f"'{shape_expression}' is not a valid shape expression."
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def check_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
||||
"""
|
||||
Check whether the given shape corresponds to the given shape_expression.
|
||||
:param shape: the shape in question.
|
||||
:param target: the shape expression to which shape is tested.
|
||||
:return: True if the given shape corresponds to shape_expression.
|
||||
"""
|
||||
target_shape = _handle_ellipsis(shape, target.prepared_args)
|
||||
return _check_dimensions_against_shape(shape, target_shape)
|
||||
|
||||
|
||||
def _check_dimensions_against_shape(shape: ShapeTuple, target: List[str]) -> bool:
|
||||
# Walk through the shape and test them against the given target,
|
||||
# taking into consideration variables, wildcards, etc.
|
||||
|
||||
if len(shape) != len(target):
|
||||
return False
|
||||
shape_as_strings = (str(dim) for dim in shape)
|
||||
variables: Dict[str, str] = {}
|
||||
for dim, target_dim in zip(shape_as_strings, target):
|
||||
if _is_wildcard(target_dim) or _is_assignable_var(dim, target_dim, variables):
|
||||
continue
|
||||
if _is_range(target_dim) and _check_range(dim, target_dim):
|
||||
continue
|
||||
if dim != target_dim:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _handle_ellipsis(shape: ShapeTuple, target: List[str]) -> List[str]:
|
||||
# Let the ellipsis allows for any number of dimensions by replacing the
|
||||
# ellipsis with the dimension size repeated the number of times that
|
||||
# corresponds to the shape of the instance.
|
||||
if target[-1] == "...":
|
||||
dim_to_repeat = "*"
|
||||
target = target[0:-1]
|
||||
if len(shape) > len(target):
|
||||
difference = len(shape) - len(target)
|
||||
target += difference * [dim_to_repeat]
|
||||
return target
|
||||
|
||||
|
||||
def _is_range(target_dim: str) -> bool:
|
||||
"""Whether the dimension is a range (literally whether it includes a hyphen)"""
|
||||
return "-" in target_dim and len(target_dim.split("-")) == 2
|
||||
|
||||
|
||||
def _check_range(dim: str, target_dim: str) -> bool:
|
||||
"""check whether the given dimension is within the target_dim range"""
|
||||
dim = int(dim)
|
||||
|
||||
range_min, range_max = target_dim.split("-")
|
||||
if _is_wildcard(range_min):
|
||||
return dim <= int(range_max)
|
||||
elif _is_wildcard(range_max):
|
||||
return dim >= int(range_min)
|
||||
else:
|
||||
return int(range_min) <= dim <= int(range_max)
|
||||
|
||||
|
||||
def _is_wildcard(dim: str) -> bool:
|
||||
"""
|
||||
CHANGES FROM NPTYPING: added '*-*' range, which is a wildcard
|
||||
"""
|
||||
# Return whether dim is a wildcard (i.e. the character that takes any
|
||||
# dimension size).
|
||||
return dim == "*" or dim == "*-*"
|
||||
|
||||
|
||||
# CHANGES FROM NPTYPING: Allow ranges
|
||||
_REGEX_SEPARATOR = r"(\s*,\s*)"
|
||||
_REGEX_DIMENSION_SIZE = r"(\s*[0-9]+\s*)"
|
||||
_REGEX_DIMENSION_RANGE = r"(\s*[0-9\*]+-[0-9\*]+\s*)"
|
||||
_REGEX_VARIABLE = r"(\s*\b[A-Z]\w*\s*)"
|
||||
_REGEX_LABEL = r"(\s*\b[a-z]\w*\s*)"
|
||||
_REGEX_LABELS = rf"({_REGEX_LABEL}({_REGEX_SEPARATOR}{_REGEX_LABEL})*)"
|
||||
_REGEX_WILDCARD = r"(\s*\*\s*)"
|
||||
_REGEX_DIMENSION_BREAKDOWN = rf"(\s*\[{_REGEX_LABELS}\]\s*)"
|
||||
_REGEX_DIMENSION = (
|
||||
rf"({_REGEX_DIMENSION_SIZE}"
|
||||
rf"|{_REGEX_DIMENSION_RANGE}"
|
||||
rf"|{_REGEX_VARIABLE}"
|
||||
rf"|{_REGEX_WILDCARD}"
|
||||
rf"|{_REGEX_DIMENSION_BREAKDOWN})"
|
||||
)
|
||||
_REGEX_DIMENSION_WITH_LABEL = rf"({_REGEX_DIMENSION}(\s+{_REGEX_LABEL})*)"
|
||||
_REGEX_DIMENSIONS = (
|
||||
rf"{_REGEX_DIMENSION_WITH_LABEL}({_REGEX_SEPARATOR}{_REGEX_DIMENSION_WITH_LABEL})*"
|
||||
)
|
||||
_REGEX_DIMENSIONS_ELLIPSIS = rf"({_REGEX_DIMENSIONS}{_REGEX_SEPARATOR}\.\.\.\s*)"
|
||||
_REGEX_SHAPE_EXPRESSION = rf"^({_REGEX_DIMENSIONS}|{_REGEX_DIMENSIONS_ELLIPSIS})$"
|
||||
|
||||
# --------------------------------------------------
|
||||
# Below - unchanged from nptyping
|
||||
# --------------------------------------------------
|
||||
|
||||
|
||||
def _is_assignable_var(dim: str, target_dim: str, variables: Dict[str, str]) -> bool:
|
||||
# Return whether target_dim is a variable and can be assigned with dim.
|
||||
return _is_variable(target_dim) and _can_assign_variable(dim, target_dim, variables)
|
||||
|
||||
|
||||
def _is_variable(dim: str) -> bool:
|
||||
# Return whether dim is a variable.
|
||||
return dim[0] in string.ascii_uppercase
|
||||
|
||||
|
||||
def _can_assign_variable(dim: str, target_dim: str, variables: Dict[str, str]) -> bool:
|
||||
# Check and assign a variable.
|
||||
assignable = variables.get(target_dim) in (None, dim)
|
||||
variables[target_dim] = dim
|
||||
return assignable
|
|
@ -7,9 +7,9 @@ import json
|
|||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ValidationError, Field
|
||||
from nptyping import Shape, Number
|
||||
from nptyping import Number
|
||||
|
||||
from numpydantic import NDArray
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.exceptions import ShapeError, DtypeError
|
||||
from numpydantic import dtype
|
||||
|
||||
|
|
80
tests/test_shape.py
Normal file
80
tests/test_shape.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import pdb
|
||||
|
||||
import pytest
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
import numpy as np
|
||||
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape,valid",
|
||||
[
|
||||
((2, 6), True),
|
||||
((2, 7), True),
|
||||
((3, 6), True),
|
||||
((3, 7), True),
|
||||
((4, 6), True),
|
||||
((4, 7), True),
|
||||
((1, 6), False),
|
||||
((5, 6), False),
|
||||
((2, 5), False),
|
||||
((2, 8), False),
|
||||
],
|
||||
)
|
||||
def test_shape_range(shape, valid):
|
||||
"""Specify a dimension with a range of possible sizes"""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Shape["2-4, 6-7"], Any]
|
||||
|
||||
if valid:
|
||||
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape,valid",
|
||||
[
|
||||
((2, 5), True),
|
||||
((10, 5), True),
|
||||
((2, 2), True),
|
||||
((1, 5), False),
|
||||
((2, 6), False),
|
||||
],
|
||||
)
|
||||
def test_shape_wildcard(shape, valid):
|
||||
"""Specify an open-ended minimum or maximum size for a given dimension"""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Shape["2-*, *-5"], Any]
|
||||
|
||||
if valid:
|
||||
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||
else:
|
||||
with pytest.raises(ValidationError):
|
||||
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
|
||||
|
||||
|
||||
def test_range_shape_schema():
|
||||
"""
|
||||
Range shapes should correctly generate JSON Schema
|
||||
"""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array_range: NDArray[Shape["2-4"], Any]
|
||||
array_range_min: NDArray[Shape["2-*"], Any]
|
||||
array_range_max: NDArray[Shape["*-4"], Any]
|
||||
|
||||
schema = MyModel.model_json_schema()
|
||||
assert schema["properties"]["array_range"]["minItems"] == 2
|
||||
assert schema["properties"]["array_range"]["maxItems"] == 4
|
||||
assert schema["properties"]["array_range_min"]["minItems"] == 2
|
||||
assert "maxItems" not in schema["properties"]["array_range_min"]
|
||||
assert schema["properties"]["array_range_max"]["maxItems"] == 4
|
||||
assert "minItems" not in schema["properties"]["array_range_max"]
|
Loading…
Reference in a new issue