numpy interface and type stub generation

This commit is contained in:
sneakers-the-rat 2024-04-03 20:52:33 -07:00
parent 5069c3ddd4
commit a6391c08a3
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
17 changed files with 1470 additions and 57 deletions

3
.gitignore vendored
View file

@ -158,4 +158,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
.pdm-python .pdm-python
ndarray.pyi

1155
pdm.lock

File diff suppressed because it is too large Load diff

View file

@ -94,11 +94,16 @@ select = [
# whitespace # whitespace
"D210", "D211", "D210", "D211",
# emptiness # emptiness
"D419" "D419",
] ]
ignore = [ ignore = [
"ANN101", "ANN102" "ANN101", "ANN102", "ANN401",
# builtin type annotations
"UP006", "UP035",
# docstrings for __init__
"D107",
] ]
fixable = ["ALL"] fixable = ["ALL"]

View file

@ -2,10 +2,11 @@
# ruff: noqa: F401 # ruff: noqa: F401
# ruff: noqa: I001 # ruff: noqa: I001
from numpydantic.monkeypatch import apply_patches from numpydantic.monkeypatch import apply_patches
apply_patches() apply_patches()
from numpydantic.ndarray import NDArray from numpydantic.ndarray import NDArray
__all__ = [ from numpydantic.meta import update_ndarray_stub
"NDArray"
] update_ndarray_stub()

View file

@ -0,0 +1,6 @@
class DtypeError(TypeError):
"""Exception raised for invalid dtypes"""
class ShapeError(ValueError):
"""Exception raise for invalid shapes"""

View file

@ -0,0 +1,4 @@
from numpydantic.interface.interface import Interface
from numpydantic.interface.numpy import NumpyInterface
__all__ = ["Interface", "NumpyInterface"]

View file

View file

View file

@ -0,0 +1,125 @@
from abc import ABC, abstractmethod
from operator import attrgetter
from typing import Any, Generic, List, Type, TypeVar, Tuple
from nptyping.shape_expression import check_shape
from numpydantic.exceptions import DtypeError, ShapeError
from numpydantic.types import DtypeType, NDArrayType, ShapeType
T = TypeVar("T", bound=NDArrayType)
class Interface(ABC, Generic[T]):
"""
Abstract parent class for interfaces to different array formats
"""
return_type: Type[T]
priority: int = 0
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
self.shape = shape
self.dtype = dtype
def validate(self, array: Any) -> T:
"""
Validate input, returning final array type
"""
array = self.before_validation(array)
array = self.validate_dtype(array)
array = self.validate_shape(array)
array = self.after_validation(array)
return array
def before_validation(self, array: Any) -> NDArrayType:
"""
Optional step pre-validation that coerces the input into a type that can be
validated for shape and dtype
Default method is a no-op
"""
return array
def validate_dtype(self, array: NDArrayType) -> NDArrayType:
"""
Validate the dtype of the given array, returning it unmutated.
Raises:
:class:`~numpydantic.exceptions.DtypeError`
"""
if self.dtype is Any:
return array
if not array.dtype == self.dtype:
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
return array
def validate_shape(self, array: NDArrayType) -> NDArrayType:
"""
Validate the shape of the given array, returning it unmutated
Raises:
:class:`~numpydantic.exceptions.ShapeError`
"""
if self.shape is Any:
return array
if not check_shape(array.shape, self.shape):
raise ShapeError(
f"Invalid shape! expected shape {self.shape.prepared_args}, got shape {array.shape}"
)
return array
def after_validation(self, array: NDArrayType) -> T:
"""
Optional step post-validation that coerces the intermediate array type into the return type
Default method is a no-op
"""
return array
@classmethod
@abstractmethod
def check(cls, array: Any) -> bool:
"""
Method to check whether a given input applies to this interface
"""
@classmethod
@abstractmethod
def enabled(cls) -> bool:
"""
Check whether this array interface can be used (eg. its dependent packages are installed, etc.)
"""
@classmethod
def interfaces(cls) -> Tuple[Type["Interface"], ...]:
"""
Enabled interface subclasses
"""
return tuple(
sorted(
[i for i in Interface.__subclasses__() if i.enabled()],
key=attrgetter("priority"),
reverse=True,
)
)
@classmethod
def array_types(cls) -> Tuple[NDArrayType, ...]:
"""Return types for all enabled interfaces"""
return tuple([i.return_type for i in cls.interfaces()])
@classmethod
def match(cls, array: Any) -> Type["Interface"]:
"""
Find the interface that should be used for this array
"""
matches = [i for i in cls.interfaces() if i.check(array)]
if len(matches) > 1:
msg = f"More than one interface matches input {array}:\n"
msg += "\n".join([f" - {i}" for i in matches])
raise ValueError(msg)
elif len(matches) == 0:
raise ValueError(f"No matching interfaces found for input {array}")
else:
return matches[0]

View file

