diff --git a/docs/design.md b/docs/design.md index b06c4af..2a8d6cc 100644 --- a/docs/design.md +++ b/docs/design.md @@ -17,6 +17,7 @@ relatively low. Its `Dtype[ArrayClass, "{shape_expression}"]` syntax is not well suited for modeling arrays intended to be general across implementations, and makes it challenging to adapt to pydantic's schema generation system. +(design_challenges)= ## Challenges The Python type annotation system is weird and not like the rest of Python! diff --git a/docs/index.md b/docs/index.md index 3d9da31..1e6a1f8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,6 +57,25 @@ model = MyModel(array=('data.zarr', '/nested/dataset')) model = MyModel(array="data.mp4") ``` +And use the `NDArray` type annotation like a regular type outside +of pydantic -- eg. to validate an array anywhere, use `isinstance`: + +```python +array_type = NDArray[Shape["1, 2, 3"], int] +isinstance(np.zeros((1,2,3), dtype=int), array_type) +# True +isinstance(zarr.zeros((1,2,3), dtype=int), array_type) +# True +isinstance(np.zeros((4,5,6), dtype=int), array_type) +# False +isinstance(np.zeros((1,2,3), dtype=float), array_type) +# False +``` + +```{note} +`NDArray` can't do validation with static type checkers yet, see +{ref}`design_challenges` and {ref}`type_checkers` +``` ## Features: - **Types** - Annotations (based on [npytyping](https://github.com/ramonhagenaars/nptyping)) diff --git a/docs/todo.md b/docs/todo.md index 6d94b8f..9080e29 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -10,6 +10,21 @@ type system and is no longer actively maintained. We will be reimplementing a sy that extends its array specification syntax to include things like ranges and extensible dtypes with varying precision (and is much less finnicky to deal with). +(type_checkers)= +## Type Checker Integration + +The `.pyi` stubfile generation ({mod}`numpydantic.meta`) works for +keeping type checkers from complaining about various array formats +not literally being `NDArray` objects, but it doesn't do the kind of +validation we would want to be able to use `NDArray` objects as full-fledged +python types, including validation propagation through scopes and +IDE type checking for invalid literals. + +We want to hook into the type checking process to satisfy these type checkers: +- mypy - has hooks, can be done with an extension +- pyright - unclear if has hooks, might nee to monkeypatch +- pycharm - unlikely this is possible, extensions need to be in Java and installed separately + ## Validation diff --git a/src/numpydantic/exceptions.py b/src/numpydantic/exceptions.py index 20d4981..252dbbd 100644 --- a/src/numpydantic/exceptions.py +++ b/src/numpydantic/exceptions.py @@ -3,9 +3,25 @@ Exceptions used within numpydantic """ -class DtypeError(TypeError): +class InterfaceError(Exception): + """Parent mixin class for errors raised by :class:`.Interface` subclasses""" + + +class DtypeError(TypeError, InterfaceError): """Exception raised for invalid dtypes""" -class ShapeError(ValueError): +class ShapeError(ValueError, InterfaceError): """Exception raise for invalid shapes""" + + +class MatchError(ValueError, InterfaceError): + """Exception for errors raised during :class:`.Interface.match`-ing""" + + +class NoMatchError(MatchError): + """No match was found by :class:`.Interface.match`""" + + +class TooManyMatchesError(MatchError): + """Too many matches found by :class:`.Interface.match`""" diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 14e1231..d7d19c1 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -10,7 +10,12 @@ import numpy as np from nptyping.shape_expression import check_shape from pydantic import SerializationInfo -from numpydantic.exceptions import DtypeError, ShapeError +from numpydantic.exceptions import ( + DtypeError, + NoMatchError, + ShapeError, + TooManyMatchesError, +) from numpydantic.types import DtypeType, NDArrayType, ShapeType T = TypeVar("T", bound=NDArrayType) @@ -32,6 +37,25 @@ class Interface(ABC, Generic[T]): def validate(self, array: Any) -> T: """ Validate input, returning final array type + + Calls the methods, in order: + + * :meth:`.before_validation` + * :meth:`.validate_dtype` + * :meth:`.validate_shape` + * :meth:`.after_validation` + + passing the ``array`` argument and returning it from each. + + Implementing an interface subclass largely consists of overriding these methods + as needed. + + Raises: + If validation fails, rather than eg. returning ``False``, exceptions will + be raised (to halt the rest of the pydantic validation process). + When using interfaces outside of pydantic, you must catch both + :class:`.DtypeError` and :class:`.ShapeError` (both of which are children + of :class:`.InterfaceError` ) """ array = self.before_validation(array) array = self.validate_dtype(array) @@ -120,17 +144,38 @@ class Interface(ABC, Generic[T]): return array.tolist() @classmethod - def interfaces(cls) -> Tuple[Type["Interface"], ...]: + def interfaces( + cls, with_disabled: bool = False, sort: bool = True + ) -> Tuple[Type["Interface"], ...]: """ Enabled interface subclasses + + Args: + with_disabled (bool): If ``True`` , get every known interface. + If ``False`` (default), get only enabled interfaces. + sort (bool): If ``True`` (default), sort interfaces by priority. + If ``False`` , sorted by definition order. Used for recursion: + we only want to sort once at the top level. """ - return tuple( - sorted( - [i for i in Interface.__subclasses__() if i.enabled()], + # get recursively + subclasses = [] + for i in cls.__subclasses__(): + if with_disabled: + subclasses.append(i) + + if i.enabled(): + subclasses.append(i) + + subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False)) + + if sort: + subclasses = sorted( + subclasses, key=attrgetter("priority"), reverse=True, ) - ) + + return tuple(subclasses) @classmethod def return_types(cls) -> Tuple[NDArrayType, ...]: @@ -150,9 +195,21 @@ class Interface(ABC, Generic[T]): return tuple(in_types) @classmethod - def match(cls, array: Any) -> Type["Interface"]: + def match(cls, array: Any, fast: bool = False) -> Type["Interface"]: """ Find the interface that should be used for this array based on its input type + + First runs the ``check`` method for all interfaces returned by + :meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` , + and if no match is found then try the numpy interface. This is because + :meth:`.NumpyInterface.check` can be expensive, as we could potentially + try to + + Args: + fast (bool): if ``False`` , check all interfaces and raise exceptions for + having multiple matching interfaces (default). If ``True`` , + check each interface (as ordered by its ``priority`` , decreasing), + and return on the first match. """ # first try and find a non-numpy interface, since the numpy interface # will try and load the array into memory in its check method @@ -160,17 +217,24 @@ class Interface(ABC, Generic[T]): non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"] np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0] - matches = [i for i in non_np_interfaces if i.check(array)] + if fast: + matches = [] + for i in non_np_interfaces: + if i.check(array): + return i + else: + matches = [i for i in non_np_interfaces if i.check(array)] + if len(matches) > 1: msg = f"More than one interface matches input {array}:\n" msg += "\n".join([f" - {i}" for i in matches]) - raise ValueError(msg) + raise TooManyMatchesError(msg) elif len(matches) == 0: # now try the numpy interface if np_interface.check(array): return np_interface else: - raise ValueError(f"No matching interfaces found for input {array}") + raise NoMatchError(f"No matching interfaces found for input {array}") else: return matches[0] @@ -186,8 +250,8 @@ class Interface(ABC, Generic[T]): if len(matches) > 1: msg = f"More than one interface matches output {array}:\n" msg += "\n".join([f" - {i}" for i in matches]) - raise ValueError(msg) + raise TooManyMatchesError(msg) elif len(matches) == 0: - raise ValueError(f"No matching interfaces found for output {array}") + raise NoMatchError(f"No matching interfaces found for output {array}") else: return matches[0] diff --git a/src/numpydantic/meta.py b/src/numpydantic/meta.py index 671ab8d..3b5222d 100644 --- a/src/numpydantic/meta.py +++ b/src/numpydantic/meta.py @@ -24,7 +24,8 @@ def generate_ndarray_stub() -> str: # Create import statements, saving aliased name of type if needed if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing": type_name = str(arr) if arr.__module__ == "typing" else arr.__name__ - import_strings.append(f"from {arr.__module__} import {type_name}") + if arr.__module__ != "typing": + import_strings.append(f"from {arr.__module__} import {type_name}") else: # since other packages could use the same name for an imported object # (eg dask and zarr both use an Array class) @@ -39,6 +40,7 @@ def generate_ndarray_stub() -> str: type_names.append(type_name) import_strings.extend(_BUILTIN_IMPORTS) + import_strings = list(dict.fromkeys(import_strings)) import_string = "\n".join(import_strings) class_union = " | ".join(type_names) diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index ddc1316..5bc539c 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa """ -from typing import Any, Tuple +from typing import TYPE_CHECKING, Any, Tuple import numpy as np from nptyping.error import InvalidArgumentsError @@ -28,6 +28,8 @@ from pydantic import GetJsonSchemaHandler from pydantic_core import core_schema from numpydantic.dtype import DType +from numpydantic.exceptions import InterfaceError +from numpydantic.interface import Interface from numpydantic.maps import python_to_nptyping from numpydantic.schema import ( _handler_type, @@ -37,6 +39,9 @@ from numpydantic.schema import ( ) from numpydantic.types import DtypeType, ShapeType +if TYPE_CHECKING: # pragma: no cover + from nptyping.base_meta_classes import SubscriptableMeta + class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): """ @@ -44,6 +49,35 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): completion of the transition away from nptyping """ + if TYPE_CHECKING: # pragma: no cover + __getitem__ = SubscriptableMeta.__getitem__ + + def __instancecheck__(self, instance: Any): + """ + Extended type checking that determines whether + + 1) the ``type`` of the given instance is one of those in + :meth:`.Interface.input_types` + + but also + + 2) it satisfies the constraints set on the :class:`.NDArray` annotation + + Args: + instance (:class:`typing.Any`): Thing to check! + + Returns: + bool: ``True`` if matches constraints, ``False`` otherwise. + """ + shape, dtype = self.__args__ + try: + interface_cls = Interface.match(instance, fast=True) + interface = interface_cls(shape, dtype) + _ = interface.validate(instance) + return True + except InterfaceError: + return False + def _get_dtype(cls, dtype_candidate: Any) -> DType: """ Override of base _get_dtype method to allow for compound tuple types diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index 0233610..e610df4 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -225,7 +225,9 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: :meth:`.Interface.validate` method """ - def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType: + def validate_interface( + value: Any, info: Optional["ValidationInfo"] = None + ) -> NDArrayType: interface_cls = Interface.match(value) interface = interface_cls(shape, dtype) value = interface.validate(value) diff --git a/tests/test_interface/test_interface.py b/tests/test_interface/test_interface.py index 0efd38d..e8fbe85 100644 --- a/tests/test_interface/test_interface.py +++ b/tests/test_interface/test_interface.py @@ -1,5 +1,3 @@ -import pdb - import pytest import numpy as np @@ -14,9 +12,12 @@ def interfaces(): class Interface1(Interface): input_types = (list,) return_type = tuple + priority = 1000 + checked = False @classmethod def check(cls, array): + cls.checked = True if isinstance(array, list): return True return False @@ -26,18 +27,34 @@ def interfaces(): return True Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__)) + Interface2.checked = False + Interface2.priority = 999 class Interface3(Interface1): + priority = 998 + checked = False + @classmethod def enabled(cls) -> bool: return False + class Interface4(Interface3): + priority = 997 + checked = False + + @classmethod + def enabled(cls) -> bool: + return True + class Interfaces: interface1 = Interface1 interface2 = Interface2 interface3 = Interface3 + interface4 = Interface4 yield Interfaces + # Interface.__subclasses__().remove(Interface1) + # Interface.__subclasses__().remove(Interface2) del Interface1 del Interface2 del Interface3 @@ -66,6 +83,20 @@ def test_interface_match_error(interfaces): assert "No matching interfaces" in str(e.value) +def test_interface_match_fast(interfaces): + """ + fast matching should return as soon as an interface is found + and not raise an error for duplicates + """ + Interface.interfaces()[0].checked = False + Interface.interfaces()[1].checked = False + # this doesnt' raise an error + matched = Interface.match([1, 2, 3], fast=True) + assert matched == Interface.interfaces()[0] + assert Interface.interfaces()[0].checked + assert not Interface.interfaces()[1].checked + + def test_interface_enabled(interfaces): """ An interface shouldn't be included if it's not enabled @@ -101,3 +132,22 @@ def test_interfaces_sorting(): ifaces = Interface.interfaces() priorities = [i.priority for i in ifaces] assert (np.diff(priorities) <= 0).all() + + +def test_interface_with_disabled(interfaces): + """ + Get all interfaces, even if not enabled + """ + ifaces = Interface.interfaces(with_disabled=True) + assert interfaces.interface3 in ifaces + + +def test_interface_recursive(interfaces): + """ + Get all interfaces, including subclasses of subclasses + """ + ifaces = Interface.interfaces() + assert issubclass(interfaces.interface4, interfaces.interface3) + assert issubclass(interfaces.interface3, interfaces.interface1) + assert issubclass(interfaces.interface1, Interface) + assert interfaces.interface4 in ifaces diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 39972f1..583bd5f 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -223,3 +223,23 @@ def test_json_schema_ellipsis(): schema = ConstrainedAnyShape.model_json_schema() _recursive_array(schema) + + +def test_instancecheck(): + """ + NDArray should handle ``isinstance()`` s.t. valid arrays are ``True`` + and invalid arrays are ``False`` + + We don't make this test exhaustive because correctness of validation + is tested elsewhere. We are just testing that the type checking works + """ + array_type = NDArray[Shape["1, 2, 3"], int] + + assert isinstance(np.zeros((1, 2, 3), dtype=int), array_type) + assert not isinstance(np.zeros((2, 2, 3), dtype=int), array_type) + assert not isinstance(np.zeros((1, 2, 3), dtype=float), array_type) + + def my_function(array: NDArray[Shape["1, 2, 3"], int]): + return array + + my_function(np.zeros((1, 2, 3), int))