From 025832aa3d99fdf93dab103cd00c0a596055be4e Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 24 May 2024 18:31:04 -0700 Subject: [PATCH] testing for fast matching mode. recursively get interface classes including disabled classes --- src/numpydantic/interface/interface.py | 31 ++++++++++++--- src/numpydantic/ndarray.py | 4 +- tests/test_interface/test_interface.py | 54 +++++++++++++++++++++++++- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 805e0ba..d7d19c1 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -144,17 +144,38 @@ class Interface(ABC, Generic[T]): return array.tolist() @classmethod - def interfaces(cls) -> Tuple[Type["Interface"], ...]: + def interfaces( + cls, with_disabled: bool = False, sort: bool = True + ) -> Tuple[Type["Interface"], ...]: """ Enabled interface subclasses + + Args: + with_disabled (bool): If ``True`` , get every known interface. + If ``False`` (default), get only enabled interfaces. + sort (bool): If ``True`` (default), sort interfaces by priority. + If ``False`` , sorted by definition order. Used for recursion: + we only want to sort once at the top level. """ - return tuple( - sorted( - [i for i in Interface.__subclasses__() if i.enabled()], + # get recursively + subclasses = [] + for i in cls.__subclasses__(): + if with_disabled: + subclasses.append(i) + + if i.enabled(): + subclasses.append(i) + + subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False)) + + if sort: + subclasses = sorted( + subclasses, key=attrgetter("priority"), reverse=True, ) - ) + + return tuple(subclasses) @classmethod def return_types(cls) -> Tuple[NDArrayType, ...]: diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index ce63464..5bc539c 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -39,7 +39,7 @@ from numpydantic.schema import ( ) from numpydantic.types import DtypeType, ShapeType -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from nptyping.base_meta_classes import SubscriptableMeta @@ -49,7 +49,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): completion of the transition away from nptyping """ - if TYPE_CHECKING: + if TYPE_CHECKING: # pragma: no cover __getitem__ = SubscriptableMeta.__getitem__ def __instancecheck__(self, instance: Any): diff --git a/tests/test_interface/test_interface.py b/tests/test_interface/test_interface.py index 0efd38d..e8fbe85 100644 --- a/tests/test_interface/test_interface.py +++ b/tests/test_interface/test_interface.py @@ -1,5 +1,3 @@ -import pdb - import pytest import numpy as np @@ -14,9 +12,12 @@ def interfaces(): class Interface1(Interface): input_types = (list,) return_type = tuple + priority = 1000 + checked = False @classmethod def check(cls, array): + cls.checked = True if isinstance(array, list): return True return False @@ -26,18 +27,34 @@ def interfaces(): return True Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__)) + Interface2.checked = False + Interface2.priority = 999 class Interface3(Interface1): + priority = 998 + checked = False + @classmethod def enabled(cls) -> bool: return False + class Interface4(Interface3): + priority = 997 + checked = False + + @classmethod + def enabled(cls) -> bool: + return True + class Interfaces: interface1 = Interface1 interface2 = Interface2 interface3 = Interface3 + interface4 = Interface4 yield Interfaces + # Interface.__subclasses__().remove(Interface1) + # Interface.__subclasses__().remove(Interface2) del Interface1 del Interface2 del Interface3 @@ -66,6 +83,20 @@ def test_interface_match_error(interfaces): assert "No matching interfaces" in str(e.value) +def test_interface_match_fast(interfaces): + """ + fast matching should return as soon as an interface is found + and not raise an error for duplicates + """ + Interface.interfaces()[0].checked = False + Interface.interfaces()[1].checked = False + # this doesnt' raise an error + matched = Interface.match([1, 2, 3], fast=True) + assert matched == Interface.interfaces()[0] + assert Interface.interfaces()[0].checked + assert not Interface.interfaces()[1].checked + + def test_interface_enabled(interfaces): """ An interface shouldn't be included if it's not enabled @@ -101,3 +132,22 @@ def test_interfaces_sorting(): ifaces = Interface.interfaces() priorities = [i.priority for i in ifaces] assert (np.diff(priorities) <= 0).all() + + +def test_interface_with_disabled(interfaces): + """ + Get all interfaces, even if not enabled + """ + ifaces = Interface.interfaces(with_disabled=True) + assert interfaces.interface3 in ifaces + + +def test_interface_recursive(interfaces): + """ + Get all interfaces, including subclasses of subclasses + """ + ifaces = Interface.interfaces() + assert issubclass(interfaces.interface4, interfaces.interface3) + assert issubclass(interfaces.interface3, interfaces.interface1) + assert issubclass(interfaces.interface1, Interface) + assert interfaces.interface4 in ifaces