add array shape ranges! take over shape specification and checking from nptyping

This commit is contained in:
sneakers-the-rat 2024-06-14 22:38:13 -07:00
parent d567ce0194
commit 07ab3d1b76
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
7 changed files with 350 additions and 22 deletions

View file

@ -8,8 +8,7 @@ apply_patches()
from numpydantic.ndarray import NDArray from numpydantic.ndarray import NDArray
from numpydantic.meta import update_ndarray_stub from numpydantic.meta import update_ndarray_stub
from numpydantic.shape import Shape
from nptyping import Shape
update_ndarray_stub() update_ndarray_stub()

View file

@ -7,7 +7,6 @@ from operator import attrgetter
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
from nptyping.shape_expression import check_shape
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.exceptions import ( from numpydantic.exceptions import (
@ -16,6 +15,7 @@ from numpydantic.exceptions import (
ShapeError, ShapeError,
TooManyMatchesError, TooManyMatchesError,
) )
from numpydantic.shape import check_shape
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
T = TypeVar("T", bound=NDArrayType) T = TypeVar("T", bound=NDArrayType)

View file

@ -42,6 +42,8 @@ from numpydantic.types import DtypeType, ShapeType
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from nptyping.base_meta_classes import SubscriptableMeta from nptyping.base_meta_classes import SubscriptableMeta
from numpydantic import Shape
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
""" """
@ -78,6 +80,28 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
except InterfaceError: except InterfaceError:
return False return False
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
"""
Override of base method to use our local definition of shape
"""
from numpydantic.shape import Shape
if dtype_candidate is Any or dtype_candidate is Shape:
shape = Any
elif issubclass(dtype_candidate, Shape):
shape = dtype_candidate
elif cls._is_literal_like(dtype_candidate):
shape_expression = dtype_candidate.__args__[0]
shape = Shape[shape_expression]
else:
raise InvalidArgumentsError(
f"Unexpected argument '{dtype_candidate}', expecting"
" Shape[<ShapeExpression>]"
" or Literal[<ShapeExpression>]"
" or typing.Any."
)
return shape
def _get_dtype(cls, dtype_candidate: Any) -> DType: def _get_dtype(cls, dtype_candidate: Any) -> DType:
""" """
Override of base _get_dtype method to allow for compound tuple types Override of base _get_dtype method to allow for compound tuple types

View file

@ -5,11 +5,10 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
import hashlib import hashlib
import json import json
from typing import Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import nptyping.structure import nptyping.structure
import numpy as np import numpy as np
from nptyping import Shape
from pydantic import SerializationInfo from pydantic import SerializationInfo
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo from pydantic_core.core_schema import ListSchema, ValidationInfo
@ -19,6 +18,9 @@ from numpydantic.interface import Interface
from numpydantic.maps import np_to_python from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
if TYPE_CHECKING:
from numpydantic import Shape
_handler_type = Callable[[Any], core_schema.CoreSchema] _handler_type = Callable[[Any], core_schema.CoreSchema]
_UNSUPPORTED_TYPES = (complex,) _UNSUPPORTED_TYPES = (complex,)
@ -88,7 +90,7 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
return array_type return array_type
def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema: def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
""" """
Make a pydantic JSON schema for an array as a list of lists. Make a pydantic JSON schema for an array as a list of lists.
@ -98,13 +100,15 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
This function is typically called from :func:`.make_json_schema` This function is typically called from :func:`.make_json_schema`
Args: Args:
shape (:class:`.Shape` ): Shape determines the depth and max/min elements shape (:class:`~numpydantic.Shape`): Shape determines the depth and max/min
for each layer of list schema elements for each layer of list schema
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
core schema to use in the innermost list entry core schema to use in the innermost list entry
""" """
from numpydantic.shape import _is_range
shape_parts = shape.__args__[0].split(",") shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
# labels, if present
split_parts = [ split_parts = [
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
] ]
@ -128,19 +132,29 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
list_schema = core_schema.list_schema(inner_schema, metadata=metadata) list_schema = core_schema.list_schema(inner_schema, metadata=metadata)
elif arg == "...": elif arg == "...":
list_schema = _unbounded_shape(inner_schema, metadata=metadata) list_schema = _unbounded_shape(inner_schema, metadata=metadata)
else:
if _is_range(arg):
arg_min, arg_max = arg.split("-")
arg_min = None if arg_min == "*" else int(arg_min)
arg_max = None if arg_max == "*" else int(arg_max)
else: else:
try: try:
arg = int(arg) arg = int(arg)
arg_min = arg
arg_max = arg
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
"Array shapes must be integers, wildcards, or ellipses. " "Array shapes must be integers, wildcards, ellipses, or "
"Shape variables (for declaring that one dimension must be the " "ranges. Shape variables (for declaring that one dimension "
"same size as another) are not supported because it is " "must be the same size as another) are not supported because "
"impossible to express dynamic minItems/maxItems in JSON Schema. " "it is impossible to express dynamic minItems/maxItems in "
"JSON Schema. "
"See: https://github.com/orgs/json-schema-org/discussions/730" "See: https://github.com/orgs/json-schema-org/discussions/730"
) from e ) from e
list_schema = core_schema.list_schema( list_schema = core_schema.list_schema(
inner_schema, min_length=arg, max_length=arg, metadata=metadata inner_schema, min_length=arg_min, max_length=arg_max, metadata=metadata
) )
return list_schema return list_schema

