Merge pull request #1 from p2p-ld/feat-instancecheck

Instance Checking
This commit is contained in:
Jonny Saunders 2024-05-24 18:52:58 -07:00 committed by GitHub
commit 13a8fce4ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 242 additions and 19 deletions

View file

@ -17,6 +17,7 @@ relatively low. Its `Dtype[ArrayClass, "{shape_expression}"]` syntax is not well
suited for modeling arrays intended to be general across implementations, and suited for modeling arrays intended to be general across implementations, and
makes it challenging to adapt to pydantic's schema generation system. makes it challenging to adapt to pydantic's schema generation system.
(design_challenges)=
## Challenges ## Challenges
The Python type annotation system is weird and not like the rest of Python! The Python type annotation system is weird and not like the rest of Python!

View file

@ -57,6 +57,25 @@ model = MyModel(array=('data.zarr', '/nested/dataset'))
model = MyModel(array="data.mp4") model = MyModel(array="data.mp4")
``` ```
And use the `NDArray` type annotation like a regular type outside
of pydantic -- eg. to validate an array anywhere, use `isinstance`:
```python
array_type = NDArray[Shape["1, 2, 3"], int]
isinstance(np.zeros((1,2,3), dtype=int), array_type)
# True
isinstance(zarr.zeros((1,2,3), dtype=int), array_type)
# True
isinstance(np.zeros((4,5,6), dtype=int), array_type)
# False
isinstance(np.zeros((1,2,3), dtype=float), array_type)
# False
```
```{note}
`NDArray` can't do validation with static type checkers yet, see
{ref}`design_challenges` and {ref}`type_checkers`
```
## Features: ## Features:
- **Types** - Annotations (based on [npytyping](https://github.com/ramonhagenaars/nptyping)) - **Types** - Annotations (based on [npytyping](https://github.com/ramonhagenaars/nptyping))

View file

@ -10,6 +10,21 @@ type system and is no longer actively maintained. We will be reimplementing a sy
that extends its array specification syntax to include things like ranges and extensible that extends its array specification syntax to include things like ranges and extensible
dtypes with varying precision (and is much less finnicky to deal with). dtypes with varying precision (and is much less finnicky to deal with).
(type_checkers)=
## Type Checker Integration
The `.pyi` stubfile generation ({mod}`numpydantic.meta`) works for
keeping type checkers from complaining about various array formats
not literally being `NDArray` objects, but it doesn't do the kind of
validation we would want to be able to use `NDArray` objects as full-fledged
python types, including validation propagation through scopes and
IDE type checking for invalid literals.
We want to hook into the type checking process to satisfy these type checkers:
- mypy - has hooks, can be done with an extension
- pyright - unclear if has hooks, might nee to monkeypatch
- pycharm - unlikely this is possible, extensions need to be in Java and installed separately
## Validation ## Validation

View file

@ -3,9 +3,25 @@ Exceptions used within numpydantic
""" """
class DtypeError(TypeError): class InterfaceError(Exception):
"""Parent mixin class for errors raised by :class:`.Interface` subclasses"""
class DtypeError(TypeError, InterfaceError):
"""Exception raised for invalid dtypes""" """Exception raised for invalid dtypes"""
class ShapeError(ValueError): class ShapeError(ValueError, InterfaceError):
"""Exception raise for invalid shapes""" """Exception raise for invalid shapes"""
class MatchError(ValueError, InterfaceError):
"""Exception for errors raised during :class:`.Interface.match`-ing"""
class NoMatchError(MatchError):
"""No match was found by :class:`.Interface.match`"""
class TooManyMatchesError(MatchError):
"""Too many matches found by :class:`.Interface.match`"""

View file

@ -10,7 +10,12 @@ import numpy as np
from nptyping.shape_expression import check_shape from nptyping.shape_expression import check_shape
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import (
DtypeError,
NoMatchError,
ShapeError,
TooManyMatchesError,
)
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
T = TypeVar("T", bound=NDArrayType) T = TypeVar("T", bound=NDArrayType)
@ -32,6 +37,25 @@ class Interface(ABC, Generic[T]):
def validate(self, array: Any) -> T: def validate(self, array: Any) -> T:
""" """
Validate input, returning final array type Validate input, returning final array type
Calls the methods, in order:
* :meth:`.before_validation`
* :meth:`.validate_dtype`
* :meth:`.validate_shape`
* :meth:`.after_validation`
passing the ``array`` argument and returning it from each.
Implementing an interface subclass largely consists of overriding these methods
as needed.
Raises:
If validation fails, rather than eg. returning ``False``, exceptions will
be raised (to halt the rest of the pydantic validation process).
When using interfaces outside of pydantic, you must catch both
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
of :class:`.InterfaceError` )
""" """
array = self.before_validation(array) array = self.before_validation(array)
array = self.validate_dtype(array) array = self.validate_dtype(array)
@ -120,17 +144,38 @@ class Interface(ABC, Generic[T]):
return array.tolist() return array.tolist()
@classmethod @classmethod
def interfaces(cls) -> Tuple[Type["Interface"], ...]: def interfaces(
cls, with_disabled: bool = False, sort: bool = True
) -> Tuple[Type["Interface"], ...]:
""" """
Enabled interface subclasses Enabled interface subclasses
Args:
with_disabled (bool): If ``True`` , get every known interface.
If ``False`` (default), get only enabled interfaces.
sort (bool): If ``True`` (default), sort interfaces by priority.
If ``False`` , sorted by definition order. Used for recursion:
we only want to sort once at the top level.
""" """
return tuple( # get recursively
sorted( subclasses = []
[i for i in Interface.__subclasses__() if i.enabled()], for i in cls.__subclasses__():
if with_disabled:
subclasses.append(i)
if i.enabled():
subclasses.append(i)
subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
if sort:
subclasses = sorted(
subclasses,
key=attrgetter("priority"), key=attrgetter("priority"),
reverse=True, reverse=True,
) )
)
return tuple(subclasses)
@classmethod @classmethod
def return_types(cls) -> Tuple[NDArrayType, ...]: def return_types(cls) -> Tuple[NDArrayType, ...]:
@ -150,9 +195,21 @@ class Interface(ABC, Generic[T]):
return tuple(in_types) return tuple(in_types)
@classmethod @classmethod
def match(cls, array: Any) -> Type["Interface"]: def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
""" """
Find the interface that should be used for this array based on its input type Find the interface that should be used for this array based on its input type
First runs the ``check`` method for all interfaces returned by
:meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
and if no match is found then try the numpy interface. This is because
:meth:`.NumpyInterface.check` can be expensive, as we could potentially
try to
Args:
fast (bool): if ``False`` , check all interfaces and raise exceptions for
having multiple matching interfaces (default). If ``True`` ,
check each interface (as ordered by its ``priority`` , decreasing),
and return on the first match.
""" """
# first try and find a non-numpy interface, since the numpy interface # first try and find a non-numpy interface, since the numpy interface
# will try and load the array into memory in its check method # will try and load the array into memory in its check method
@ -160,17 +217,24 @@ class Interface(ABC, Generic[T]):
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"] non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0] np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
matches = [i for i in non_np_interfaces if i.check(array)] if fast:
matches = []
for i in non_np_interfaces:
if i.check(array):
return i
else:
matches = [i for i in non_np_interfaces if i.check(array)]
if len(matches) > 1: if len(matches) > 1:
msg = f"More than one interface matches input {array}:\n" msg = f"More than one interface matches input {array}:\n"
msg += "\n".join([f" - {i}" for i in matches]) msg += "\n".join([f" - {i}" for i in matches])
raise ValueError(msg) raise TooManyMatchesError(msg)
elif len(matches) == 0: elif len(matches) == 0:
# now try the numpy interface # now try the numpy interface
if np_interface.check(array): if np_interface.check(array):
return np_interface return np_interface
else: else:
raise ValueError(f"No matching interfaces found for input {array}") raise NoMatchError(f"No matching interfaces found for input {array}")
else: else:
return matches[0] return matches[0]
@ -186,8 +250,8 @@ class Interface(ABC, Generic[T]):
if len(matches) > 1: if len(matches) > 1:
msg = f"More than one interface matches output {array}:\n" msg = f"More than one interface matches output {array}:\n"
msg += "\n".join([f" - {i}" for i in matches]) msg += "\n".join([f" - {i}" for i in matches])
raise ValueError(msg) raise TooManyMatchesError(msg)
elif len(matches) == 0: elif len(matches) == 0:
raise ValueError(f"No matching interfaces found for output {array}") raise NoMatchError(f"No matching interfaces found for output {array}")
else: else:
return matches[0] return matches[0]

View file

@ -24,7 +24,8 @@ def generate_ndarray_stub() -> str:
# Create import statements, saving aliased name of type if needed # Create import statements, saving aliased name of type if needed
if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing": if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing":
type_name = str(arr) if arr.__module__ == "typing" else arr.__name__ type_name = str(arr) if arr.__module__ == "typing" else arr.__name__
import_strings.append(f"from {arr.__module__} import {type_name}") if arr.__module__ != "typing":
import_strings.append(f"from {arr.__module__} import {type_name}")
else: else:
# since other packages could use the same name for an imported object # since other packages could use the same name for an imported object
# (eg dask and zarr both use an Array class) # (eg dask and zarr both use an Array class)
@ -39,6 +40,7 @@ def generate_ndarray_stub() -> str:
type_names.append(type_name) type_names.append(type_name)
import_strings.extend(_BUILTIN_IMPORTS) import_strings.extend(_BUILTIN_IMPORTS)
import_strings = list(dict.fromkeys(import_strings))
import_string = "\n".join(import_strings) import_string = "\n".join(import_strings)
class_union = " | ".join(type_names) class_union = " | ".join(type_names)

View file

@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
""" """
from typing import Any, Tuple from typing import TYPE_CHECKING, Any, Tuple
import numpy as np import numpy as np
from nptyping.error import InvalidArgumentsError from nptyping.error import InvalidArgumentsError
@ -28,6 +28,8 @@ from pydantic import GetJsonSchemaHandler
from pydantic_core import core_schema from pydantic_core import core_schema
from numpydantic.dtype import DType from numpydantic.dtype import DType
from numpydantic.exceptions import InterfaceError
from numpydantic.interface import Interface
from numpydantic.maps import python_to_nptyping from numpydantic.maps import python_to_nptyping
from numpydantic.schema import ( from numpydantic.schema import (
_handler_type, _handler_type,
@ -37,6 +39,9 @@ from numpydantic.schema import (
) )
from numpydantic.types import DtypeType, ShapeType from numpydantic.types import DtypeType, ShapeType
if TYPE_CHECKING: # pragma: no cover
from nptyping.base_meta_classes import SubscriptableMeta
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
""" """
@ -44,6 +49,35 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
completion of the transition away from nptyping completion of the transition away from nptyping
""" """
if TYPE_CHECKING: # pragma: no cover
__getitem__ = SubscriptableMeta.__getitem__
def __instancecheck__(self, instance: Any):
"""
Extended type checking that determines whether
1) the ``type`` of the given instance is one of those in
:meth:`.Interface.input_types`
but also
2) it satisfies the constraints set on the :class:`.NDArray` annotation
Args:
instance (:class:`typing.Any`): Thing to check!
Returns:
bool: ``True`` if matches constraints, ``False`` otherwise.
"""
shape, dtype = self.__args__
try:
interface_cls = Interface.match(instance, fast=True)
interface = interface_cls(shape, dtype)
_ = interface.validate(instance)
return True
except InterfaceError:
return False
def _get_dtype(cls, dtype_candidate: Any) -> DType: def _get_dtype(cls, dtype_candidate: Any) -> DType:
""" """
Override of base _get_dtype method to allow for compound tuple types Override of base _get_dtype method to allow for compound tuple types

View file

@ -225,7 +225,9 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
:meth:`.Interface.validate` method :meth:`.Interface.validate` method
""" """
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType: def validate_interface(
value: Any, info: Optional["ValidationInfo"] = None
) -> NDArrayType:
interface_cls = Interface.match(value) interface_cls = Interface.match(value)
interface = interface_cls(shape, dtype) interface = interface_cls(shape, dtype)
value = interface.validate(value) value = interface.validate(value)

View file

@ -1,5 +1,3 @@
import pdb
import pytest import pytest
import numpy as np import numpy as np
@ -14,9 +12,12 @@ def interfaces():
class Interface1(Interface): class Interface1(Interface):
input_types = (list,) input_types = (list,)
return_type = tuple return_type = tuple
priority = 1000
checked = False
@classmethod @classmethod
def check(cls, array): def check(cls, array):
cls.checked = True
if isinstance(array, list): if isinstance(array, list):
return True return True
return False return False
@ -26,18 +27,34 @@ def interfaces():
return True return True
Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__)) Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__))
Interface2.checked = False
Interface2.priority = 999
class Interface3(Interface1): class Interface3(Interface1):
priority = 998
checked = False
@classmethod @classmethod
def enabled(cls) -> bool: def enabled(cls) -> bool:
return False return False
class Interface4(Interface3):
priority = 997
checked = False
@classmethod
def enabled(cls) -> bool:
return True
class Interfaces: class Interfaces:
interface1 = Interface1 interface1 = Interface1
interface2 = Interface2 interface2 = Interface2
interface3 = Interface3 interface3 = Interface3
interface4 = Interface4
yield Interfaces yield Interfaces
# Interface.__subclasses__().remove(Interface1)
# Interface.__subclasses__().remove(Interface2)
del Interface1 del Interface1
del Interface2 del Interface2
del Interface3 del Interface3
@ -66,6 +83,20 @@ def test_interface_match_error(interfaces):
assert "No matching interfaces" in str(e.value) assert "No matching interfaces" in str(e.value)
def test_interface_match_fast(interfaces):
"""
fast matching should return as soon as an interface is found
and not raise an error for duplicates
"""
Interface.interfaces()[0].checked = False
Interface.interfaces()[1].checked = False
# this doesnt' raise an error
matched = Interface.match([1, 2, 3], fast=True)
assert matched == Interface.interfaces()[0]
assert Interface.interfaces()[0].checked
assert not Interface.interfaces()[1].checked
def test_interface_enabled(interfaces): def test_interface_enabled(interfaces):
""" """
An interface shouldn't be included if it's not enabled An interface shouldn't be included if it's not enabled
@ -101,3 +132,22 @@ def test_interfaces_sorting():
ifaces = Interface.interfaces() ifaces = Interface.interfaces()
priorities = [i.priority for i in ifaces] priorities = [i.priority for i in ifaces]
assert (np.diff(priorities) <= 0).all() assert (np.diff(priorities) <= 0).all()
def test_interface_with_disabled(interfaces):
"""
Get all interfaces, even if not enabled
"""
ifaces = Interface.interfaces(with_disabled=True)
assert interfaces.interface3 in ifaces
def test_interface_recursive(interfaces):
"""
Get all interfaces, including subclasses of subclasses
"""
ifaces = Interface.interfaces()
assert issubclass(interfaces.interface4, interfaces.interface3)
assert issubclass(interfaces.interface3, interfaces.interface1)
assert issubclass(interfaces.interface1, Interface)
assert interfaces.interface4 in ifaces

View file

@ -223,3 +223,23 @@ def test_json_schema_ellipsis():
schema = ConstrainedAnyShape.model_json_schema() schema = ConstrainedAnyShape.model_json_schema()
_recursive_array(schema) _recursive_array(schema)
def test_instancecheck():
"""
NDArray should handle ``isinstance()`` s.t. valid arrays are ``True``
and invalid arrays are ``False``
We don't make this test exhaustive because correctness of validation
is tested elsewhere. We are just testing that the type checking works
"""
array_type = NDArray[Shape["1, 2, 3"], int]
assert isinstance(np.zeros((1, 2, 3), dtype=int), array_type)
assert not isinstance(np.zeros((2, 2, 3), dtype=int), array_type)
assert not isinstance(np.zeros((1, 2, 3), dtype=float), array_type)
def my_function(array: NDArray[Shape["1, 2, 3"], int]):
return array
my_function(np.zeros((1, 2, 3), int))