mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
testing for fast matching mode. recursively get interface classes including disabled classes
This commit is contained in:
parent
1290d64833
commit
025832aa3d
3 changed files with 80 additions and 9 deletions
|
@ -144,17 +144,38 @@ class Interface(ABC, Generic[T]):
|
||||||
return array.tolist()
|
return array.tolist()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def interfaces(cls) -> Tuple[Type["Interface"], ...]:
|
def interfaces(
|
||||||
|
cls, with_disabled: bool = False, sort: bool = True
|
||||||
|
) -> Tuple[Type["Interface"], ...]:
|
||||||
"""
|
"""
|
||||||
Enabled interface subclasses
|
Enabled interface subclasses
|
||||||
|
|
||||||
|
Args:
|
||||||
|
with_disabled (bool): If ``True`` , get every known interface.
|
||||||
|
If ``False`` (default), get only enabled interfaces.
|
||||||
|
sort (bool): If ``True`` (default), sort interfaces by priority.
|
||||||
|
If ``False`` , sorted by definition order. Used for recursion:
|
||||||
|
we only want to sort once at the top level.
|
||||||
"""
|
"""
|
||||||
return tuple(
|
# get recursively
|
||||||
sorted(
|
subclasses = []
|
||||||
[i for i in Interface.__subclasses__() if i.enabled()],
|
for i in cls.__subclasses__():
|
||||||
|
if with_disabled:
|
||||||
|
subclasses.append(i)
|
||||||
|
|
||||||
|
if i.enabled():
|
||||||
|
subclasses.append(i)
|
||||||
|
|
||||||
|
subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
|
||||||
|
|
||||||
|
if sort:
|
||||||
|
subclasses = sorted(
|
||||||
|
subclasses,
|
||||||
key=attrgetter("priority"),
|
key=attrgetter("priority"),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
return tuple(subclasses)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def return_types(cls) -> Tuple[NDArrayType, ...]:
|
def return_types(cls) -> Tuple[NDArrayType, ...]:
|
||||||
|
|
|
@ -39,7 +39,7 @@ from numpydantic.schema import (
|
||||||
)
|
)
|
||||||
from numpydantic.types import DtypeType, ShapeType
|
from numpydantic.types import DtypeType, ShapeType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from nptyping.base_meta_classes import SubscriptableMeta
|
from nptyping.base_meta_classes import SubscriptableMeta
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
completion of the transition away from nptyping
|
completion of the transition away from nptyping
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
__getitem__ = SubscriptableMeta.__getitem__
|
__getitem__ = SubscriptableMeta.__getitem__
|
||||||
|
|
||||||
def __instancecheck__(self, instance: Any):
|
def __instancecheck__(self, instance: Any):
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import pdb
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -14,9 +12,12 @@ def interfaces():
|
||||||
class Interface1(Interface):
|
class Interface1(Interface):
|
||||||
input_types = (list,)
|
input_types = (list,)
|
||||||
return_type = tuple
|
return_type = tuple
|
||||||
|
priority = 1000
|
||||||
|
checked = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check(cls, array):
|
def check(cls, array):
|
||||||
|
cls.checked = True
|
||||||
if isinstance(array, list):
|
if isinstance(array, list):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@ -26,18 +27,34 @@ def interfaces():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__))
|
Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__))
|
||||||
|
Interface2.checked = False
|
||||||
|
Interface2.priority = 999
|
||||||
|
|
||||||
class Interface3(Interface1):
|
class Interface3(Interface1):
|
||||||
|
priority = 998
|
||||||
|
checked = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class Interface4(Interface3):
|
||||||
|
priority = 997
|
||||||
|
checked = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
class Interfaces:
|
class Interfaces:
|
||||||
interface1 = Interface1
|
interface1 = Interface1
|
||||||
interface2 = Interface2
|
interface2 = Interface2
|
||||||
interface3 = Interface3
|
interface3 = Interface3
|
||||||
|
interface4 = Interface4
|
||||||
|
|
||||||
yield Interfaces
|
yield Interfaces
|
||||||
|
# Interface.__subclasses__().remove(Interface1)
|
||||||
|
# Interface.__subclasses__().remove(Interface2)
|
||||||
del Interface1
|
del Interface1
|
||||||
del Interface2
|
del Interface2
|
||||||
del Interface3
|
del Interface3
|
||||||
|
@ -66,6 +83,20 @@ def test_interface_match_error(interfaces):
|
||||||
assert "No matching interfaces" in str(e.value)
|
assert "No matching interfaces" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_match_fast(interfaces):
|
||||||
|
"""
|
||||||
|
fast matching should return as soon as an interface is found
|
||||||
|
and not raise an error for duplicates
|
||||||
|
"""
|
||||||
|
Interface.interfaces()[0].checked = False
|
||||||
|
Interface.interfaces()[1].checked = False
|
||||||
|
# this doesnt' raise an error
|
||||||
|
matched = Interface.match([1, 2, 3], fast=True)
|
||||||
|
assert matched == Interface.interfaces()[0]
|
||||||
|
assert Interface.interfaces()[0].checked
|
||||||
|
assert not Interface.interfaces()[1].checked
|
||||||
|
|
||||||
|
|
||||||
def test_interface_enabled(interfaces):
|
def test_interface_enabled(interfaces):
|
||||||
"""
|
"""
|
||||||
An interface shouldn't be included if it's not enabled
|
An interface shouldn't be included if it's not enabled
|
||||||
|
@ -101,3 +132,22 @@ def test_interfaces_sorting():
|
||||||
ifaces = Interface.interfaces()
|
ifaces = Interface.interfaces()
|
||||||
priorities = [i.priority for i in ifaces]
|
priorities = [i.priority for i in ifaces]
|
||||||
assert (np.diff(priorities) <= 0).all()
|
assert (np.diff(priorities) <= 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_with_disabled(interfaces):
|
||||||
|
"""
|
||||||
|
Get all interfaces, even if not enabled
|
||||||
|
"""
|
||||||
|
ifaces = Interface.interfaces(with_disabled=True)
|
||||||
|
assert interfaces.interface3 in ifaces
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_recursive(interfaces):
|
||||||
|
"""
|
||||||
|
Get all interfaces, including subclasses of subclasses
|
||||||
|
"""
|
||||||
|
ifaces = Interface.interfaces()
|
||||||
|
assert issubclass(interfaces.interface4, interfaces.interface3)
|
||||||
|
assert issubclass(interfaces.interface3, interfaces.interface1)
|
||||||
|
assert issubclass(interfaces.interface1, Interface)
|
||||||
|
assert interfaces.interface4 in ifaces
|
||||||
|
|
Loading…
Reference in a new issue