testing for fast matching mode. recursively get interface classes including disabled classes

This commit is contained in:
sneakers-the-rat 2024-05-24 18:31:04 -07:00
parent 1290d64833
commit 025832aa3d
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
3 changed files with 80 additions and 9 deletions

View file

@ -144,17 +144,38 @@ class Interface(ABC, Generic[T]):
return array.tolist()
@classmethod
def interfaces(cls) -> Tuple[Type["Interface"], ...]:
def interfaces(
cls, with_disabled: bool = False, sort: bool = True
) -> Tuple[Type["Interface"], ...]:
"""
Enabled interface subclasses
Args:
with_disabled (bool): If ``True`` , get every known interface.
If ``False`` (default), get only enabled interfaces.
sort (bool): If ``True`` (default), sort interfaces by priority.
If ``False`` , sorted by definition order. Used for recursion:
we only want to sort once at the top level.
"""
return tuple(
sorted(
[i for i in Interface.__subclasses__() if i.enabled()],
# get recursively
subclasses = []
for i in cls.__subclasses__():
if with_disabled:
subclasses.append(i)
if i.enabled():
subclasses.append(i)
subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
if sort:
subclasses = sorted(
subclasses,
key=attrgetter("priority"),
reverse=True,
)
)
return tuple(subclasses)
@classmethod
def return_types(cls) -> Tuple[NDArrayType, ...]:

View file

@ -39,7 +39,7 @@ from numpydantic.schema import (
)
from numpydantic.types import DtypeType, ShapeType
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from nptyping.base_meta_classes import SubscriptableMeta
@ -49,7 +49,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
completion of the transition away from nptyping
"""
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
__getitem__ = SubscriptableMeta.__getitem__
def __instancecheck__(self, instance: Any):

View file

@ -1,5 +1,3 @@
import pdb
import pytest
import numpy as np
@ -14,9 +12,12 @@ def interfaces():
class Interface1(Interface):
input_types = (list,)
return_type = tuple
priority = 1000
checked = False
@classmethod
def check(cls, array):
cls.checked = True
if isinstance(array, list):
return True
return False
@ -26,18 +27,34 @@ def interfaces():
return True
Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__))
Interface2.checked = False
Interface2.priority = 999
class Interface3(Interface1):
priority = 998
checked = False
@classmethod
def enabled(cls) -> bool:
return False
class Interface4(Interface3):
priority = 997
checked = False
@classmethod
def enabled(cls) -> bool:
return True
class Interfaces:
interface1 = Interface1
interface2 = Interface2
interface3 = Interface3
interface4 = Interface4
yield Interfaces
# Interface.__subclasses__().remove(Interface1)
# Interface.__subclasses__().remove(Interface2)
del Interface1
del Interface2
del Interface3
@ -66,6 +83,20 @@ def test_interface_match_error(interfaces):
assert "No matching interfaces" in str(e.value)
def test_interface_match_fast(interfaces):
"""
fast matching should return as soon as an interface is found
and not raise an error for duplicates
"""
Interface.interfaces()[0].checked = False
Interface.interfaces()[1].checked = False
# this doesnt' raise an error
matched = Interface.match([1, 2, 3], fast=True)
assert matched == Interface.interfaces()[0]
assert Interface.interfaces()[0].checked
assert not Interface.interfaces()[1].checked
def test_interface_enabled(interfaces):
"""
An interface shouldn't be included if it's not enabled
@ -101,3 +132,22 @@ def test_interfaces_sorting():
ifaces = Interface.interfaces()
priorities = [i.priority for i in ifaces]
assert (np.diff(priorities) <= 0).all()
def test_interface_with_disabled(interfaces):
"""
Get all interfaces, even if not enabled
"""
ifaces = Interface.interfaces(with_disabled=True)
assert interfaces.interface3 in ifaces
def test_interface_recursive(interfaces):
"""
Get all interfaces, including subclasses of subclasses
"""
ifaces = Interface.interfaces()
assert issubclass(interfaces.interface4, interfaces.interface3)
assert issubclass(interfaces.interface3, interfaces.interface1)
assert issubclass(interfaces.interface1, Interface)
assert interfaces.interface4 in ifaces