working instancecheck, but not working static analysis

This commit is contained in:
sneakers-the-rat 2024-05-23 21:08:38 -07:00
parent b0b391947f
commit 1290d64833
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 126 additions and 11 deletions

View file

@ -3,9 +3,25 @@ Exceptions used within numpydantic
""" """
class DtypeError(TypeError): class InterfaceError(Exception):
"""Parent mixin class for errors raised by :class:`.Interface` subclasses"""
class DtypeError(TypeError, InterfaceError):
"""Exception raised for invalid dtypes""" """Exception raised for invalid dtypes"""
class ShapeError(ValueError): class ShapeError(ValueError, InterfaceError):
"""Exception raise for invalid shapes""" """Exception raise for invalid shapes"""
class MatchError(ValueError, InterfaceError):
"""Exception for errors raised during :class:`.Interface.match`-ing"""
class NoMatchError(MatchError):
"""No match was found by :class:`.Interface.match`"""
class TooManyMatchesError(MatchError):
"""Too many matches found by :class:`.Interface.match`"""

View file

@ -10,7 +10,12 @@ import numpy as np
from nptyping.shape_expression import check_shape from nptyping.shape_expression import check_shape
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import (
DtypeError,
NoMatchError,
ShapeError,
TooManyMatchesError,
)
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
T = TypeVar("T", bound=NDArrayType) T = TypeVar("T", bound=NDArrayType)
@ -32,6 +37,25 @@ class Interface(ABC, Generic[T]):
def validate(self, array: Any) -> T: def validate(self, array: Any) -> T:
""" """
Validate input, returning final array type Validate input, returning final array type
Calls the methods, in order:
* :meth:`.before_validation`
* :meth:`.validate_dtype`
* :meth:`.validate_shape`
* :meth:`.after_validation`
passing the ``array`` argument and returning it from each.
Implementing an interface subclass largely consists of overriding these methods
as needed.
Raises:
If validation fails, rather than eg. returning ``False``, exceptions will
be raised (to halt the rest of the pydantic validation process).
When using interfaces outside of pydantic, you must catch both
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
of :class:`.InterfaceError` )
""" """
array = self.before_validation(array) array = self.before_validation(array)
array = self.validate_dtype(array) array = self.validate_dtype(array)
@ -150,9 +174,21 @@ class Interface(ABC, Generic[T]):
return tuple(in_types) return tuple(in_types)
@classmethod @classmethod
def match(cls, array: Any) -> Type["Interface"]: def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
""" """
Find the interface that should be used for this array based on its input type Find the interface that should be used for this array based on its input type
First runs the ``check`` method for all interfaces returned by
:meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
and if no match is found then try the numpy interface. This is because
:meth:`.NumpyInterface.check` can be expensive, as we could potentially
try to
Args:
fast (bool): if ``False`` , check all interfaces and raise exceptions for
having multiple matching interfaces (default). If ``True`` ,
check each interface (as ordered by its ``priority`` , decreasing),
and return on the first match.
""" """
# first try and find a non-numpy interface, since the numpy interface # first try and find a non-numpy interface, since the numpy interface
# will try and load the array into memory in its check method # will try and load the array into memory in its check method
@ -160,17 +196,24 @@ class Interface(ABC, Generic[T]):
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"] non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0] np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
matches = [i for i in non_np_interfaces if i.check(array)] if fast:
matches = []
for i in non_np_interfaces:
if i.check(array):
return i
else:
matches = [i for i in non_np_interfaces if i.check(array)]
if len(matches) > 1: if len(matches) > 1:
msg = f"More than one interface matches input {array}:\n" msg = f"More than one interface matches input {array}:\n"
msg += "\n".join([f" - {i}" for i in matches]) msg += "\n".join([f" - {i}" for i in matches])
raise ValueError(msg) raise TooManyMatchesError(msg)
elif len(matches) == 0: elif len(matches) == 0:
# now try the numpy interface # now try the numpy interface
if np_interface.check(array): if np_interface.check(array):
return np_interface return np_interface
else: else:
raise ValueError(f"No matching interfaces found for input {array}") raise NoMatchError(f"No matching interfaces found for input {array}")
else: else:
return matches[0] return matches[0]
@ -186,8 +229,8 @@ class Interface(ABC, Generic[T]):
if len(matches) > 1: if len(matches) > 1:
msg = f"More than one interface matches output {array}:\n" msg = f"More than one interface matches output {array}:\n"
msg += "\n".join([f" - {i}" for i in matches]) msg += "\n".join([f" - {i}" for i in matches])
raise ValueError(msg) raise TooManyMatchesError(msg)
elif len(matches) == 0: elif len(matches) == 0:
raise ValueError(f"No matching interfaces found for output {array}") raise NoMatchError(f"No matching interfaces found for output {array}")
else: else:
return matches[0] return matches[0]