@ -0,0 +1,45 @@
from typing import Any
from numpydantic.interface.interface import Interface
try:
from numpy import ndarray
ENABLED = True
except ImportError:
ENABLED = False
ndarray = None
class NumpyInterface(Interface):
"""
Numpy :class:`~numpy.ndarray` s!
"""
return_type = ndarray
@classmethod
def check(cls, array: Any) -> bool:
"""Check that this is in fact a numpy ndarray or something that can be coerced to one"""
if isinstance(array, ndarray):
return True
else:
try:
_ = ndarray(array)
return True
except TypeError:
return False
def before_validation(self, array: Any) -> ndarray:
"""
Coerce to an ndarray. We have already checked if coercion is possible in :meth:`.check`
"""
if not isinstance(array, ndarray):
array = ndarray(array)
return array
@classmethod
def enabled(cls) -> bool:
"""Check that numpy is present in the environment"""
return ENABLED

View file

View file

39
src/numpydantic/meta.py Normal file
View file

@ -0,0 +1,39 @@
"""
Metaprogramming functions for numpydantic to modify itself :)
"""
from pathlib import Path
from numpydantic.interface import Interface
def generate_ndarray_stub() -> str:
"""
Make a stub file based on the array interfaces that are available
"""
import_strings = [
f"from {arr.__module__} import {arr.__name__}"
for arr in Interface.array_types()
]
import_string = "\n".join(import_strings)
class_names = [arr.__name__ for arr in Interface.array_types()]
class_union = " | ".join(class_names)
ndarray_type = "NDArray = " + class_union
stub_string = "\n".join([import_string, ndarray_type])
return stub_string
def update_ndarray_stub() -> None:
"""
Update the ndarray.pyi string in the numpydantic file
"""
from numpydantic import ndarray
stub_string = generate_ndarray_stub()
pyi_file = Path(ndarray.__file__).with_suffix(".pyi")
with open(pyi_file, "w") as pyi:
pyi.write(stub_string)

View file

@ -8,14 +8,11 @@ import base64
import sys import sys
from collections.abc import Callable from collections.abc import Callable
from copy import copy from copy import copy
from typing import Any from typing import Any, Tuple, TypeVar, cast, Union
import blosc2 import blosc2
import nptyping.structure import nptyping.structure
import numpy as np import numpy as np
# TODO: conditional import of dask, remove from required dependencies
from dask.array.core import Array as DaskArray
from nptyping import Shape from nptyping import Shape
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
from nptyping.nptyping_type import NPTypingType from nptyping.nptyping_type import NPTypingType
@ -23,9 +20,11 @@ from nptyping.shape_expression import check_shape
from pydantic_core import core_schema from pydantic_core import core_schema
from pydantic_core.core_schema import ListSchema from pydantic_core.core_schema import ListSchema
from numpydantic.interface import Interface
from numpydantic.maps import np_to_python from numpydantic.maps import np_to_python
from numpydantic.proxy import NDArrayProxy # from numpydantic.proxy import NDArrayProxy
from numpydantic.types import DtypeType, NDArrayType, ShapeType
COMPRESSION_THRESHOLD = 16 * 1024 COMPRESSION_THRESHOLD = 16 * 1024
""" """
@ -33,8 +32,6 @@ Arrays larger than this size (in bytes) will be compressed and b64 encoded when
serializing to JSON. serializing to JSON.
""" """
ARRAY_TYPES = np.ndarray | DaskArray | NDArrayProxy
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema: def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
"""Make a pydantic JSON schema for an array as a list of lists.""" """Make a pydantic JSON schema for an array as a list of lists."""
@ -68,7 +65,7 @@ def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
return list_schema return list_schema
def jsonize_array(array: ARRAY_TYPES) -> list | dict: def jsonize_array(array: NDArrayType) -> list | dict:
""" """
Render an array to base python types that can be serialized to JSON Render an array to base python types that can be serialized to JSON
@ -80,12 +77,13 @@ def jsonize_array(array: ARRAY_TYPES) -> list | dict:
Args: Args:
array (:class:`np.ndarray`, :class:`dask.DaskArray`): Array to render as a list! array (:class:`np.ndarray`, :class:`dask.DaskArray`): Array to render as a list!
""" """
if isinstance(array, DaskArray): # if isinstance(array, DaskArray):
arr = array.__array__() # arr = array.__array__()
elif isinstance(array, NDArrayProxy): # elif isinstance(array, NDArrayProxy):
arr = array[:] # arr = array[:]
else: # else:
arr = array # arr = array
arr = array
# If we're larger than 16kB then compress array! # If we're larger than 16kB then compress array!
if sys.getsizeof(arr) > COMPRESSION_THRESHOLD: if sys.getsizeof(arr) > COMPRESSION_THRESHOLD:
@ -117,21 +115,18 @@ def get_validate_shape(shape: Shape) -> Callable:
return validate_shape return validate_shape
def get_validate_dtype(dtype: np.dtype) -> Callable: def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
""" """
Get a closure around a dtype validation function that includes the dtype definition Validate using a matching :class:`.Interface` class using its :meth:`.Interface.validate` method
""" """
def validate_dtype(value: np.ndarray) -> np.ndarray: def validate_interface(value: Any, info) -> NDArrayType:
if dtype is Any: interface_cls = Interface.match(value)
return value interface = interface_cls(shape, dtype)
value = interface.validate(value)
assert (
value.dtype == dtype
), f"Invalid dtype! expected {dtype}, got {value.dtype}"
return value return value
return validate_dtype return validate_interface
def coerce_list(value: Any) -> np.ndarray: def coerce_list(value: Any) -> np.ndarray:
@ -152,6 +147,9 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
""" """
T = TypeVar("T")
class NDArray(NPTypingType, metaclass=NDArrayMeta): class NDArray(NPTypingType, metaclass=NDArrayMeta):
""" """
Constrained array type allowing npytyping syntax for dtype and shape validation and serialization. Constrained array type allowing npytyping syntax for dtype and shape validation and serialization.
@ -167,7 +165,10 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
- https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types - https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
""" """
__args__ = (Any, Any) def __init__(self: T):
pass
__args__: Tuple[ShapeType, DtypeType] = (Any, Any)
@classmethod @classmethod
def __get_pydantic_core_schema__( def __get_pydantic_core_schema__(
@ -176,6 +177,8 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
_handler: Callable[[Any], core_schema.CoreSchema], _handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema: ) -> core_schema.CoreSchema:
shape, dtype = _source_type.__args__ shape, dtype = _source_type.__args__
shape: ShapeType
dtype: DtypeType
# get pydantic core schema for the given specified type # get pydantic core schema for the given specified type
if isinstance(dtype, nptyping.structure.StructureMeta): if isinstance(dtype, nptyping.structure.StructureMeta):
@ -195,18 +198,8 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
python_schema=core_schema.chain_schema( python_schema=core_schema.chain_schema(
[ [
core_schema.no_info_plain_validator_function(coerce_list), core_schema.no_info_plain_validator_function(coerce_list),
core_schema.union_schema( core_schema.with_info_plain_validator_function(
[ get_validate_interface(shape, dtype)
core_schema.is_instance_schema(cls=np.ndarray),
core_schema.is_instance_schema(cls=DaskArray),
core_schema.is_instance_schema(cls=NDArrayProxy),
]
),
core_schema.no_info_plain_validator_function(
get_validate_dtype(dtype)
),
core_schema.no_info_plain_validator_function(
get_validate_shape(shape)
), ),
] ]
), ),
@ -214,3 +207,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
jsonize_array, when_used="json" jsonize_array, when_used="json"
), ),
) )
NDArray = cast(Union[np.ndarray, list[int]], NDArray)
# NDArray = cast(Union[Interface.array_types()], NDArray)

