some experimental scratch work on generic typing - don't look @ me lol

This commit is contained in:
sneakers-the-rat 2024-10-02 18:43:01 -07:00
parent 69dbe39557
commit 943df965e5
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
7 changed files with 118 additions and 11 deletions

View file

@ -333,8 +333,11 @@ class Interface(ABC, Generic[T]):
:class:`~numpydantic.exceptions.ShapeError` :class:`~numpydantic.exceptions.ShapeError`
""" """
if not valid: if not valid:
from numpydantic.validation.shape import to_shape
_shape = to_shape(self.shape)
raise ShapeError( raise ShapeError(
f"Invalid shape! expected shape {self.shape.prepared_args}, " f"Invalid shape! expected shape {_shape.prepared_args}, "
f"got shape {shape}" f"got shape {shape}"
) )

View file

@ -0,0 +1,72 @@
from typing import Protocol, TypeVar, runtime_checkable
from typing_extensions import Unpack
from numpydantic.types import DtypeType
# Shape = TypeVarTuple("Shape")
# Shape = tuple[int, ...]
Shape = TypeVar("Shape", bound=tuple[int, ...])
DType = TypeVar("DType", bound=DtypeType)
@runtime_checkable
class NDArray(Protocol[Shape, DType]):
"""v2 generic protocol ndarray"""
@property
def dtype(self) -> DType:
"""dtype"""
@property
def shape(self) -> Unpack[Shape]:
"""shape"""
#
#
# def __get_pydantic_core_schema__(
# typ: Type, handler: CallbackGetCoreSchemaHandler
# ) -> core_schema.CoreSchema:
# args = get_args(typ)
# if len(args) == 0:
# shape, dtype = Any, Any
# elif len(args) == 1:
# shape, dtype = args[0], Any
# elif len(args) == 2:
# shape, dtype = args[0], args[1]
# else:
# shape, dtype = args[:-1], args[-1]
#
# json_schema = make_json_schema(shape, dtype, handler)
# return core_schema.with_info_plain_validator_function(
# get_validate_interface(shape, dtype),
# serialization=core_schema.plain_serializer_function_ser_schema(
# jsonize_array, when_used="json", info_arg=True
# ),
# metadata=json_schema,
# )
#
# def __get_pydantic_json_schema__(
# schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
# ) -> core_schema.JsonSchema:
# # shape, dtype = cls.__args__
# json_schema = handler(schema["metadata"])
# json_schema = handler.resolve_ref_schema(json_schema)
#
# # if not isinstance(dtype, tuple) and dtype.__module__ not in (
# # "builtins",
# # "typing",
# # ):
# # json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
#
# return json_schema
# NDArray = Annotated[
# _NDArray[Unpack[Shape], DType],
# GetPydanticSchema(__get_pydantic_core_schema__),
# # GetJsonSchemaFunction(__get_pydantic_json_schema__),
# ]

0
src/numpydantic/py.typed Normal file
View file

View file

@ -113,7 +113,9 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
core schema to use in the innermost list entry core schema to use in the innermost list entry
""" """
from numpydantic.validation.shape import _is_range from numpydantic.validation.shape import _is_range, to_shape
shape = to_shape(shape)
shape_parts = [part.strip() for part in shape.__args__[0].split(",")] shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
# labels, if present # labels, if present

View file

@ -91,6 +91,16 @@ class Shape(NPTypingType, ABC, metaclass=ShapeMeta):
prepared_args = ("*", "...") prepared_args = ("*", "...")
def to_shape(shape) -> "Shape":
from numpydantic import Shape
if isinstance(shape, int):
shape = Shape[f"{shape}"]
elif isinstance(shape, tuple):
shape = Shape[f"{', '.join([s for s in shape])}"]
return shape
def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None: def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None:
""" """
CHANGES FROM NPTYPING: Allow ranges CHANGES FROM NPTYPING: Allow ranges
@ -112,6 +122,7 @@ def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
:param target: the shape expression to which shape is tested. :param target: the shape expression to which shape is tested.
:return: True if the given shape corresponds to shape_expression. :return: True if the given shape corresponds to shape_expression.
""" """
target = to_shape(target)
target_shape = _handle_ellipsis(shape, target.prepared_args) target_shape = _handle_ellipsis(shape, target.prepared_args)
return _check_dimensions_against_shape(shape, target_shape) return _check_dimensions_against_shape(shape, target_shape)

19
tests/test_generic.py Normal file
View file

@ -0,0 +1,19 @@
from numpydantic.ndarray_generic import NDArray
from pydantic import BaseModel
from typing import Literal as L
import numpy as np
class MyClass(BaseModel):
array: NDArray[L[4, 5], int]
model = MyClass(array=np.array([1, 2, 3, 4]))
model2 = MyClass(array=np.array([1, 2, 3]))
model3 = MyClass(array=(1, 2))
array: NDArray[L[4], np.int64] = np.array([1, 2, 3])
array2: NDArray[L[3], np.int64] = [1, 2, 3]

View file

@ -1,16 +1,15 @@
import pytest
from typing import Union, Optional, Any
import json import json
from typing import Any, Optional, Union
import numpy as np import numpy as np
from pydantic import BaseModel, ValidationError, Field import pytest
from pydantic import BaseModel, Field, ValidationError
# from numpydantic import NDArray, Shape
from numpydantic import NDArray, Shape from numpydantic import Shape, dtype
from numpydantic.exceptions import ShapeError, DtypeError
from numpydantic import dtype
from numpydantic.dtype import Number from numpydantic.dtype import Number
from numpydantic.exceptions import DtypeError
from numpydantic.ndarray_generic import NDArray
@pytest.mark.json_schema @pytest.mark.json_schema
@ -177,9 +176,10 @@ def test_shape_ellipsis():
""" """
class MyModel(BaseModel): class MyModel(BaseModel):
array: NDArray[Shape["1, 2, ..."], Number] array: NDArray[1, 2, ..., Number]
_ = MyModel(array=np.zeros((1, 2, 3, 4, 5))) _ = MyModel(array=np.zeros((1, 2, 3, 4, 5)))
_ = MyModel(array="hey")
@pytest.mark.serialization @pytest.mark.serialization