View file

@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
""" """
from typing import Any, Tuple from typing import TYPE_CHECKING, Any, Tuple
import numpy as np import numpy as np
from nptyping.error import InvalidArgumentsError from nptyping.error import InvalidArgumentsError
@ -28,6 +28,8 @@ from pydantic import GetJsonSchemaHandler
from pydantic_core import core_schema from pydantic_core import core_schema
from numpydantic.dtype import DType from numpydantic.dtype import DType
from numpydantic.exceptions import InterfaceError
from numpydantic.interface import Interface
from numpydantic.maps import python_to_nptyping from numpydantic.maps import python_to_nptyping
from numpydantic.schema import ( from numpydantic.schema import (
_handler_type, _handler_type,
@ -37,6 +39,9 @@ from numpydantic.schema import (
) )
from numpydantic.types import DtypeType, ShapeType from numpydantic.types import DtypeType, ShapeType
if TYPE_CHECKING:
from nptyping.base_meta_classes import SubscriptableMeta
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
""" """
@ -44,6 +49,35 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
completion of the transition away from nptyping completion of the transition away from nptyping
""" """
if TYPE_CHECKING:
__getitem__ = SubscriptableMeta.__getitem__
def __instancecheck__(self, instance: Any):
"""
Extended type checking that determines whether
1) the ``type`` of the given instance is one of those in
:meth:`.Interface.input_types`
but also
2) it satisfies the constraints set on the :class:`.NDArray` annotation
Args:
instance (:class:`typing.Any`): Thing to check!
Returns:
bool: ``True`` if matches constraints, ``False`` otherwise.
"""
shape, dtype = self.__args__
try:
interface_cls = Interface.match(instance, fast=True)
interface = interface_cls(shape, dtype)
_ = interface.validate(instance)
return True
except InterfaceError:
return False
def _get_dtype(cls, dtype_candidate: Any) -> DType: def _get_dtype(cls, dtype_candidate: Any) -> DType:
""" """
Override of base _get_dtype method to allow for compound tuple types Override of base _get_dtype method to allow for compound tuple types

View file

@ -225,7 +225,9 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
:meth:`.Interface.validate` method :meth:`.Interface.validate` method
""" """
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType: def validate_interface(
value: Any, info: Optional["ValidationInfo"] = None
) -> NDArrayType:
interface_cls = Interface.match(value) interface_cls = Interface.match(value)
interface = interface_cls(shape, dtype) interface = interface_cls(shape, dtype)
value = interface.validate(value) value = interface.validate(value)

View file

@ -223,3 +223,23 @@ def test_json_schema_ellipsis():
schema = ConstrainedAnyShape.model_json_schema() schema = ConstrainedAnyShape.model_json_schema()
_recursive_array(schema) _recursive_array(schema)
def test_instancecheck():
"""
NDArray should handle ``isinstance()`` s.t. valid arrays are ``True``
and invalid arrays are ``False``
We don't make this test exhaustive because correctness of validation
is tested elsewhere. We are just testing that the type checking works
"""
array_type = NDArray[Shape["1, 2, 3"], int]
assert isinstance(np.zeros((1, 2, 3), dtype=int), array_type)
assert not isinstance(np.zeros((2, 2, 3), dtype=int), array_type)
assert not isinstance(np.zeros((1, 2, 3), dtype=float), array_type)
def my_function(array: NDArray[Shape["1, 2, 3"], int]):
return array
my_function(np.zeros((1, 2, 3), int))