From 943df965e5fa67f01972806368a0e4cae26a910a Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 2 Oct 2024 18:43:01 -0700 Subject: [PATCH] some experimental scratch work on generic typing - don't look @ me lol --- src/numpydantic/interface/interface.py | 5 +- src/numpydantic/ndarray_generic.py | 72 ++++++++++++++++++++++++++ src/numpydantic/py.typed | 0 src/numpydantic/schema.py | 4 +- src/numpydantic/validation/shape.py | 11 ++++ tests/test_generic.py | 19 +++++++ tests/test_ndarray.py | 18 +++---- 7 files changed, 118 insertions(+), 11 deletions(-) create mode 100644 src/numpydantic/ndarray_generic.py create mode 100644 src/numpydantic/py.typed create mode 100644 tests/test_generic.py diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 42bb891..5533827 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -333,8 +333,11 @@ class Interface(ABC, Generic[T]): :class:`~numpydantic.exceptions.ShapeError` """ if not valid: + from numpydantic.validation.shape import to_shape + + _shape = to_shape(self.shape) raise ShapeError( - f"Invalid shape! expected shape {self.shape.prepared_args}, " + f"Invalid shape! expected shape {_shape.prepared_args}, " f"got shape {shape}" ) diff --git a/src/numpydantic/ndarray_generic.py b/src/numpydantic/ndarray_generic.py new file mode 100644 index 0000000..d298132 --- /dev/null +++ b/src/numpydantic/ndarray_generic.py @@ -0,0 +1,72 @@ +from typing import Protocol, TypeVar, runtime_checkable + +from typing_extensions import Unpack + +from numpydantic.types import DtypeType + +# Shape = TypeVarTuple("Shape") +# Shape = tuple[int, ...] +Shape = TypeVar("Shape", bound=tuple[int, ...]) +DType = TypeVar("DType", bound=DtypeType) + + +@runtime_checkable +class NDArray(Protocol[Shape, DType]): + """v2 generic protocol ndarray""" + + @property + def dtype(self) -> DType: + """dtype""" + + @property + def shape(self) -> Unpack[Shape]: + """shape""" + + +# +# +# def __get_pydantic_core_schema__( +# typ: Type, handler: CallbackGetCoreSchemaHandler +# ) -> core_schema.CoreSchema: +# args = get_args(typ) +# if len(args) == 0: +# shape, dtype = Any, Any +# elif len(args) == 1: +# shape, dtype = args[0], Any +# elif len(args) == 2: +# shape, dtype = args[0], args[1] +# else: +# shape, dtype = args[:-1], args[-1] +# +# json_schema = make_json_schema(shape, dtype, handler) +# return core_schema.with_info_plain_validator_function( +# get_validate_interface(shape, dtype), +# serialization=core_schema.plain_serializer_function_ser_schema( +# jsonize_array, when_used="json", info_arg=True +# ), +# metadata=json_schema, +# ) + + +# +# def __get_pydantic_json_schema__( +# schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler +# ) -> core_schema.JsonSchema: +# # shape, dtype = cls.__args__ +# json_schema = handler(schema["metadata"]) +# json_schema = handler.resolve_ref_schema(json_schema) +# +# # if not isinstance(dtype, tuple) and dtype.__module__ not in ( +# # "builtins", +# # "typing", +# # ): +# # json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__]) +# +# return json_schema + + +# NDArray = Annotated[ +# _NDArray[Unpack[Shape], DType], +# GetPydanticSchema(__get_pydantic_core_schema__), +# # GetJsonSchemaFunction(__get_pydantic_json_schema__), +# ] diff --git a/src/numpydantic/py.typed b/src/numpydantic/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index bfea3aa..8683166 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -113,7 +113,9 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema: array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic core schema to use in the innermost list entry """ - from numpydantic.validation.shape import _is_range + from numpydantic.validation.shape import _is_range, to_shape + + shape = to_shape(shape) shape_parts = [part.strip() for part in shape.__args__[0].split(",")] # labels, if present diff --git a/src/numpydantic/validation/shape.py b/src/numpydantic/validation/shape.py index e899ecd..20ad9a3 100644 --- a/src/numpydantic/validation/shape.py +++ b/src/numpydantic/validation/shape.py @@ -91,6 +91,16 @@ class Shape(NPTypingType, ABC, metaclass=ShapeMeta): prepared_args = ("*", "...") +def to_shape(shape) -> "Shape": + from numpydantic import Shape + + if isinstance(shape, int): + shape = Shape[f"{shape}"] + elif isinstance(shape, tuple): + shape = Shape[f"{', '.join([s for s in shape])}"] + return shape + + def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None: """ CHANGES FROM NPTYPING: Allow ranges @@ -112,6 +122,7 @@ def validate_shape(shape: ShapeTuple, target: "Shape") -> bool: :param target: the shape expression to which shape is tested. :return: True if the given shape corresponds to shape_expression. """ + target = to_shape(target) target_shape = _handle_ellipsis(shape, target.prepared_args) return _check_dimensions_against_shape(shape, target_shape) diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 0000000..f511ccd --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,19 @@ +from numpydantic.ndarray_generic import NDArray +from pydantic import BaseModel +from typing import Literal as L +import numpy as np + + +class MyClass(BaseModel): + array: NDArray[L[4, 5], int] + + +model = MyClass(array=np.array([1, 2, 3, 4])) + +model2 = MyClass(array=np.array([1, 2, 3])) + +model3 = MyClass(array=(1, 2)) + + +array: NDArray[L[4], np.int64] = np.array([1, 2, 3]) +array2: NDArray[L[3], np.int64] = [1, 2, 3] diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index cda092c..3042e29 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -1,16 +1,15 @@ -import pytest - -from typing import Union, Optional, Any import json +from typing import Any, Optional, Union import numpy as np -from pydantic import BaseModel, ValidationError, Field +import pytest +from pydantic import BaseModel, Field, ValidationError - -from numpydantic import NDArray, Shape -from numpydantic.exceptions import ShapeError, DtypeError -from numpydantic import dtype +# from numpydantic import NDArray, Shape +from numpydantic import Shape, dtype from numpydantic.dtype import Number +from numpydantic.exceptions import DtypeError +from numpydantic.ndarray_generic import NDArray @pytest.mark.json_schema @@ -177,9 +176,10 @@ def test_shape_ellipsis(): """ class MyModel(BaseModel): - array: NDArray[Shape["1, 2, ..."], Number] + array: NDArray[1, 2, ..., Number] _ = MyModel(array=np.zeros((1, 2, 3, 4, 5))) + _ = MyModel(array="hey") @pytest.mark.serialization