28
src/numpydantic/types.py Normal file
View file

@ -0,0 +1,28 @@
"""
Types for numpydantic
Note that these are types as in python typing types, not classes.
"""
from typing import Any, Protocol, Tuple, TypeVar, Union, runtime_checkable
import numpy as np
from nptyping import DType
ShapeType = Tuple[int, ...] | Any
DtypeType = np.dtype | str | type | Any | DType
@runtime_checkable
class NDArrayType(Protocol):
@property
def dtype(self) -> DtypeType: ...
@property
def shape(self) -> ShapeType: ...
def __getitem__(self, key: int | slice) -> "NDArrayType": ...
def __setitem__(self, key: int | slice, value: "NDArrayType"): ...

29
tests/test_meta.py Normal file
View file

@ -0,0 +1,29 @@
import pytest
from numpydantic import NDArray
from typing import reveal_type
@pytest.mark.skip("TODO")
def test_generate_stub():
"""
Test that we generate the stub file correctly...
"""
pass
@pytest.mark.skip("TODO")
def test_update_stub():
"""
Test that the update stub file correctly updates the stub stored in the package
"""
pass
@pytest.mark.skip("TODO")
def test_stub_revealed_type():
"""
Check that the revealed type matches the stub
"""
type = reveal_type(NDArray)

View file

@ -8,7 +8,7 @@ from pydantic import BaseModel, ValidationError, Field
from nptyping import Shape, Number from nptyping import Shape, Number
from numpydantic import NDArray from numpydantic import NDArray
from numpydantic.proxy import NDArrayProxy from numpydantic.exceptions import ShapeError, DtypeError
# from .fixtures import tmp_output_dir_func # from .fixtures import tmp_output_dir_func
@ -33,7 +33,7 @@ def test_ndarray_type():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
instance = Model(array=np.zeros((4, 6))) instance = Model(array=np.zeros((4, 6)))
with pytest.raises(ValidationError): with pytest.raises(DtypeError):
instance = Model(array=np.ones((2, 3), dtype=bool)) instance = Model(array=np.ones((2, 3), dtype=bool))
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5))) instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
@ -77,7 +77,7 @@ def test_ndarray_coercion():
amod = Model(array=[1, 2, 3, 4.5]) amod = Model(array=[1, 2, 3, 4.5])
assert np.allclose(amod.array, np.array([1, 2, 3, 4.5])) assert np.allclose(amod.array, np.array([1, 2, 3, 4.5]))
with pytest.raises(ValidationError): with pytest.raises(DtypeError):
amod = Model(array=["a", "b", "c"]) amod = Model(array=["a", "b", "c"])