211
src/numpydantic/shape.py Normal file
View file

@ -0,0 +1,211 @@
"""
Declaration and validation functions for array shapes.
Mostly a mildly modified version of nptyping's
:func:`npytping.shape_expression.check_shape`
and its internals to allow for extended syntax, including ranges of shapes.
Modifications from nptyping:
- **"..."** - In nptyping, ``'...'`` means "any number of dimensions with the same shape
as the last dimension. ie ``Shape[2, ...]`` means "any number of 2-length
dimensions. Here ``'...'`` always means "any number of any-shape dimensions"
- **Ranges** - (inclusive) shape ranges are allowed. eg. to specify an array
where the first dimension can be 2, 3, or 4 length:
Shape["2-4, ..."]
To specify a range with an unbounded min or max, use wildcards, eg. for
an array with the first dimension at least length 2, and the second dimension
at most length 5 (both inclusive):
Shape["2-*, *-5"]
"""
import re
import string
from abc import ABC
from functools import lru_cache
from typing import Any, Dict, List, Union
from nptyping.base_meta_classes import ContainerMeta
from nptyping.error import InvalidShapeError
from nptyping.nptyping_type import NPTypingType
from nptyping.shape_expression import (
get_dimensions,
normalize_shape_expression,
remove_labels,
)
from nptyping.typing_ import ShapeExpression, ShapeTuple
class ShapeMeta(ContainerMeta, implementation="Shape"):
"""
Metaclass that is coupled to nptyping.Shape.
Overridden from nptyping to use local shape validation function
"""
def _validate_expression(cls, item: str) -> None:
validate_shape_expression(item)
def _normalize_expression(cls, item: str) -> str:
return normalize_shape_expression(item)
def _get_additional_values(cls, item: Any) -> Dict[str, Any]:
dim_strings = get_dimensions(item)
dim_string_without_labels = remove_labels(dim_strings)
return {"prepared_args": dim_string_without_labels}
class Shape(NPTypingType, ABC, metaclass=ShapeMeta):
"""
A container for shape expressions that describe the shape of an multi
dimensional array.
Simple example:
>>> Shape['2, 2']
Shape['2, 2']
A Shape can be compared to a typing.Literal. You can use Literals in
NDArray as well.
>>> from typing import Literal
>>> Shape['2, 2'] == Literal['2, 2']
True
"""
__args__ = ("*, ...",)
prepared_args = ("*", "...")
def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None:
"""
CHANGES FROM NPTYPING: Allow ranges
"""
shape_expression_no_quotes = shape_expression.replace("'", "").replace('"', "")
if shape_expression is not Any and not re.match(
_REGEX_SHAPE_EXPRESSION, shape_expression_no_quotes
):
raise InvalidShapeError(
f"'{shape_expression}' is not a valid shape expression."
)
@lru_cache
def check_shape(shape: ShapeTuple, target: "Shape") -> bool:
"""
Check whether the given shape corresponds to the given shape_expression.
:param shape: the shape in question.
:param target: the shape expression to which shape is tested.
:return: True if the given shape corresponds to shape_expression.
"""
target_shape = _handle_ellipsis(shape, target.prepared_args)
return _check_dimensions_against_shape(shape, target_shape)
def _check_dimensions_against_shape(shape: ShapeTuple, target: List[str]) -> bool:
# Walk through the shape and test them against the given target,
# taking into consideration variables, wildcards, etc.
if len(shape) != len(target):
return False
shape_as_strings = (str(dim) for dim in shape)
variables: Dict[str, str] = {}
for dim, target_dim in zip(shape_as_strings, target):
if _is_wildcard(target_dim) or _is_assignable_var(dim, target_dim, variables):
continue
if _is_range(target_dim) and _check_range(dim, target_dim):
continue
if dim != target_dim:
return False
return True
def _handle_ellipsis(shape: ShapeTuple, target: List[str]) -> List[str]:
# Let the ellipsis allows for any number of dimensions by replacing the
# ellipsis with the dimension size repeated the number of times that
# corresponds to the shape of the instance.
if target[-1] == "...":
dim_to_repeat = "*"
target = target[0:-1]
if len(shape) > len(target):
difference = len(shape) - len(target)
target += difference * [dim_to_repeat]
return target
def _is_range(target_dim: str) -> bool:
"""Whether the dimension is a range (literally whether it includes a hyphen)"""
return "-" in target_dim and len(target_dim.split("-")) == 2
def _check_range(dim: str, target_dim: str) -> bool:
"""check whether the given dimension is within the target_dim range"""
dim = int(dim)
range_min, range_max = target_dim.split("-")
if _is_wildcard(range_min):
return dim <= int(range_max)
elif _is_wildcard(range_max):
return dim >= int(range_min)
else:
return int(range_min) <= dim <= int(range_max)
def _is_wildcard(dim: str) -> bool:
"""
CHANGES FROM NPTYPING: added '*-*' range, which is a wildcard
"""
# Return whether dim is a wildcard (i.e. the character that takes any
# dimension size).
return dim == "*" or dim == "*-*"
# CHANGES FROM NPTYPING: Allow ranges
_REGEX_SEPARATOR = r"(\s*,\s*)"
_REGEX_DIMENSION_SIZE = r"(\s*[0-9]+\s*)"
_REGEX_DIMENSION_RANGE = r"(\s*[0-9\*]+-[0-9\*]+\s*)"
_REGEX_VARIABLE = r"(\s*\b[A-Z]\w*\s*)"
_REGEX_LABEL = r"(\s*\b[a-z]\w*\s*)"
_REGEX_LABELS = rf"({_REGEX_LABEL}({_REGEX_SEPARATOR}{_REGEX_LABEL})*)"
_REGEX_WILDCARD = r"(\s*\*\s*)"
_REGEX_DIMENSION_BREAKDOWN = rf"(\s*\[{_REGEX_LABELS}\]\s*)"
_REGEX_DIMENSION = (
rf"({_REGEX_DIMENSION_SIZE}"
rf"|{_REGEX_DIMENSION_RANGE}"
rf"|{_REGEX_VARIABLE}"
rf"|{_REGEX_WILDCARD}"
rf"|{_REGEX_DIMENSION_BREAKDOWN})"
)
_REGEX_DIMENSION_WITH_LABEL = rf"({_REGEX_DIMENSION}(\s+{_REGEX_LABEL})*)"
_REGEX_DIMENSIONS = (
rf"{_REGEX_DIMENSION_WITH_LABEL}({_REGEX_SEPARATOR}{_REGEX_DIMENSION_WITH_LABEL})*"
)
_REGEX_DIMENSIONS_ELLIPSIS = rf"({_REGEX_DIMENSIONS}{_REGEX_SEPARATOR}\.\.\.\s*)"
_REGEX_SHAPE_EXPRESSION = rf"^({_REGEX_DIMENSIONS}|{_REGEX_DIMENSIONS_ELLIPSIS})$"
# --------------------------------------------------
# Below - unchanged from nptyping
# --------------------------------------------------
def _is_assignable_var(dim: str, target_dim: str, variables: Dict[str, str]) -> bool:
# Return whether target_dim is a variable and can be assigned with dim.
return _is_variable(target_dim) and _can_assign_variable(dim, target_dim, variables)
def _is_variable(dim: str) -> bool:
# Return whether dim is a variable.
return dim[0] in string.ascii_uppercase
def _can_assign_variable(dim: str, target_dim: str, variables: Dict[str, str]) -> bool:
# Check and assign a variable.
assignable = variables.get(target_dim) in (None, dim)
variables[target_dim] = dim
return assignable

