From 07ab3d1b7622989361c7c6505821135c68aec238 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 14 Jun 2024 22:38:13 -0700 Subject: [PATCH] add array shape ranges! take over shape specification and checking from nptyping --- src/numpydantic/__init__.py | 3 +- src/numpydantic/interface/interface.py | 2 +- src/numpydantic/ndarray.py | 24 +++ src/numpydantic/schema.py | 48 ++++-- src/numpydantic/shape.py | 211 +++++++++++++++++++++++++ tests/test_ndarray.py | 4 +- tests/test_shape.py | 80 ++++++++++ 7 files changed, 350 insertions(+), 22 deletions(-) create mode 100644 src/numpydantic/shape.py create mode 100644 tests/test_shape.py diff --git a/src/numpydantic/__init__.py b/src/numpydantic/__init__.py index 33bb17f..d251f8d 100644 --- a/src/numpydantic/__init__.py +++ b/src/numpydantic/__init__.py @@ -8,8 +8,7 @@ apply_patches() from numpydantic.ndarray import NDArray from numpydantic.meta import update_ndarray_stub - -from nptyping import Shape +from numpydantic.shape import Shape update_ndarray_stub() diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index d7d19c1..360b7f0 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -7,7 +7,6 @@ from operator import attrgetter from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union import numpy as np -from nptyping.shape_expression import check_shape from pydantic import SerializationInfo from numpydantic.exceptions import ( @@ -16,6 +15,7 @@ from numpydantic.exceptions import ( ShapeError, TooManyMatchesError, ) +from numpydantic.shape import check_shape from numpydantic.types import DtypeType, NDArrayType, ShapeType T = TypeVar("T", bound=NDArrayType) diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index 5bc539c..e62df59 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -42,6 +42,8 @@ from numpydantic.types import DtypeType, ShapeType if TYPE_CHECKING: # pragma: no cover from nptyping.base_meta_classes import SubscriptableMeta + from numpydantic import Shape + class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): """ @@ -78,6 +80,28 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): except InterfaceError: return False + def _get_shape(cls, dtype_candidate: Any) -> "Shape": + """ + Override of base method to use our local definition of shape + """ + from numpydantic.shape import Shape + + if dtype_candidate is Any or dtype_candidate is Shape: + shape = Any + elif issubclass(dtype_candidate, Shape): + shape = dtype_candidate + elif cls._is_literal_like(dtype_candidate): + shape_expression = dtype_candidate.__args__[0] + shape = Shape[shape_expression] + else: + raise InvalidArgumentsError( + f"Unexpected argument '{dtype_candidate}', expecting" + " Shape[]" + " or Literal[]" + " or typing.Any." + ) + return shape + def _get_dtype(cls, dtype_candidate: Any) -> DType: """ Override of base _get_dtype method to allow for compound tuple types diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index e610df4..084ac7c 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -5,11 +5,10 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in import hashlib import json -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import nptyping.structure import numpy as np -from nptyping import Shape from pydantic import SerializationInfo from pydantic_core import CoreSchema, core_schema from pydantic_core.core_schema import ListSchema, ValidationInfo @@ -19,6 +18,9 @@ from numpydantic.interface import Interface from numpydantic.maps import np_to_python from numpydantic.types import DtypeType, NDArrayType, ShapeType +if TYPE_CHECKING: + from numpydantic import Shape + _handler_type = Callable[[Any], core_schema.CoreSchema] _UNSUPPORTED_TYPES = (complex,) @@ -88,7 +90,7 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: return array_type -def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema: +def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema: """ Make a pydantic JSON schema for an array as a list of lists. @@ -98,13 +100,15 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema: This function is typically called from :func:`.make_json_schema` Args: - shape (:class:`.Shape` ): Shape determines the depth and max/min elements - for each layer of list schema + shape (:class:`~numpydantic.Shape`): Shape determines the depth and max/min + elements for each layer of list schema array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic core schema to use in the innermost list entry """ + from numpydantic.shape import _is_range - shape_parts = shape.__args__[0].split(",") + shape_parts = [part.strip() for part in shape.__args__[0].split(",")] + # labels, if present split_parts = [ p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts ] @@ -129,18 +133,28 @@ def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema: elif arg == "...": list_schema = _unbounded_shape(inner_schema, metadata=metadata) else: - try: - arg = int(arg) - except ValueError as e: - raise ValueError( - "Array shapes must be integers, wildcards, or ellipses. " - "Shape variables (for declaring that one dimension must be the " - "same size as another) are not supported because it is " - "impossible to express dynamic minItems/maxItems in JSON Schema. " - "See: https://github.com/orgs/json-schema-org/discussions/730" - ) from e + if _is_range(arg): + arg_min, arg_max = arg.split("-") + arg_min = None if arg_min == "*" else int(arg_min) + arg_max = None if arg_max == "*" else int(arg_max) + + else: + try: + arg = int(arg) + arg_min = arg + arg_max = arg + except ValueError as e: + + raise ValueError( + "Array shapes must be integers, wildcards, ellipses, or " + "ranges. Shape variables (for declaring that one dimension " + "must be the same size as another) are not supported because " + "it is impossible to express dynamic minItems/maxItems in " + "JSON Schema. " + "See: https://github.com/orgs/json-schema-org/discussions/730" + ) from e list_schema = core_schema.list_schema( - inner_schema, min_length=arg, max_length=arg, metadata=metadata + inner_schema, min_length=arg_min, max_length=arg_max, metadata=metadata ) return list_schema diff --git a/src/numpydantic/shape.py b/src/numpydantic/shape.py new file mode 100644 index 0000000..3b82149 --- /dev/null +++ b/src/numpydantic/shape.py @@ -0,0 +1,211 @@ +""" +Declaration and validation functions for array shapes. + +Mostly a mildly modified version of nptyping's +:func:`npytping.shape_expression.check_shape` +and its internals to allow for extended syntax, including ranges of shapes. + +Modifications from nptyping: + +- **"..."** - In nptyping, ``'...'`` means "any number of dimensions with the same shape + as the last dimension. ie ``Shape[2, ...]`` means "any number of 2-length + dimensions. Here ``'...'`` always means "any number of any-shape dimensions" +- **Ranges** - (inclusive) shape ranges are allowed. eg. to specify an array + where the first dimension can be 2, 3, or 4 length: + + Shape["2-4, ..."] + + To specify a range with an unbounded min or max, use wildcards, eg. for + an array with the first dimension at least length 2, and the second dimension + at most length 5 (both inclusive): + + Shape["2-*, *-5"] + +""" + +import re +import string +from abc import ABC +from functools import lru_cache +from typing import Any, Dict, List, Union + +from nptyping.base_meta_classes import ContainerMeta +from nptyping.error import InvalidShapeError +from nptyping.nptyping_type import NPTypingType +from nptyping.shape_expression import ( + get_dimensions, + normalize_shape_expression, + remove_labels, +) +from nptyping.typing_ import ShapeExpression, ShapeTuple + + +class ShapeMeta(ContainerMeta, implementation="Shape"): + """ + Metaclass that is coupled to nptyping.Shape. + + Overridden from nptyping to use local shape validation function + """ + + def _validate_expression(cls, item: str) -> None: + validate_shape_expression(item) + + def _normalize_expression(cls, item: str) -> str: + return normalize_shape_expression(item) + + def _get_additional_values(cls, item: Any) -> Dict[str, Any]: + dim_strings = get_dimensions(item) + dim_string_without_labels = remove_labels(dim_strings) + return {"prepared_args": dim_string_without_labels} + + +class Shape(NPTypingType, ABC, metaclass=ShapeMeta): + """ + A container for shape expressions that describe the shape of an multi + dimensional array. + + Simple example: + + >>> Shape['2, 2'] + Shape['2, 2'] + + A Shape can be compared to a typing.Literal. You can use Literals in + NDArray as well. + + >>> from typing import Literal + + >>> Shape['2, 2'] == Literal['2, 2'] + True + + """ + + __args__ = ("*, ...",) + prepared_args = ("*", "...") + + +def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None: + """ + CHANGES FROM NPTYPING: Allow ranges + """ + shape_expression_no_quotes = shape_expression.replace("'", "").replace('"', "") + if shape_expression is not Any and not re.match( + _REGEX_SHAPE_EXPRESSION, shape_expression_no_quotes + ): + raise InvalidShapeError( + f"'{shape_expression}' is not a valid shape expression." + ) + + +@lru_cache +def check_shape(shape: ShapeTuple, target: "Shape") -> bool: + """ + Check whether the given shape corresponds to the given shape_expression. + :param shape: the shape in question. + :param target: the shape expression to which shape is tested. + :return: True if the given shape corresponds to shape_expression. + """ + target_shape = _handle_ellipsis(shape, target.prepared_args) + return _check_dimensions_against_shape(shape, target_shape) + + +def _check_dimensions_against_shape(shape: ShapeTuple, target: List[str]) -> bool: + # Walk through the shape and test them against the given target, + # taking into consideration variables, wildcards, etc. + + if len(shape) != len(target): + return False + shape_as_strings = (str(dim) for dim in shape) + variables: Dict[str, str] = {} + for dim, target_dim in zip(shape_as_strings, target): + if _is_wildcard(target_dim) or _is_assignable_var(dim, target_dim, variables): + continue + if _is_range(target_dim) and _check_range(dim, target_dim): + continue + if dim != target_dim: + return False + return True + + +def _handle_ellipsis(shape: ShapeTuple, target: List[str]) -> List[str]: + # Let the ellipsis allows for any number of dimensions by replacing the + # ellipsis with the dimension size repeated the number of times that + # corresponds to the shape of the instance. + if target[-1] == "...": + dim_to_repeat = "*" + target = target[0:-1] + if len(shape) > len(target): + difference = len(shape) - len(target) + target += difference * [dim_to_repeat] + return target + + +def _is_range(target_dim: str) -> bool: + """Whether the dimension is a range (literally whether it includes a hyphen)""" + return "-" in target_dim and len(target_dim.split("-")) == 2 + + +def _check_range(dim: str, target_dim: str) -> bool: + """check whether the given dimension is within the target_dim range""" + dim = int(dim) + + range_min, range_max = target_dim.split("-") + if _is_wildcard(range_min): + return dim <= int(range_max) + elif _is_wildcard(range_max): + return dim >= int(range_min) + else: + return int(range_min) <= dim <= int(range_max) + + +def _is_wildcard(dim: str) -> bool: + """ + CHANGES FROM NPTYPING: added '*-*' range, which is a wildcard + """ + # Return whether dim is a wildcard (i.e. the character that takes any + # dimension size). + return dim == "*" or dim == "*-*" + + +# CHANGES FROM NPTYPING: Allow ranges +_REGEX_SEPARATOR = r"(\s*,\s*)" +_REGEX_DIMENSION_SIZE = r"(\s*[0-9]+\s*)" +_REGEX_DIMENSION_RANGE = r"(\s*[0-9\*]+-[0-9\*]+\s*)" +_REGEX_VARIABLE = r"(\s*\b[A-Z]\w*\s*)" +_REGEX_LABEL = r"(\s*\b[a-z]\w*\s*)" +_REGEX_LABELS = rf"({_REGEX_LABEL}({_REGEX_SEPARATOR}{_REGEX_LABEL})*)" +_REGEX_WILDCARD = r"(\s*\*\s*)" +_REGEX_DIMENSION_BREAKDOWN = rf"(\s*\[{_REGEX_LABELS}\]\s*)" +_REGEX_DIMENSION = ( + rf"({_REGEX_DIMENSION_SIZE}" + rf"|{_REGEX_DIMENSION_RANGE}" + rf"|{_REGEX_VARIABLE}" + rf"|{_REGEX_WILDCARD}" + rf"|{_REGEX_DIMENSION_BREAKDOWN})" +) +_REGEX_DIMENSION_WITH_LABEL = rf"({_REGEX_DIMENSION}(\s+{_REGEX_LABEL})*)" +_REGEX_DIMENSIONS = ( + rf"{_REGEX_DIMENSION_WITH_LABEL}({_REGEX_SEPARATOR}{_REGEX_DIMENSION_WITH_LABEL})*" +) +_REGEX_DIMENSIONS_ELLIPSIS = rf"({_REGEX_DIMENSIONS}{_REGEX_SEPARATOR}\.\.\.\s*)" +_REGEX_SHAPE_EXPRESSION = rf"^({_REGEX_DIMENSIONS}|{_REGEX_DIMENSIONS_ELLIPSIS})$" + +# -------------------------------------------------- +# Below - unchanged from nptyping +# -------------------------------------------------- + + +def _is_assignable_var(dim: str, target_dim: str, variables: Dict[str, str]) -> bool: + # Return whether target_dim is a variable and can be assigned with dim. + return _is_variable(target_dim) and _can_assign_variable(dim, target_dim, variables) + + +def _is_variable(dim: str) -> bool: + # Return whether dim is a variable. + return dim[0] in string.ascii_uppercase + + +def _can_assign_variable(dim: str, target_dim: str, variables: Dict[str, str]) -> bool: + # Check and assign a variable. + assignable = variables.get(target_dim) in (None, dim) + variables[target_dim] = dim + return assignable diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 192c57a..f30cc98 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -7,9 +7,9 @@ import json import numpy as np from pydantic import BaseModel, ValidationError, Field -from nptyping import Shape, Number +from nptyping import Number -from numpydantic import NDArray +from numpydantic import NDArray, Shape from numpydantic.exceptions import ShapeError, DtypeError from numpydantic import dtype diff --git a/tests/test_shape.py b/tests/test_shape.py new file mode 100644 index 0000000..3abff19 --- /dev/null +++ b/tests/test_shape.py @@ -0,0 +1,80 @@ +import pdb + +import pytest + +from typing import Any + +from pydantic import BaseModel, ValidationError +import numpy as np + +from numpydantic import NDArray, Shape + + +@pytest.mark.parametrize( + "shape,valid", + [ + ((2, 6), True), + ((2, 7), True), + ((3, 6), True), + ((3, 7), True), + ((4, 6), True), + ((4, 7), True), + ((1, 6), False), + ((5, 6), False), + ((2, 5), False), + ((2, 8), False), + ], +) +def test_shape_range(shape, valid): + """Specify a dimension with a range of possible sizes""" + + class MyModel(BaseModel): + array: NDArray[Shape["2-4, 6-7"], Any] + + if valid: + _ = MyModel(array=np.zeros(shape, dtype=np.uint8)) + else: + with pytest.raises(ValidationError): + _ = MyModel(array=np.zeros(shape, dtype=np.uint8)) + + +@pytest.mark.parametrize( + "shape,valid", + [ + ((2, 5), True), + ((10, 5), True), + ((2, 2), True), + ((1, 5), False), + ((2, 6), False), + ], +) +def test_shape_wildcard(shape, valid): + """Specify an open-ended minimum or maximum size for a given dimension""" + + class MyModel(BaseModel): + array: NDArray[Shape["2-*, *-5"], Any] + + if valid: + _ = MyModel(array=np.zeros(shape, dtype=np.uint8)) + else: + with pytest.raises(ValidationError): + _ = MyModel(array=np.zeros(shape, dtype=np.uint8)) + + +def test_range_shape_schema(): + """ + Range shapes should correctly generate JSON Schema + """ + + class MyModel(BaseModel): + array_range: NDArray[Shape["2-4"], Any] + array_range_min: NDArray[Shape["2-*"], Any] + array_range_max: NDArray[Shape["*-4"], Any] + + schema = MyModel.model_json_schema() + assert schema["properties"]["array_range"]["minItems"] == 2 + assert schema["properties"]["array_range"]["maxItems"] == 4 + assert schema["properties"]["array_range_min"]["minItems"] == 2 + assert "maxItems" not in schema["properties"]["array_range_min"] + assert schema["properties"]["array_range_max"]["maxItems"] == 4 + assert "minItems" not in schema["properties"]["array_range_max"]