diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index 8756ae0..d951d3a 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -29,7 +29,7 @@ from numpydantic.schema import ( get_validate_interface, make_json_schema, ) -from numpydantic.types import DtypeType, ShapeType +from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.vendor.nptyping.error import InvalidArgumentsError from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta from numpydantic.vendor.nptyping.nptyping_type import NPTypingType @@ -54,6 +54,10 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): if TYPE_CHECKING: # pragma: no cover __getitem__ = SubscriptableMeta.__getitem__ + def __call__(cls, val: NDArrayType) -> NDArrayType: + """Call ndarray as a validator function""" + return get_validate_interface(cls.__args__[0], cls.__args__[1])(val) + def __instancecheck__(self, instance: Any): """ Extended type checking that determines whether diff --git a/src/numpydantic/vendor/nptyping/__init__.py b/src/numpydantic/vendor/nptyping/__init__.py index 4ca8fdb..0a51854 100644 --- a/src/numpydantic/vendor/nptyping/__init__.py +++ b/src/numpydantic/vendor/nptyping/__init__.py @@ -32,7 +32,9 @@ from numpydantic.vendor.nptyping.error import ( ) from numpydantic.vendor.nptyping.ndarray import NDArray from numpydantic.vendor.nptyping.package_info import __version__ -from numpydantic.vendor.nptyping.pandas_.dataframe import DataFrame + +# don't import unnecessarily since we don't use it +# from numpydantic.vendor.nptyping.pandas_.dataframe import DataFrame from numpydantic.vendor.nptyping.recarray import RecArray from numpydantic.vendor.nptyping.shape import Shape from numpydantic.vendor.nptyping.shape_expression import ( diff --git a/src/numpydantic/vendor/nptyping/ndarray.py b/src/numpydantic/vendor/nptyping/ndarray.py index 90a4793..19f06b4 100644 --- a/src/numpydantic/vendor/nptyping/ndarray.py +++ b/src/numpydantic/vendor/nptyping/ndarray.py @@ -31,7 +31,6 @@ import numpy as np from numpydantic.vendor.nptyping.base_meta_classes import ( FinalMeta, ImmutableMeta, - InconstructableMeta, MaybeCheckableMeta, PrintableMeta, SubscriptableMeta, @@ -54,7 +53,6 @@ from numpydantic.vendor.nptyping.typing_ import ( class NDArrayMeta( SubscriptableMeta, - InconstructableMeta, ImmutableMeta, FinalMeta, MaybeCheckableMeta, diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 9883c2a..f92a66d 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -350,3 +350,17 @@ def test_instancecheck(): return array my_function(np.zeros((1, 2, 3), int)) + + +def test_callable(): + """ + NDArray objects are callable to validate and cast + Don't test validation here, just that we can be called + """ + annotation = NDArray[Shape["3"], int] + array = np.array([1, 2, 3], dtype=int) + validated = annotation(array) + assert validated is array + + with pytest.raises(DtypeError): + _ = annotation(np.zeros((1, 2, 3)))