View file

@ -7,9 +7,9 @@ import json
import numpy as np import numpy as np
from pydantic import BaseModel, ValidationError, Field from pydantic import BaseModel, ValidationError, Field
from nptyping import Shape, Number from nptyping import Number
from numpydantic import NDArray from numpydantic import NDArray, Shape
from numpydantic.exceptions import ShapeError, DtypeError from numpydantic.exceptions import ShapeError, DtypeError
from numpydantic import dtype from numpydantic import dtype

80
tests/test_shape.py Normal file
View file

@ -0,0 +1,80 @@
import pdb
import pytest
from typing import Any
from pydantic import BaseModel, ValidationError
import numpy as np
from numpydantic import NDArray, Shape
@pytest.mark.parametrize(
"shape,valid",
[
((2, 6), True),
((2, 7), True),
((3, 6), True),
((3, 7), True),
((4, 6), True),
((4, 7), True),
((1, 6), False),
((5, 6), False),
((2, 5), False),
((2, 8), False),
],
)
def test_shape_range(shape, valid):
"""Specify a dimension with a range of possible sizes"""
class MyModel(BaseModel):
array: NDArray[Shape["2-4, 6-7"], Any]
if valid:
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
else:
with pytest.raises(ValidationError):
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
@pytest.mark.parametrize(
"shape,valid",
[
((2, 5), True),
((10, 5), True),
((2, 2), True),
((1, 5), False),
((2, 6), False),
],
)
def test_shape_wildcard(shape, valid):
"""Specify an open-ended minimum or maximum size for a given dimension"""
class MyModel(BaseModel):
array: NDArray[Shape["2-*, *-5"], Any]
if valid:
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
else:
with pytest.raises(ValidationError):
_ = MyModel(array=np.zeros(shape, dtype=np.uint8))
def test_range_shape_schema():
"""
Range shapes should correctly generate JSON Schema
"""
class MyModel(BaseModel):
array_range: NDArray[Shape["2-4"], Any]
array_range_min: NDArray[Shape["2-*"], Any]
array_range_max: NDArray[Shape["*-4"], Any]
schema = MyModel.model_json_schema()
assert schema["properties"]["array_range"]["minItems"] == 2
assert schema["properties"]["array_range"]["maxItems"] == 4
assert schema["properties"]["array_range_min"]["minItems"] == 2
assert "maxItems" not in schema["properties"]["array_range_min"]
assert schema["properties"]["array_range_max"]["maxItems"] == 4
assert "minItems" not in schema["properties"]["array_range_max"]