mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
some experimental scratch work on generic typing - don't look @ me lol
This commit is contained in:
parent
69dbe39557
commit
943df965e5
7 changed files with 118 additions and 11 deletions
|
@ -333,8 +333,11 @@ class Interface(ABC, Generic[T]):
|
||||||
:class:`~numpydantic.exceptions.ShapeError`
|
:class:`~numpydantic.exceptions.ShapeError`
|
||||||
"""
|
"""
|
||||||
if not valid:
|
if not valid:
|
||||||
|
from numpydantic.validation.shape import to_shape
|
||||||
|
|
||||||
|
_shape = to_shape(self.shape)
|
||||||
raise ShapeError(
|
raise ShapeError(
|
||||||
f"Invalid shape! expected shape {self.shape.prepared_args}, "
|
f"Invalid shape! expected shape {_shape.prepared_args}, "
|
||||||
f"got shape {shape}"
|
f"got shape {shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
72
src/numpydantic/ndarray_generic.py
Normal file
72
src/numpydantic/ndarray_generic.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
from typing import Protocol, TypeVar, runtime_checkable
|
||||||
|
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
from numpydantic.types import DtypeType
|
||||||
|
|
||||||
|
# Shape = TypeVarTuple("Shape")
|
||||||
|
# Shape = tuple[int, ...]
|
||||||
|
Shape = TypeVar("Shape", bound=tuple[int, ...])
|
||||||
|
DType = TypeVar("DType", bound=DtypeType)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class NDArray(Protocol[Shape, DType]):
|
||||||
|
"""v2 generic protocol ndarray"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
"""dtype"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> Unpack[Shape]:
|
||||||
|
"""shape"""
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def __get_pydantic_core_schema__(
|
||||||
|
# typ: Type, handler: CallbackGetCoreSchemaHandler
|
||||||
|
# ) -> core_schema.CoreSchema:
|
||||||
|
# args = get_args(typ)
|
||||||
|
# if len(args) == 0:
|
||||||
|
# shape, dtype = Any, Any
|
||||||
|
# elif len(args) == 1:
|
||||||
|
# shape, dtype = args[0], Any
|
||||||
|
# elif len(args) == 2:
|
||||||
|
# shape, dtype = args[0], args[1]
|
||||||
|
# else:
|
||||||
|
# shape, dtype = args[:-1], args[-1]
|
||||||
|
#
|
||||||
|
# json_schema = make_json_schema(shape, dtype, handler)
|
||||||
|
# return core_schema.with_info_plain_validator_function(
|
||||||
|
# get_validate_interface(shape, dtype),
|
||||||
|
# serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
# jsonize_array, when_used="json", info_arg=True
|
||||||
|
# ),
|
||||||
|
# metadata=json_schema,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# def __get_pydantic_json_schema__(
|
||||||
|
# schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
|
# ) -> core_schema.JsonSchema:
|
||||||
|
# # shape, dtype = cls.__args__
|
||||||
|
# json_schema = handler(schema["metadata"])
|
||||||
|
# json_schema = handler.resolve_ref_schema(json_schema)
|
||||||
|
#
|
||||||
|
# # if not isinstance(dtype, tuple) and dtype.__module__ not in (
|
||||||
|
# # "builtins",
|
||||||
|
# # "typing",
|
||||||
|
# # ):
|
||||||
|
# # json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
|
||||||
|
#
|
||||||
|
# return json_schema
|
||||||
|
|
||||||
|
|
||||||
|
# NDArray = Annotated[
|
||||||
|
# _NDArray[Unpack[Shape], DType],
|
||||||
|
# GetPydanticSchema(__get_pydantic_core_schema__),
|
||||||
|
# # GetJsonSchemaFunction(__get_pydantic_json_schema__),
|
||||||
|
# ]
|
0
src/numpydantic/py.typed
Normal file
0
src/numpydantic/py.typed
Normal file
|
@ -113,7 +113,9 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
|
||||||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||||
core schema to use in the innermost list entry
|
core schema to use in the innermost list entry
|
||||||
"""
|
"""
|
||||||
from numpydantic.validation.shape import _is_range
|
from numpydantic.validation.shape import _is_range, to_shape
|
||||||
|
|
||||||
|
shape = to_shape(shape)
|
||||||
|
|
||||||
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
||||||
# labels, if present
|
# labels, if present
|
||||||
|
|
|
@ -91,6 +91,16 @@ class Shape(NPTypingType, ABC, metaclass=ShapeMeta):
|
||||||
prepared_args = ("*", "...")
|
prepared_args = ("*", "...")
|
||||||
|
|
||||||
|
|
||||||
|
def to_shape(shape) -> "Shape":
|
||||||
|
from numpydantic import Shape
|
||||||
|
|
||||||
|
if isinstance(shape, int):
|
||||||
|
shape = Shape[f"{shape}"]
|
||||||
|
elif isinstance(shape, tuple):
|
||||||
|
shape = Shape[f"{', '.join([s for s in shape])}"]
|
||||||
|
return shape
|
||||||
|
|
||||||
|
|
||||||
def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None:
|
def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
CHANGES FROM NPTYPING: Allow ranges
|
CHANGES FROM NPTYPING: Allow ranges
|
||||||
|
@ -112,6 +122,7 @@ def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
||||||
:param target: the shape expression to which shape is tested.
|
:param target: the shape expression to which shape is tested.
|
||||||
:return: True if the given shape corresponds to shape_expression.
|
:return: True if the given shape corresponds to shape_expression.
|
||||||
"""
|
"""
|
||||||
|
target = to_shape(target)
|
||||||
target_shape = _handle_ellipsis(shape, target.prepared_args)
|
target_shape = _handle_ellipsis(shape, target.prepared_args)
|
||||||
return _check_dimensions_against_shape(shape, target_shape)
|
return _check_dimensions_against_shape(shape, target_shape)
|
||||||
|
|
||||||
|
|
19
tests/test_generic.py
Normal file
19
tests/test_generic.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from numpydantic.ndarray_generic import NDArray
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Literal as L
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class MyClass(BaseModel):
|
||||||
|
array: NDArray[L[4, 5], int]
|
||||||
|
|
||||||
|
|
||||||
|
model = MyClass(array=np.array([1, 2, 3, 4]))
|
||||||
|
|
||||||
|
model2 = MyClass(array=np.array([1, 2, 3]))
|
||||||
|
|
||||||
|
model3 = MyClass(array=(1, 2))
|
||||||
|
|
||||||
|
|
||||||
|
array: NDArray[L[4], np.int64] = np.array([1, 2, 3])
|
||||||
|
array2: NDArray[L[3], np.int64] = [1, 2, 3]
|
|
@ -1,16 +1,15 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
from typing import Union, Optional, Any
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ValidationError, Field
|
import pytest
|
||||||
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
|
|
||||||
|
# from numpydantic import NDArray, Shape
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import Shape, dtype
|
||||||
from numpydantic.exceptions import ShapeError, DtypeError
|
|
||||||
from numpydantic import dtype
|
|
||||||
from numpydantic.dtype import Number
|
from numpydantic.dtype import Number
|
||||||
|
from numpydantic.exceptions import DtypeError
|
||||||
|
from numpydantic.ndarray_generic import NDArray
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.json_schema
|
@pytest.mark.json_schema
|
||||||
|
@ -177,9 +176,10 @@ def test_shape_ellipsis():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class MyModel(BaseModel):
|
class MyModel(BaseModel):
|
||||||
array: NDArray[Shape["1, 2, ..."], Number]
|
array: NDArray[1, 2, ..., Number]
|
||||||
|
|
||||||
_ = MyModel(array=np.zeros((1, 2, 3, 4, 5)))
|
_ = MyModel(array=np.zeros((1, 2, 3, 4, 5)))
|
||||||
|
_ = MyModel(array="hey")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.serialization
|
@pytest.mark.serialization
|
||||||
|
|
Loading…
Reference in a new issue