From a8ba11f772d5b92ad8f4fff797eb5fbdea2768be Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 12 Aug 2024 21:35:01 -0700 Subject: [PATCH] update changelog, bump version --- src/numpydantic/interface/interface.py | 8 +++++++- tests/conftest.py | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 3030220..3dc3fdc 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -128,7 +128,13 @@ class Interface(ABC, Generic[T]): elif self.dtype is np.str_: valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_ else: - valid = dtype == self.dtype + # try to match as any subclass, if self.dtype is a class + try: + valid = issubclass(dtype, self.dtype) + except TypeError: + # expected, if dtype or self.dtype is not a class + valid = dtype == self.dtype + return valid def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 0655362..0467f25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,6 +66,10 @@ class BadModel(BaseModel): x: int +class SubClass(BasicModel): + pass + + RGB_UNION: TypeAlias = Union[ NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number], @@ -143,6 +147,7 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False), + ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), ], ids=[ "float", @@ -169,6 +174,7 @@ def shape_cases(request) -> ValidationCase: "model-model", "model-badmodel", "model-int", + "model-subclass", ], ) def dtype_cases(request) -> ValidationCase: