mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 13:44:26 +00:00
some experimental scratch work on generic typing - don't look @ me lol
This commit is contained in:
parent
69dbe39557
commit
943df965e5
7 changed files with 118 additions and 11 deletions
|
@ -333,8 +333,11 @@ class Interface(ABC, Generic[T]):
|
|||
:class:`~numpydantic.exceptions.ShapeError`
|
||||
"""
|
||||
if not valid:
|
||||
from numpydantic.validation.shape import to_shape
|
||||
|
||||
_shape = to_shape(self.shape)
|
||||
raise ShapeError(
|
||||
f"Invalid shape! expected shape {self.shape.prepared_args}, "
|
||||
f"Invalid shape! expected shape {_shape.prepared_args}, "
|
||||
f"got shape {shape}"
|
||||
)
|
||||
|
||||
|
|
72
src/numpydantic/ndarray_generic.py
Normal file
72
src/numpydantic/ndarray_generic.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
from typing import Protocol, TypeVar, runtime_checkable
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from numpydantic.types import DtypeType
|
||||
|
||||
# Shape = TypeVarTuple("Shape")
|
||||
# Shape = tuple[int, ...]
|
||||
Shape = TypeVar("Shape", bound=tuple[int, ...])
|
||||
DType = TypeVar("DType", bound=DtypeType)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class NDArray(Protocol[Shape, DType]):
|
||||
"""v2 generic protocol ndarray"""
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
"""dtype"""
|
||||
|
||||
@property
|
||||
def shape(self) -> Unpack[Shape]:
|
||||
"""shape"""
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
# def __get_pydantic_core_schema__(
|
||||
# typ: Type, handler: CallbackGetCoreSchemaHandler
|
||||
# ) -> core_schema.CoreSchema:
|
||||
# args = get_args(typ)
|
||||
# if len(args) == 0:
|
||||
# shape, dtype = Any, Any
|
||||
# elif len(args) == 1:
|
||||
# shape, dtype = args[0], Any
|
||||
# elif len(args) == 2:
|
||||
# shape, dtype = args[0], args[1]
|
||||
# else:
|
||||
# shape, dtype = args[:-1], args[-1]
|
||||
#
|
||||
# json_schema = make_json_schema(shape, dtype, handler)
|
||||
# return core_schema.with_info_plain_validator_function(
|
||||
# get_validate_interface(shape, dtype),
|
||||
# serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
# jsonize_array, when_used="json", info_arg=True
|
||||
# ),
|
||||
# metadata=json_schema,
|
||||
# )
|
||||
|
||||
|
||||
#
|
||||
# def __get_pydantic_json_schema__(
|
||||
# schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
# ) -> core_schema.JsonSchema:
|
||||
# # shape, dtype = cls.__args__
|
||||
# json_schema = handler(schema["metadata"])
|
||||
# json_schema = handler.resolve_ref_schema(json_schema)
|
||||
#
|
||||
# # if not isinstance(dtype, tuple) and dtype.__module__ not in (
|
||||
# # "builtins",
|
||||
# # "typing",
|
||||
# # ):
|
||||
# # json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
|
||||
#
|
||||
# return json_schema
|
||||
|
||||
|
||||
# NDArray = Annotated[
|
||||
# _NDArray[Unpack[Shape], DType],
|
||||
# GetPydanticSchema(__get_pydantic_core_schema__),
|
||||
# # GetJsonSchemaFunction(__get_pydantic_json_schema__),
|
||||
# ]
|
0
src/numpydantic/py.typed
Normal file
0
src/numpydantic/py.typed
Normal file
|
@ -113,7 +113,9 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
|
|||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||
core schema to use in the innermost list entry
|
||||
"""
|
||||
from numpydantic.validation.shape import _is_range
|
||||
from numpydantic.validation.shape import _is_range, to_shape
|
||||
|
||||
shape = to_shape(shape)
|
||||
|
||||
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
||||
# labels, if present
|
||||
|
|
|
@ -91,6 +91,16 @@ class Shape(NPTypingType, ABC, metaclass=ShapeMeta):
|
|||
prepared_args = ("*", "...")
|
||||
|
||||
|
||||
def to_shape(shape) -> "Shape":
|
||||
from numpydantic import Shape
|
||||
|
||||
if isinstance(shape, int):
|
||||
shape = Shape[f"{shape}"]
|
||||
elif isinstance(shape, tuple):
|
||||
shape = Shape[f"{', '.join([s for s in shape])}"]
|
||||
return shape
|
||||
|
||||
|
||||
def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None:
|
||||
"""
|
||||
CHANGES FROM NPTYPING: Allow ranges
|
||||
|
@ -112,6 +122,7 @@ def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
|||
:param target: the shape expression to which shape is tested.
|
||||
:return: True if the given shape corresponds to shape_expression.
|
||||
"""
|
||||
target = to_shape(target)
|
||||
target_shape = _handle_ellipsis(shape, target.prepared_args)
|
||||
return _check_dimensions_against_shape(shape, target_shape)
|
||||
|
||||
|
|
19
tests/test_generic.py
Normal file
19
tests/test_generic.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from numpydantic.ndarray_generic import NDArray
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal as L
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyClass(BaseModel):
|
||||
array: NDArray[L[4, 5], int]
|
||||
|
||||
|
||||
model = MyClass(array=np.array([1, 2, 3, 4]))
|
||||
|
||||
model2 = MyClass(array=np.array([1, 2, 3]))
|
||||
|
||||
model3 = MyClass(array=(1, 2))
|
||||
|
||||
|
||||
array: NDArray[L[4], np.int64] = np.array([1, 2, 3])
|
||||
array2: NDArray[L[3], np.int64] = [1, 2, 3]
|
|
@ -1,16 +1,15 @@
|
|||
import pytest
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
import json
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ValidationError, Field
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.exceptions import ShapeError, DtypeError
|
||||
from numpydantic import dtype
|
||||
# from numpydantic import NDArray, Shape
|
||||
from numpydantic import Shape, dtype
|
||||
from numpydantic.dtype import Number
|
||||
from numpydantic.exceptions import DtypeError
|
||||
from numpydantic.ndarray_generic import NDArray
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
|
@ -177,9 +176,10 @@ def test_shape_ellipsis():
|
|||
"""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Shape["1, 2, ..."], Number]
|
||||
array: NDArray[1, 2, ..., Number]
|
||||
|
||||
_ = MyModel(array=np.zeros((1, 2, 3, 4, 5)))
|
||||
_ = MyModel(array="hey")
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
|
|
Loading…
Reference in a new issue