From 1290d6483335d5e84b4d4dc2e5fc05c472bccefa Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 23 May 2024 21:08:38 -0700 Subject: [PATCH] working instancecheck, but not working static analysis --- src/numpydantic/exceptions.py | 20 ++++++++- src/numpydantic/interface/interface.py | 57 ++++++++++++++++++++++---- src/numpydantic/ndarray.py | 36 +++++++++++++++- src/numpydantic/schema.py | 4 +- tests/test_ndarray.py | 20 +++++++++ 5 files changed, 126 insertions(+), 11 deletions(-) diff --git a/src/numpydantic/exceptions.py b/src/numpydantic/exceptions.py index 20d4981..252dbbd 100644 --- a/src/numpydantic/exceptions.py +++ b/src/numpydantic/exceptions.py @@ -3,9 +3,25 @@ Exceptions used within numpydantic """ -class DtypeError(TypeError): +class InterfaceError(Exception): + """Parent mixin class for errors raised by :class:`.Interface` subclasses""" + + +class DtypeError(TypeError, InterfaceError): """Exception raised for invalid dtypes""" -class ShapeError(ValueError): +class ShapeError(ValueError, InterfaceError): """Exception raise for invalid shapes""" + + +class MatchError(ValueError, InterfaceError): + """Exception for errors raised during :class:`.Interface.match`-ing""" + + +class NoMatchError(MatchError): + """No match was found by :class:`.Interface.match`""" + + +class TooManyMatchesError(MatchError): + """Too many matches found by :class:`.Interface.match`""" diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 14e1231..805e0ba 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -10,7 +10,12 @@ import numpy as np from nptyping.shape_expression import check_shape from pydantic import SerializationInfo -from numpydantic.exceptions import DtypeError, ShapeError +from numpydantic.exceptions import ( + DtypeError, + NoMatchError, + ShapeError, + TooManyMatchesError, +) from numpydantic.types import DtypeType, NDArrayType, ShapeType T = TypeVar("T", bound=NDArrayType) @@ -32,6 +37,25 @@ class Interface(ABC, Generic[T]): def validate(self, array: Any) -> T: """ Validate input, returning final array type + + Calls the methods, in order: + + * :meth:`.before_validation` + * :meth:`.validate_dtype` + * :meth:`.validate_shape` + * :meth:`.after_validation` + + passing the ``array`` argument and returning it from each. + + Implementing an interface subclass largely consists of overriding these methods + as needed. + + Raises: + If validation fails, rather than eg. returning ``False``, exceptions will + be raised (to halt the rest of the pydantic validation process). + When using interfaces outside of pydantic, you must catch both + :class:`.DtypeError` and :class:`.ShapeError` (both of which are children + of :class:`.InterfaceError` ) """ array = self.before_validation(array) array = self.validate_dtype(array) @@ -150,9 +174,21 @@ class Interface(ABC, Generic[T]): return tuple(in_types) @classmethod - def match(cls, array: Any) -> Type["Interface"]: + def match(cls, array: Any, fast: bool = False) -> Type["Interface"]: """ Find the interface that should be used for this array based on its input type + + First runs the ``check`` method for all interfaces returned by + :meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` , + and if no match is found then try the numpy interface. This is because + :meth:`.NumpyInterface.check` can be expensive, as we could potentially + try to + + Args: + fast (bool): if ``False`` , check all interfaces and raise exceptions for + having multiple matching interfaces (default). If ``True`` , + check each interface (as ordered by its ``priority`` , decreasing), + and return on the first match. """ # first try and find a non-numpy interface, since the numpy interface # will try and load the array into memory in its check method @@ -160,17 +196,24 @@ class Interface(ABC, Generic[T]): non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"] np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0] - matches = [i for i in non_np_interfaces if i.check(array)] + if fast: + matches = [] + for i in non_np_interfaces: + if i.check(array): + return i + else: + matches = [i for i in non_np_interfaces if i.check(array)] + if len(matches) > 1: msg = f"More than one interface matches input {array}:\n" msg += "\n".join([f" - {i}" for i in matches]) - raise ValueError(msg) + raise TooManyMatchesError(msg) elif len(matches) == 0: # now try the numpy interface if np_interface.check(array): return np_interface else: - raise ValueError(f"No matching interfaces found for input {array}") + raise NoMatchError(f"No matching interfaces found for input {array}") else: return matches[0] @@ -186,8 +229,8 @@ class Interface(ABC, Generic[T]): if len(matches) > 1: msg = f"More than one interface matches output {array}:\n" msg += "\n".join([f" - {i}" for i in matches]) - raise ValueError(msg) + raise TooManyMatchesError(msg) elif len(matches) == 0: - raise ValueError(f"No matching interfaces found for output {array}") + raise NoMatchError(f"No matching interfaces found for output {array}") else: return matches[0] diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index ddc1316..ce63464 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa """ -from typing import Any, Tuple +from typing import TYPE_CHECKING, Any, Tuple import numpy as np from nptyping.error import InvalidArgumentsError @@ -28,6 +28,8 @@ from pydantic import GetJsonSchemaHandler from pydantic_core import core_schema from numpydantic.dtype import DType +from numpydantic.exceptions import InterfaceError +from numpydantic.interface import Interface from numpydantic.maps import python_to_nptyping from numpydantic.schema import ( _handler_type, @@ -37,6 +39,9 @@ from numpydantic.schema import ( ) from numpydantic.types import DtypeType, ShapeType +if TYPE_CHECKING: + from nptyping.base_meta_classes import SubscriptableMeta + class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): """ @@ -44,6 +49,35 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): completion of the transition away from nptyping """ + if TYPE_CHECKING: + __getitem__ = SubscriptableMeta.__getitem__ + + def __instancecheck__(self, instance: Any): + """ + Extended type checking that determines whether + + 1) the ``type`` of the given instance is one of those in + :meth:`.Interface.input_types` + + but also + + 2) it satisfies the constraints set on the :class:`.NDArray` annotation + + Args: + instance (:class:`typing.Any`): Thing to check! + + Returns: + bool: ``True`` if matches constraints, ``False`` otherwise. + """ + shape, dtype = self.__args__ + try: + interface_cls = Interface.match(instance, fast=True) + interface = interface_cls(shape, dtype) + _ = interface.validate(instance) + return True + except InterfaceError: + return False + def _get_dtype(cls, dtype_candidate: Any) -> DType: """ Override of base _get_dtype method to allow for compound tuple types diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index 0233610..e610df4 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -225,7 +225,9 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: :meth:`.Interface.validate` method """ - def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType: + def validate_interface( + value: Any, info: Optional["ValidationInfo"] = None + ) -> NDArrayType: interface_cls = Interface.match(value) interface = interface_cls(shape, dtype) value = interface.validate(value) diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 39972f1..583bd5f 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -223,3 +223,23 @@ def test_json_schema_ellipsis(): schema = ConstrainedAnyShape.model_json_schema() _recursive_array(schema) + + +def test_instancecheck(): + """ + NDArray should handle ``isinstance()`` s.t. valid arrays are ``True`` + and invalid arrays are ``False`` + + We don't make this test exhaustive because correctness of validation + is tested elsewhere. We are just testing that the type checking works + """ + array_type = NDArray[Shape["1, 2, 3"], int] + + assert isinstance(np.zeros((1, 2, 3), dtype=int), array_type) + assert not isinstance(np.zeros((2, 2, 3), dtype=int), array_type) + assert not isinstance(np.zeros((1, 2, 3), dtype=float), array_type) + + def my_function(array: NDArray[Shape["1, 2, 3"], int]): + return array + + my_function(np.zeros((1, 2, 3), int))