mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
numpy interface and type stub generation
This commit is contained in:
parent
5069c3ddd4
commit
a6391c08a3
17 changed files with 1470 additions and 57 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -158,4 +158,5 @@ cython_debug/
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
.pdm-python
|
.pdm-python
|
||||||
|
ndarray.pyi
|
|
@ -94,11 +94,16 @@ select = [
|
||||||
# whitespace
|
# whitespace
|
||||||
"D210", "D211",
|
"D210", "D211",
|
||||||
# emptiness
|
# emptiness
|
||||||
"D419"
|
"D419",
|
||||||
|
|
||||||
|
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"ANN101", "ANN102"
|
"ANN101", "ANN102", "ANN401",
|
||||||
|
# builtin type annotations
|
||||||
|
"UP006", "UP035",
|
||||||
|
# docstrings for __init__
|
||||||
|
"D107",
|
||||||
]
|
]
|
||||||
|
|
||||||
fixable = ["ALL"]
|
fixable = ["ALL"]
|
||||||
|
|
|
@ -2,10 +2,11 @@
|
||||||
# ruff: noqa: F401
|
# ruff: noqa: F401
|
||||||
# ruff: noqa: I001
|
# ruff: noqa: I001
|
||||||
from numpydantic.monkeypatch import apply_patches
|
from numpydantic.monkeypatch import apply_patches
|
||||||
|
|
||||||
apply_patches()
|
apply_patches()
|
||||||
|
|
||||||
from numpydantic.ndarray import NDArray
|
from numpydantic.ndarray import NDArray
|
||||||
|
|
||||||
__all__ = [
|
from numpydantic.meta import update_ndarray_stub
|
||||||
"NDArray"
|
|
||||||
]
|
update_ndarray_stub()
|
||||||
|
|
6
src/numpydantic/exceptions.py
Normal file
6
src/numpydantic/exceptions.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
class DtypeError(TypeError):
|
||||||
|
"""Exception raised for invalid dtypes"""
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeError(ValueError):
|
||||||
|
"""Exception raise for invalid shapes"""
|
4
src/numpydantic/interface/__init__.py
Normal file
4
src/numpydantic/interface/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
from numpydantic.interface.interface import Interface
|
||||||
|
from numpydantic.interface.numpy import NumpyInterface
|
||||||
|
|
||||||
|
__all__ = ["Interface", "NumpyInterface"]
|
0
src/numpydantic/interface/dask.py
Normal file
0
src/numpydantic/interface/dask.py
Normal file
0
src/numpydantic/interface/hdf5.py
Normal file
0
src/numpydantic/interface/hdf5.py
Normal file
125
src/numpydantic/interface/interface.py
Normal file
125
src/numpydantic/interface/interface.py
Normal file
|
@ -0,0 +1,125 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from operator import attrgetter
|
||||||
|
from typing import Any, Generic, List, Type, TypeVar, Tuple
|
||||||
|
|
||||||
|
from nptyping.shape_expression import check_shape
|
||||||
|
|
||||||
|
from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=NDArrayType)
|
||||||
|
|
||||||
|
|
||||||
|
class Interface(ABC, Generic[T]):
|
||||||
|
"""
|
||||||
|
Abstract parent class for interfaces to different array formats
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_type: Type[T]
|
||||||
|
priority: int = 0
|
||||||
|
|
||||||
|
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
|
||||||
|
self.shape = shape
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def validate(self, array: Any) -> T:
|
||||||
|
"""
|
||||||
|
Validate input, returning final array type
|
||||||
|
"""
|
||||||
|
array = self.before_validation(array)
|
||||||
|
array = self.validate_dtype(array)
|
||||||
|
array = self.validate_shape(array)
|
||||||
|
array = self.after_validation(array)
|
||||||
|
return array
|
||||||
|
|
||||||
|
def before_validation(self, array: Any) -> NDArrayType:
|
||||||
|
"""
|
||||||
|
Optional step pre-validation that coerces the input into a type that can be
|
||||||
|
validated for shape and dtype
|
||||||
|
|
||||||
|
Default method is a no-op
|
||||||
|
"""
|
||||||
|
return array
|
||||||
|
|
||||||
|
def validate_dtype(self, array: NDArrayType) -> NDArrayType:
|
||||||
|
"""
|
||||||
|
Validate the dtype of the given array, returning it unmutated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
:class:`~numpydantic.exceptions.DtypeError`
|
||||||
|
"""
|
||||||
|
if self.dtype is Any:
|
||||||
|
return array
|
||||||
|
if not array.dtype == self.dtype:
|
||||||
|
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
|
||||||
|
return array
|
||||||
|
|
||||||
|
def validate_shape(self, array: NDArrayType) -> NDArrayType:
|
||||||
|
"""
|
||||||
|
Validate the shape of the given array, returning it unmutated
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
:class:`~numpydantic.exceptions.ShapeError`
|
||||||
|
"""
|
||||||
|
if self.shape is Any:
|
||||||
|
return array
|
||||||
|
if not check_shape(array.shape, self.shape):
|
||||||
|
raise ShapeError(
|
||||||
|
f"Invalid shape! expected shape {self.shape.prepared_args}, got shape {array.shape}"
|
||||||
|
)
|
||||||
|
return array
|
||||||
|
|
||||||
|
def after_validation(self, array: NDArrayType) -> T:
|
||||||
|
"""
|
||||||
|
Optional step post-validation that coerces the intermediate array type into the return type
|
||||||
|
|
||||||
|
Default method is a no-op
|
||||||
|
"""
|
||||||
|
return array
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def check(cls, array: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Method to check whether a given input applies to this interface
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def enabled(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether this array interface can be used (eg. its dependent packages are installed, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def interfaces(cls) -> Tuple[Type["Interface"], ...]:
|
||||||
|
"""
|
||||||
|
Enabled interface subclasses
|
||||||
|
"""
|
||||||
|
return tuple(
|
||||||
|
sorted(
|
||||||
|
[i for i in Interface.__subclasses__() if i.enabled()],
|
||||||
|
key=attrgetter("priority"),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def array_types(cls) -> Tuple[NDArrayType, ...]:
|
||||||
|
"""Return types for all enabled interfaces"""
|
||||||
|
return tuple([i.return_type for i in cls.interfaces()])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def match(cls, array: Any) -> Type["Interface"]:
|
||||||
|
"""
|
||||||
|
Find the interface that should be used for this array
|
||||||
|
"""
|
||||||
|
matches = [i for i in cls.interfaces() if i.check(array)]
|
||||||
|
if len(matches) > 1:
|
||||||
|
msg = f"More than one interface matches input {array}:\n"
|
||||||
|
msg += "\n".join([f" - {i}" for i in matches])
|
||||||
|
raise ValueError(msg)
|
||||||
|
elif len(matches) == 0:
|
||||||
|
raise ValueError(f"No matching interfaces found for input {array}")
|
||||||
|
else:
|
||||||
|
return matches[0]
|
45
src/numpydantic/interface/numpy.py
Normal file
45
src/numpydantic/interface/numpy.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from numpydantic.interface.interface import Interface
|
||||||
|
|
||||||
|
try:
|
||||||
|
from numpy import ndarray
|
||||||
|
|
||||||
|
ENABLED = True
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
ENABLED = False
|
||||||
|
ndarray = None
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyInterface(Interface):
|
||||||
|
"""
|
||||||
|
Numpy :class:`~numpy.ndarray` s!
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_type = ndarray
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check(cls, array: Any) -> bool:
|
||||||
|
"""Check that this is in fact a numpy ndarray or something that can be coerced to one"""
|
||||||
|
if isinstance(array, ndarray):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
_ = ndarray(array)
|
||||||
|
return True
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def before_validation(self, array: Any) -> ndarray:
|
||||||
|
"""
|
||||||
|
Coerce to an ndarray. We have already checked if coercion is possible in :meth:`.check`
|
||||||
|
"""
|
||||||
|
if not isinstance(array, ndarray):
|
||||||
|
array = ndarray(array)
|
||||||
|
return array
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls) -> bool:
|
||||||
|
"""Check that numpy is present in the environment"""
|
||||||
|
return ENABLED
|
0
src/numpydantic/interface/xarray.py
Normal file
0
src/numpydantic/interface/xarray.py
Normal file
0
src/numpydantic/interface/zarr.py
Normal file
0
src/numpydantic/interface/zarr.py
Normal file
39
src/numpydantic/meta.py
Normal file
39
src/numpydantic/meta.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
"""
|
||||||
|
Metaprogramming functions for numpydantic to modify itself :)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from numpydantic.interface import Interface
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ndarray_stub() -> str:
|
||||||
|
"""
|
||||||
|
Make a stub file based on the array interfaces that are available
|
||||||
|
"""
|
||||||
|
|
||||||
|
import_strings = [
|
||||||
|
f"from {arr.__module__} import {arr.__name__}"
|
||||||
|
for arr in Interface.array_types()
|
||||||
|
]
|
||||||
|
import_string = "\n".join(import_strings)
|
||||||
|
|
||||||
|
class_names = [arr.__name__ for arr in Interface.array_types()]
|
||||||
|
class_union = " | ".join(class_names)
|
||||||
|
ndarray_type = "NDArray = " + class_union
|
||||||
|
|
||||||
|
stub_string = "\n".join([import_string, ndarray_type])
|
||||||
|
return stub_string
|
||||||
|
|
||||||
|
|
||||||
|
def update_ndarray_stub() -> None:
|
||||||
|
"""
|
||||||
|
Update the ndarray.pyi string in the numpydantic file
|
||||||
|
"""
|
||||||
|
from numpydantic import ndarray
|
||||||
|
|
||||||
|
stub_string = generate_ndarray_stub()
|
||||||
|
|
||||||
|
pyi_file = Path(ndarray.__file__).with_suffix(".pyi")
|
||||||
|
with open(pyi_file, "w") as pyi:
|
||||||
|
pyi.write(stub_string)
|
|
@ -8,14 +8,11 @@ import base64
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import Any
|
from typing import Any, Tuple, TypeVar, cast, Union
|
||||||
|
|
||||||
import blosc2
|
import blosc2
|
||||||
import nptyping.structure
|
import nptyping.structure
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# TODO: conditional import of dask, remove from required dependencies
|
|
||||||
from dask.array.core import Array as DaskArray
|
|
||||||
from nptyping import Shape
|
from nptyping import Shape
|
||||||
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||||
from nptyping.nptyping_type import NPTypingType
|
from nptyping.nptyping_type import NPTypingType
|
||||||
|
@ -23,9 +20,11 @@ from nptyping.shape_expression import check_shape
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
from pydantic_core.core_schema import ListSchema
|
from pydantic_core.core_schema import ListSchema
|
||||||
|
|
||||||
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.maps import np_to_python
|
from numpydantic.maps import np_to_python
|
||||||
|
|
||||||
from numpydantic.proxy import NDArrayProxy
|
# from numpydantic.proxy import NDArrayProxy
|
||||||
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
|
||||||
COMPRESSION_THRESHOLD = 16 * 1024
|
COMPRESSION_THRESHOLD = 16 * 1024
|
||||||
"""
|
"""
|
||||||
|
@ -33,8 +32,6 @@ Arrays larger than this size (in bytes) will be compressed and b64 encoded when
|
||||||
serializing to JSON.
|
serializing to JSON.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ARRAY_TYPES = np.ndarray | DaskArray | NDArrayProxy
|
|
||||||
|
|
||||||
|
|
||||||
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||||
"""Make a pydantic JSON schema for an array as a list of lists."""
|
"""Make a pydantic JSON schema for an array as a list of lists."""
|
||||||
|
@ -68,7 +65,7 @@ def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||||
return list_schema
|
return list_schema
|
||||||
|
|
||||||
|
|
||||||
def jsonize_array(array: ARRAY_TYPES) -> list | dict:
|
def jsonize_array(array: NDArrayType) -> list | dict:
|
||||||
"""
|
"""
|
||||||
Render an array to base python types that can be serialized to JSON
|
Render an array to base python types that can be serialized to JSON
|
||||||
|
|
||||||
|
@ -80,12 +77,13 @@ def jsonize_array(array: ARRAY_TYPES) -> list | dict:
|
||||||
Args:
|
Args:
|
||||||
array (:class:`np.ndarray`, :class:`dask.DaskArray`): Array to render as a list!
|
array (:class:`np.ndarray`, :class:`dask.DaskArray`): Array to render as a list!
|
||||||
"""
|
"""
|
||||||
if isinstance(array, DaskArray):
|
# if isinstance(array, DaskArray):
|
||||||
arr = array.__array__()
|
# arr = array.__array__()
|
||||||
elif isinstance(array, NDArrayProxy):
|
# elif isinstance(array, NDArrayProxy):
|
||||||
arr = array[:]
|
# arr = array[:]
|
||||||
else:
|
# else:
|
||||||
arr = array
|
# arr = array
|
||||||
|
arr = array
|
||||||
|
|
||||||
# If we're larger than 16kB then compress array!
|
# If we're larger than 16kB then compress array!
|
||||||
if sys.getsizeof(arr) > COMPRESSION_THRESHOLD:
|
if sys.getsizeof(arr) > COMPRESSION_THRESHOLD:
|
||||||
|
@ -117,21 +115,18 @@ def get_validate_shape(shape: Shape) -> Callable:
|
||||||
return validate_shape
|
return validate_shape
|
||||||
|
|
||||||
|
|
||||||
def get_validate_dtype(dtype: np.dtype) -> Callable:
|
def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
||||||
"""
|
"""
|
||||||
Get a closure around a dtype validation function that includes the dtype definition
|
Validate using a matching :class:`.Interface` class using its :meth:`.Interface.validate` method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validate_dtype(value: np.ndarray) -> np.ndarray:
|
def validate_interface(value: Any, info) -> NDArrayType:
|
||||||
if dtype is Any:
|
interface_cls = Interface.match(value)
|
||||||
return value
|
interface = interface_cls(shape, dtype)
|
||||||
|
value = interface.validate(value)
|
||||||
assert (
|
|
||||||
value.dtype == dtype
|
|
||||||
), f"Invalid dtype! expected {dtype}, got {value.dtype}"
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return validate_dtype
|
return validate_interface
|
||||||
|
|
||||||
|
|
||||||
def coerce_list(value: Any) -> np.ndarray:
|
def coerce_list(value: Any) -> np.ndarray:
|
||||||
|
@ -152,6 +147,9 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
"""
|
"""
|
||||||
Constrained array type allowing npytyping syntax for dtype and shape validation and serialization.
|
Constrained array type allowing npytyping syntax for dtype and shape validation and serialization.
|
||||||
|
@ -167,7 +165,10 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
- https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
|
- https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__args__ = (Any, Any)
|
def __init__(self: T):
|
||||||
|
pass
|
||||||
|
|
||||||
|
__args__: Tuple[ShapeType, DtypeType] = (Any, Any)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_core_schema__(
|
def __get_pydantic_core_schema__(
|
||||||
|
@ -176,6 +177,8 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
_handler: Callable[[Any], core_schema.CoreSchema],
|
_handler: Callable[[Any], core_schema.CoreSchema],
|
||||||
) -> core_schema.CoreSchema:
|
) -> core_schema.CoreSchema:
|
||||||
shape, dtype = _source_type.__args__
|
shape, dtype = _source_type.__args__
|
||||||
|
shape: ShapeType
|
||||||
|
dtype: DtypeType
|
||||||
|
|
||||||
# get pydantic core schema for the given specified type
|
# get pydantic core schema for the given specified type
|
||||||
if isinstance(dtype, nptyping.structure.StructureMeta):
|
if isinstance(dtype, nptyping.structure.StructureMeta):
|
||||||
|
@ -195,18 +198,8 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
python_schema=core_schema.chain_schema(
|
python_schema=core_schema.chain_schema(
|
||||||
[
|
[
|
||||||
core_schema.no_info_plain_validator_function(coerce_list),
|
core_schema.no_info_plain_validator_function(coerce_list),
|
||||||
core_schema.union_schema(
|
core_schema.with_info_plain_validator_function(
|
||||||
[
|
get_validate_interface(shape, dtype)
|
||||||
core_schema.is_instance_schema(cls=np.ndarray),
|
|
||||||
core_schema.is_instance_schema(cls=DaskArray),
|
|
||||||
core_schema.is_instance_schema(cls=NDArrayProxy),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
core_schema.no_info_plain_validator_function(
|
|
||||||
get_validate_dtype(dtype)
|
|
||||||
),
|
|
||||||
core_schema.no_info_plain_validator_function(
|
|
||||||
get_validate_shape(shape)
|
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
@ -214,3 +207,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
jsonize_array, when_used="json"
|
jsonize_array, when_used="json"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
NDArray = cast(Union[np.ndarray, list[int]], NDArray)
|
||||||
|
# NDArray = cast(Union[Interface.array_types()], NDArray)
|
||||||
|
|
28
src/numpydantic/types.py
Normal file
28
src/numpydantic/types.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
"""
|
||||||
|
Types for numpydantic
|
||||||
|
|
||||||
|
Note that these are types as in python typing types, not classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Protocol, Tuple, TypeVar, Union, runtime_checkable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from nptyping import DType
|
||||||
|
|
||||||
|
|
||||||
|
ShapeType = Tuple[int, ...] | Any
|
||||||
|
DtypeType = np.dtype | str | type | Any | DType
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class NDArrayType(Protocol):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DtypeType: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self) -> ShapeType: ...
|
||||||
|
|
||||||
|
def __getitem__(self, key: int | slice) -> "NDArrayType": ...
|
||||||
|
|
||||||
|
def __setitem__(self, key: int | slice, value: "NDArrayType"): ...
|
29
tests/test_meta.py
Normal file
29
tests/test_meta.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from numpydantic import NDArray
|
||||||
|
|
||||||
|
from typing import reveal_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO")
|
||||||
|
def test_generate_stub():
|
||||||
|
"""
|
||||||
|
Test that we generate the stub file correctly...
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO")
|
||||||
|
def test_update_stub():
|
||||||
|
"""
|
||||||
|
Test that the update stub file correctly updates the stub stored in the package
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO")
|
||||||
|
def test_stub_revealed_type():
|
||||||
|
"""
|
||||||
|
Check that the revealed type matches the stub
|
||||||
|
"""
|
||||||
|
type = reveal_type(NDArray)
|
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, ValidationError, Field
|
||||||
from nptyping import Shape, Number
|
from nptyping import Shape, Number
|
||||||
|
|
||||||
from numpydantic import NDArray
|
from numpydantic import NDArray
|
||||||
from numpydantic.proxy import NDArrayProxy
|
from numpydantic.exceptions import ShapeError, DtypeError
|
||||||
|
|
||||||
|
|
||||||
# from .fixtures import tmp_output_dir_func
|
# from .fixtures import tmp_output_dir_func
|
||||||
|
@ -33,7 +33,7 @@ def test_ndarray_type():
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
instance = Model(array=np.zeros((4, 6)))
|
instance = Model(array=np.zeros((4, 6)))
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(DtypeError):
|
||||||
instance = Model(array=np.ones((2, 3), dtype=bool))
|
instance = Model(array=np.ones((2, 3), dtype=bool))
|
||||||
|
|
||||||
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
||||||
|
@ -77,7 +77,7 @@ def test_ndarray_coercion():
|
||||||
|
|
||||||
amod = Model(array=[1, 2, 3, 4.5])
|
amod = Model(array=[1, 2, 3, 4.5])
|
||||||
assert np.allclose(amod.array, np.array([1, 2, 3, 4.5]))
|
assert np.allclose(amod.array, np.array([1, 2, 3, 4.5]))
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(DtypeError):
|
||||||
amod = Model(array=["a", "b", "c"])
|
amod = Model(array=["a", "b", "c"])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue