mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
numpy interface and type stub generation
This commit is contained in:
parent
5069c3ddd4
commit
a6391c08a3
17 changed files with 1470 additions and 57 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -158,4 +158,5 @@ cython_debug/
|
|||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.pdm-python
|
||||
.pdm-python
|
||||
ndarray.pyi
|
|
@ -94,11 +94,16 @@ select = [
|
|||
# whitespace
|
||||
"D210", "D211",
|
||||
# emptiness
|
||||
"D419"
|
||||
"D419",
|
||||
|
||||
|
||||
]
|
||||
ignore = [
|
||||
"ANN101", "ANN102"
|
||||
"ANN101", "ANN102", "ANN401",
|
||||
# builtin type annotations
|
||||
"UP006", "UP035",
|
||||
# docstrings for __init__
|
||||
"D107",
|
||||
]
|
||||
|
||||
fixable = ["ALL"]
|
||||
|
|
|
@ -2,10 +2,11 @@
|
|||
# ruff: noqa: F401
|
||||
# ruff: noqa: I001
|
||||
from numpydantic.monkeypatch import apply_patches
|
||||
|
||||
apply_patches()
|
||||
|
||||
from numpydantic.ndarray import NDArray
|
||||
|
||||
__all__ = [
|
||||
"NDArray"
|
||||
]
|
||||
from numpydantic.meta import update_ndarray_stub
|
||||
|
||||
update_ndarray_stub()
|
||||
|
|
6
src/numpydantic/exceptions.py
Normal file
6
src/numpydantic/exceptions.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
class DtypeError(TypeError):
|
||||
"""Exception raised for invalid dtypes"""
|
||||
|
||||
|
||||
class ShapeError(ValueError):
|
||||
"""Exception raise for invalid shapes"""
|
4
src/numpydantic/interface/__init__.py
Normal file
4
src/numpydantic/interface/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.numpy import NumpyInterface
|
||||
|
||||
__all__ = ["Interface", "NumpyInterface"]
|
0
src/numpydantic/interface/dask.py
Normal file
0
src/numpydantic/interface/dask.py
Normal file
0
src/numpydantic/interface/hdf5.py
Normal file
0
src/numpydantic/interface/hdf5.py
Normal file
125
src/numpydantic/interface/interface.py
Normal file
125
src/numpydantic/interface/interface.py
Normal file
|
@ -0,0 +1,125 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from operator import attrgetter
|
||||
from typing import Any, Generic, List, Type, TypeVar, Tuple
|
||||
|
||||
from nptyping.shape_expression import check_shape
|
||||
|
||||
from numpydantic.exceptions import DtypeError, ShapeError
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
T = TypeVar("T", bound=NDArrayType)
|
||||
|
||||
|
||||
class Interface(ABC, Generic[T]):
|
||||
"""
|
||||
Abstract parent class for interfaces to different array formats
|
||||
"""
|
||||
|
||||
return_type: Type[T]
|
||||
priority: int = 0
|
||||
|
||||
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
def validate(self, array: Any) -> T:
|
||||
"""
|
||||
Validate input, returning final array type
|
||||
"""
|
||||
array = self.before_validation(array)
|
||||
array = self.validate_dtype(array)
|
||||
array = self.validate_shape(array)
|
||||
array = self.after_validation(array)
|
||||
return array
|
||||
|
||||
def before_validation(self, array: Any) -> NDArrayType:
|
||||
"""
|
||||
Optional step pre-validation that coerces the input into a type that can be
|
||||
validated for shape and dtype
|
||||
|
||||
Default method is a no-op
|
||||
"""
|
||||
return array
|
||||
|
||||
def validate_dtype(self, array: NDArrayType) -> NDArrayType:
|
||||
"""
|
||||
Validate the dtype of the given array, returning it unmutated.
|
||||
|
||||
Raises:
|
||||
:class:`~numpydantic.exceptions.DtypeError`
|
||||
"""
|
||||
if self.dtype is Any:
|
||||
return array
|
||||
if not array.dtype == self.dtype:
|
||||
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
|
||||
return array
|
||||
|
||||
def validate_shape(self, array: NDArrayType) -> NDArrayType:
|
||||
"""
|
||||
Validate the shape of the given array, returning it unmutated
|
||||
|
||||
Raises:
|
||||
:class:`~numpydantic.exceptions.ShapeError`
|
||||
"""
|
||||
if self.shape is Any:
|
||||
return array
|
||||
if not check_shape(array.shape, self.shape):
|
||||
raise ShapeError(
|
||||
f"Invalid shape! expected shape {self.shape.prepared_args}, got shape {array.shape}"
|
||||
)
|
||||
return array
|
||||
|
||||
def after_validation(self, array: NDArrayType) -> T:
|
||||
"""
|
||||
Optional step post-validation that coerces the intermediate array type into the return type
|
||||
|
||||
Default method is a no-op
|
||||
"""
|
||||
return array
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def check(cls, array: Any) -> bool:
|
||||
"""
|
||||
Method to check whether a given input applies to this interface
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def enabled(cls) -> bool:
|
||||
"""
|
||||
Check whether this array interface can be used (eg. its dependent packages are installed, etc.)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def interfaces(cls) -> Tuple[Type["Interface"], ...]:
|
||||
"""
|
||||
Enabled interface subclasses
|
||||
"""
|
||||
return tuple(
|
||||
sorted(
|
||||
[i for i in Interface.__subclasses__() if i.enabled()],
|
||||
key=attrgetter("priority"),
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def array_types(cls) -> Tuple[NDArrayType, ...]:
|
||||
"""Return types for all enabled interfaces"""
|
||||
return tuple([i.return_type for i in cls.interfaces()])
|
||||
|
||||
@classmethod
|
||||
def match(cls, array: Any) -> Type["Interface"]:
|
||||
"""
|
||||
Find the interface that should be used for this array
|
||||
"""
|
||||
matches = [i for i in cls.interfaces() if i.check(array)]
|
||||
if len(matches) > 1:
|
||||
msg = f"More than one interface matches input {array}:\n"
|
||||
msg += "\n".join([f" - {i}" for i in matches])
|
||||
raise ValueError(msg)
|
||||
elif len(matches) == 0:
|
||||
raise ValueError(f"No matching interfaces found for input {array}")
|
||||
else:
|
||||
return matches[0]
|
45
src/numpydantic/interface/numpy.py
Normal file
45
src/numpydantic/interface/numpy.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from typing import Any
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
|
||||
try:
|
||||
from numpy import ndarray
|
||||
|
||||
ENABLED = True
|
||||
|
||||
except ImportError:
|
||||
ENABLED = False
|
||||
ndarray = None
|
||||
|
||||
|
||||
class NumpyInterface(Interface):
|
||||
"""
|
||||
Numpy :class:`~numpy.ndarray` s!
|
||||
"""
|
||||
|
||||
return_type = ndarray
|
||||
|
||||
@classmethod
|
||||
def check(cls, array: Any) -> bool:
|
||||
"""Check that this is in fact a numpy ndarray or something that can be coerced to one"""
|
||||
if isinstance(array, ndarray):
|
||||
return True
|
||||
else:
|
||||
try:
|
||||
_ = ndarray(array)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
def before_validation(self, array: Any) -> ndarray:
|
||||
"""
|
||||
Coerce to an ndarray. We have already checked if coercion is possible in :meth:`.check`
|
||||
"""
|
||||
if not isinstance(array, ndarray):
|
||||
array = ndarray(array)
|
||||
return array
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
"""Check that numpy is present in the environment"""
|
||||
return ENABLED
|
0
src/numpydantic/interface/xarray.py
Normal file
0
src/numpydantic/interface/xarray.py
Normal file
0
src/numpydantic/interface/zarr.py
Normal file
0
src/numpydantic/interface/zarr.py
Normal file
39
src/numpydantic/meta.py
Normal file
39
src/numpydantic/meta.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
"""
|
||||
Metaprogramming functions for numpydantic to modify itself :)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from numpydantic.interface import Interface
|
||||
|
||||
|
||||
def generate_ndarray_stub() -> str:
|
||||
"""
|
||||
Make a stub file based on the array interfaces that are available
|
||||
"""
|
||||
|
||||
import_strings = [
|
||||
f"from {arr.__module__} import {arr.__name__}"
|
||||
for arr in Interface.array_types()
|
||||
]
|
||||
import_string = "\n".join(import_strings)
|
||||
|
||||
class_names = [arr.__name__ for arr in Interface.array_types()]
|
||||
class_union = " | ".join(class_names)
|
||||
ndarray_type = "NDArray = " + class_union
|
||||
|
||||
stub_string = "\n".join([import_string, ndarray_type])
|
||||
return stub_string
|
||||
|
||||
|
||||
def update_ndarray_stub() -> None:
|
||||
"""
|
||||
Update the ndarray.pyi string in the numpydantic file
|
||||
"""
|
||||
from numpydantic import ndarray
|
||||
|
||||
stub_string = generate_ndarray_stub()
|
||||
|
||||
pyi_file = Path(ndarray.__file__).with_suffix(".pyi")
|
||||
with open(pyi_file, "w") as pyi:
|
||||
pyi.write(stub_string)
|
|
@ -8,14 +8,11 @@ import base64
|
|||
import sys
|
||||
from collections.abc import Callable
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import Any, Tuple, TypeVar, cast, Union
|
||||
|
||||
import blosc2
|
||||
import nptyping.structure
|
||||
import numpy as np
|
||||
|
||||
# TODO: conditional import of dask, remove from required dependencies
|
||||
from dask.array.core import Array as DaskArray
|
||||
from nptyping import Shape
|
||||
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||
from nptyping.nptyping_type import NPTypingType
|
||||
|
@ -23,9 +20,11 @@ from nptyping.shape_expression import check_shape
|
|||
from pydantic_core import core_schema
|
||||
from pydantic_core.core_schema import ListSchema
|
||||
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import np_to_python
|
||||
|
||||
from numpydantic.proxy import NDArrayProxy
|
||||
# from numpydantic.proxy import NDArrayProxy
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
COMPRESSION_THRESHOLD = 16 * 1024
|
||||
"""
|
||||
|
@ -33,8 +32,6 @@ Arrays larger than this size (in bytes) will be compressed and b64 encoded when
|
|||
serializing to JSON.
|
||||
"""
|
||||
|
||||
ARRAY_TYPES = np.ndarray | DaskArray | NDArrayProxy
|
||||
|
||||
|
||||
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||
"""Make a pydantic JSON schema for an array as a list of lists."""
|
||||
|
@ -68,7 +65,7 @@ def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
|||
return list_schema
|
||||
|
||||
|
||||
def jsonize_array(array: ARRAY_TYPES) -> list | dict:
|
||||
def jsonize_array(array: NDArrayType) -> list | dict:
|
||||
"""
|
||||
Render an array to base python types that can be serialized to JSON
|
||||
|
||||
|
@ -80,12 +77,13 @@ def jsonize_array(array: ARRAY_TYPES) -> list | dict:
|
|||
Args:
|
||||
array (:class:`np.ndarray`, :class:`dask.DaskArray`): Array to render as a list!
|
||||
"""
|
||||
if isinstance(array, DaskArray):
|
||||
arr = array.__array__()
|
||||
elif isinstance(array, NDArrayProxy):
|
||||
arr = array[:]
|
||||
else:
|
||||
arr = array
|
||||
# if isinstance(array, DaskArray):
|
||||
# arr = array.__array__()
|
||||
# elif isinstance(array, NDArrayProxy):
|
||||
# arr = array[:]
|
||||
# else:
|
||||
# arr = array
|
||||
arr = array
|
||||
|
||||
# If we're larger than 16kB then compress array!
|
||||
if sys.getsizeof(arr) > COMPRESSION_THRESHOLD:
|
||||
|
@ -117,21 +115,18 @@ def get_validate_shape(shape: Shape) -> Callable:
|
|||
return validate_shape
|
||||
|
||||
|
||||
def get_validate_dtype(dtype: np.dtype) -> Callable:
|
||||
def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
||||
"""
|
||||
Get a closure around a dtype validation function that includes the dtype definition
|
||||
Validate using a matching :class:`.Interface` class using its :meth:`.Interface.validate` method
|
||||
"""
|
||||
|
||||
def validate_dtype(value: np.ndarray) -> np.ndarray:
|
||||
if dtype is Any:
|
||||
return value
|
||||
|
||||
assert (
|
||||
value.dtype == dtype
|
||||
), f"Invalid dtype! expected {dtype}, got {value.dtype}"
|
||||
def validate_interface(value: Any, info) -> NDArrayType:
|
||||
interface_cls = Interface.match(value)
|
||||
interface = interface_cls(shape, dtype)
|
||||
value = interface.validate(value)
|
||||
return value
|
||||
|
||||
return validate_dtype
|
||||
return validate_interface
|
||||
|
||||
|
||||
def coerce_list(value: Any) -> np.ndarray:
|
||||
|
@ -152,6 +147,9 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
|||
"""
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||
"""
|
||||
Constrained array type allowing npytyping syntax for dtype and shape validation and serialization.
|
||||
|
@ -167,7 +165,10 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
- https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
|
||||
"""
|
||||
|
||||
__args__ = (Any, Any)
|
||||
def __init__(self: T):
|
||||
pass
|
||||
|
||||
__args__: Tuple[ShapeType, DtypeType] = (Any, Any)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
|
@ -176,6 +177,8 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
_handler: Callable[[Any], core_schema.CoreSchema],
|
||||
) -> core_schema.CoreSchema:
|
||||
shape, dtype = _source_type.__args__
|
||||
shape: ShapeType
|
||||
dtype: DtypeType
|
||||
|
||||
# get pydantic core schema for the given specified type
|
||||
if isinstance(dtype, nptyping.structure.StructureMeta):
|
||||
|
@ -195,18 +198,8 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
python_schema=core_schema.chain_schema(
|
||||
[
|
||||
core_schema.no_info_plain_validator_function(coerce_list),
|
||||
core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(cls=np.ndarray),
|
||||
core_schema.is_instance_schema(cls=DaskArray),
|
||||
core_schema.is_instance_schema(cls=NDArrayProxy),
|
||||
]
|
||||
),
|
||||
core_schema.no_info_plain_validator_function(
|
||||
get_validate_dtype(dtype)
|
||||
),
|
||||
core_schema.no_info_plain_validator_function(
|
||||
get_validate_shape(shape)
|
||||
core_schema.with_info_plain_validator_function(
|
||||
get_validate_interface(shape, dtype)
|
||||
),
|
||||
]
|
||||
),
|
||||
|
@ -214,3 +207,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
jsonize_array, when_used="json"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
NDArray = cast(Union[np.ndarray, list[int]], NDArray)
|
||||
# NDArray = cast(Union[Interface.array_types()], NDArray)
|
||||
|
|
28
src/numpydantic/types.py
Normal file
28
src/numpydantic/types.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
"""
|
||||
Types for numpydantic
|
||||
|
||||
Note that these are types as in python typing types, not classes.
|
||||
"""
|
||||
|
||||
from typing import Any, Protocol, Tuple, TypeVar, Union, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
from nptyping import DType
|
||||
|
||||
|
||||
ShapeType = Tuple[int, ...] | Any
|
||||
DtypeType = np.dtype | str | type | Any | DType
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class NDArrayType(Protocol):
|
||||
|
||||
@property
|
||||
def dtype(self) -> DtypeType: ...
|
||||
|
||||
@property
|
||||
def shape(self) -> ShapeType: ...
|
||||
|
||||
def __getitem__(self, key: int | slice) -> "NDArrayType": ...
|
||||
|
||||
def __setitem__(self, key: int | slice, value: "NDArrayType"): ...
|
29
tests/test_meta.py
Normal file
29
tests/test_meta.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
|
||||
from numpydantic import NDArray
|
||||
|
||||
from typing import reveal_type
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO")
|
||||
def test_generate_stub():
|
||||
"""
|
||||
Test that we generate the stub file correctly...
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO")
|
||||
def test_update_stub():
|
||||
"""
|
||||
Test that the update stub file correctly updates the stub stored in the package
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO")
|
||||
def test_stub_revealed_type():
|
||||
"""
|
||||
Check that the revealed type matches the stub
|
||||
"""
|
||||
type = reveal_type(NDArray)
|
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, ValidationError, Field
|
|||
from nptyping import Shape, Number
|
||||
|
||||
from numpydantic import NDArray
|
||||
from numpydantic.proxy import NDArrayProxy
|
||||
from numpydantic.exceptions import ShapeError, DtypeError
|
||||
|
||||
|
||||
# from .fixtures import tmp_output_dir_func
|
||||
|
@ -33,7 +33,7 @@ def test_ndarray_type():
|
|||
with pytest.raises(ValidationError):
|
||||
instance = Model(array=np.zeros((4, 6)))
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(DtypeError):
|
||||
instance = Model(array=np.ones((2, 3), dtype=bool))
|
||||
|
||||
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
||||
|
@ -77,7 +77,7 @@ def test_ndarray_coercion():
|
|||
|
||||
amod = Model(array=[1, 2, 3, 4.5])
|
||||
assert np.allclose(amod.array, np.array([1, 2, 3, 4.5]))
|
||||
with pytest.raises(ValidationError):
|
||||
with pytest.raises(DtypeError):
|
||||
amod = Model(array=["a", "b", "c"])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue