mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
working instancecheck, but not working static analysis
This commit is contained in:
parent
b0b391947f
commit
1290d64833
5 changed files with 126 additions and 11 deletions
|
@ -3,9 +3,25 @@ Exceptions used within numpydantic
|
|||
"""
|
||||
|
||||
|
||||
class DtypeError(TypeError):
|
||||
class InterfaceError(Exception):
|
||||
"""Parent mixin class for errors raised by :class:`.Interface` subclasses"""
|
||||
|
||||
|
||||
class DtypeError(TypeError, InterfaceError):
|
||||
"""Exception raised for invalid dtypes"""
|
||||
|
||||
|
||||
class ShapeError(ValueError):
|
||||
class ShapeError(ValueError, InterfaceError):
|
||||
"""Exception raise for invalid shapes"""
|
||||
|
||||
|
||||
class MatchError(ValueError, InterfaceError):
|
||||
"""Exception for errors raised during :class:`.Interface.match`-ing"""
|
||||
|
||||
|
||||
class NoMatchError(MatchError):
|
||||
"""No match was found by :class:`.Interface.match`"""
|
||||
|
||||
|
||||
class TooManyMatchesError(MatchError):
|
||||
"""Too many matches found by :class:`.Interface.match`"""
|
||||
|
|
|
@ -10,7 +10,12 @@ import numpy as np
|
|||
from nptyping.shape_expression import check_shape
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.exceptions import DtypeError, ShapeError
|
||||
from numpydantic.exceptions import (
|
||||
DtypeError,
|
||||
NoMatchError,
|
||||
ShapeError,
|
||||
TooManyMatchesError,
|
||||
)
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
T = TypeVar("T", bound=NDArrayType)
|
||||
|
@ -32,6 +37,25 @@ class Interface(ABC, Generic[T]):
|
|||
def validate(self, array: Any) -> T:
|
||||
"""
|
||||
Validate input, returning final array type
|
||||
|
||||
Calls the methods, in order:
|
||||
|
||||
* :meth:`.before_validation`
|
||||
* :meth:`.validate_dtype`
|
||||
* :meth:`.validate_shape`
|
||||
* :meth:`.after_validation`
|
||||
|
||||
passing the ``array`` argument and returning it from each.
|
||||
|
||||
Implementing an interface subclass largely consists of overriding these methods
|
||||
as needed.
|
||||
|
||||
Raises:
|
||||
If validation fails, rather than eg. returning ``False``, exceptions will
|
||||
be raised (to halt the rest of the pydantic validation process).
|
||||
When using interfaces outside of pydantic, you must catch both
|
||||
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
||||
of :class:`.InterfaceError` )
|
||||
"""
|
||||
array = self.before_validation(array)
|
||||
array = self.validate_dtype(array)
|
||||
|
@ -150,9 +174,21 @@ class Interface(ABC, Generic[T]):
|
|||
return tuple(in_types)
|
||||
|
||||
@classmethod
|
||||
def match(cls, array: Any) -> Type["Interface"]:
|
||||
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
||||
"""
|
||||
Find the interface that should be used for this array based on its input type
|
||||
|
||||
First runs the ``check`` method for all interfaces returned by
|
||||
:meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
|
||||
and if no match is found then try the numpy interface. This is because
|
||||
:meth:`.NumpyInterface.check` can be expensive, as we could potentially
|
||||
try to
|
||||
|
||||
Args:
|
||||
fast (bool): if ``False`` , check all interfaces and raise exceptions for
|
||||
having multiple matching interfaces (default). If ``True`` ,
|
||||
check each interface (as ordered by its ``priority`` , decreasing),
|
||||
and return on the first match.
|
||||
"""
|
||||
# first try and find a non-numpy interface, since the numpy interface
|
||||
# will try and load the array into memory in its check method
|
||||
|
@ -160,17 +196,24 @@ class Interface(ABC, Generic[T]):
|
|||
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
|
||||
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
|
||||
|
||||
matches = [i for i in non_np_interfaces if i.check(array)]
|
||||
if fast:
|
||||
matches = []
|
||||
for i in non_np_interfaces:
|
||||
if i.check(array):
|
||||
return i
|
||||
else:
|
||||
matches = [i for i in non_np_interfaces if i.check(array)]
|
||||
|
||||
if len(matches) > 1:
|
||||
msg = f"More than one interface matches input {array}:\n"
|
||||
msg += "\n".join([f" - {i}" for i in matches])
|
||||
raise ValueError(msg)
|
||||
raise TooManyMatchesError(msg)
|
||||
elif len(matches) == 0:
|
||||
# now try the numpy interface
|
||||
if np_interface.check(array):
|
||||
return np_interface
|
||||
else:
|
||||
raise ValueError(f"No matching interfaces found for input {array}")
|
||||
raise NoMatchError(f"No matching interfaces found for input {array}")
|
||||
else:
|
||||
return matches[0]
|
||||
|
||||
|
@ -186,8 +229,8 @@ class Interface(ABC, Generic[T]):
|
|||
if len(matches) > 1:
|
||||
msg = f"More than one interface matches output {array}:\n"
|
||||
msg += "\n".join([f" - {i}" for i in matches])
|
||||
raise ValueError(msg)
|
||||
raise TooManyMatchesError(msg)
|
||||
elif len(matches) == 0:
|
||||
raise ValueError(f"No matching interfaces found for output {array}")
|
||||
raise NoMatchError(f"No matching interfaces found for output {array}")
|
||||
else:
|
||||
return matches[0]
|
||||
|
|
|
@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
|
|||
|
||||
"""
|
||||
|
||||
from typing import Any, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
from nptyping.error import InvalidArgumentsError
|
||||
|
@ -28,6 +28,8 @@ from pydantic import GetJsonSchemaHandler
|
|||
from pydantic_core import core_schema
|
||||
|
||||
from numpydantic.dtype import DType
|
||||
from numpydantic.exceptions import InterfaceError
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import python_to_nptyping
|
||||
from numpydantic.schema import (
|
||||
_handler_type,
|
||||
|
@ -37,6 +39,9 @@ from numpydantic.schema import (
|
|||
)
|
||||
from numpydantic.types import DtypeType, ShapeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nptyping.base_meta_classes import SubscriptableMeta
|
||||
|
||||
|
||||
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||
"""
|
||||
|
@ -44,6 +49,35 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
|||
completion of the transition away from nptyping
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
__getitem__ = SubscriptableMeta.__getitem__
|
||||
|
||||
def __instancecheck__(self, instance: Any):
|
||||
"""
|
||||
Extended type checking that determines whether
|
||||
|
||||
1) the ``type`` of the given instance is one of those in
|
||||
:meth:`.Interface.input_types`
|
||||
|
||||
but also
|
||||
|
||||
2) it satisfies the constraints set on the :class:`.NDArray` annotation
|
||||
|
||||
Args:
|
||||
instance (:class:`typing.Any`): Thing to check!
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if matches constraints, ``False`` otherwise.
|
||||
"""
|
||||
shape, dtype = self.__args__
|
||||
try:
|
||||
interface_cls = Interface.match(instance, fast=True)
|
||||
interface = interface_cls(shape, dtype)
|
||||
_ = interface.validate(instance)
|
||||
return True
|
||||
except InterfaceError:
|
||||
return False
|
||||
|
||||
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
||||
"""
|
||||
Override of base _get_dtype method to allow for compound tuple types
|
||||
|
|
|
@ -225,7 +225,9 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
|||
:meth:`.Interface.validate` method
|
||||
"""
|
||||
|
||||
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType:
|
||||
def validate_interface(
|
||||
value: Any, info: Optional["ValidationInfo"] = None
|
||||
) -> NDArrayType:
|
||||
interface_cls = Interface.match(value)
|
||||
interface = interface_cls(shape, dtype)
|
||||
value = interface.validate(value)
|
||||
|
|
|
@ -223,3 +223,23 @@ def test_json_schema_ellipsis():
|
|||
|
||||
schema = ConstrainedAnyShape.model_json_schema()
|
||||
_recursive_array(schema)
|
||||
|
||||
|
||||
def test_instancecheck():
|
||||
"""
|
||||
NDArray should handle ``isinstance()`` s.t. valid arrays are ``True``
|
||||
and invalid arrays are ``False``
|
||||
|
||||
We don't make this test exhaustive because correctness of validation
|
||||
is tested elsewhere. We are just testing that the type checking works
|
||||
"""
|
||||
array_type = NDArray[Shape["1, 2, 3"], int]
|
||||
|
||||
assert isinstance(np.zeros((1, 2, 3), dtype=int), array_type)
|
||||
assert not isinstance(np.zeros((2, 2, 3), dtype=int), array_type)
|
||||
assert not isinstance(np.zeros((1, 2, 3), dtype=float), array_type)
|
||||
|
||||
def my_function(array: NDArray[Shape["1, 2, 3"], int]):
|
||||
return array
|
||||
|
||||
my_function(np.zeros((1, 2, 3), int))
|
||||
|
|
Loading…
Reference in a new issue