From a8ba11f772d5b92ad8f4fff797eb5fbdea2768be Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 12 Aug 2024 21:35:01 -0700 Subject: [PATCH 1/2] update changelog, bump version --- src/numpydantic/interface/interface.py | 8 +++++++- tests/conftest.py | 6 ++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 3030220..3dc3fdc 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -128,7 +128,13 @@ class Interface(ABC, Generic[T]): elif self.dtype is np.str_: valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_ else: - valid = dtype == self.dtype + # try to match as any subclass, if self.dtype is a class + try: + valid = issubclass(dtype, self.dtype) + except TypeError: + # expected, if dtype or self.dtype is not a class + valid = dtype == self.dtype + return valid def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 0655362..0467f25 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,6 +66,10 @@ class BadModel(BaseModel): x: int +class SubClass(BasicModel): + pass + + RGB_UNION: TypeAlias = Union[ NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number], @@ -143,6 +147,7 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False), + ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), ], ids=[ "float", @@ -169,6 +174,7 @@ def shape_cases(request) -> ValidationCase: "model-model", "model-badmodel", "model-int", + "model-subclass", ], ) def dtype_cases(request) -> ValidationCase: From f63bd9c1714b6a6bb815a8e4bb1c7d16057ab189 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 12 Aug 2024 21:36:57 -0700 Subject: [PATCH 2/2] update changelog, bump version --- docs/changelog.md | 4 ++++ pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index af19479..938d9bd 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,10 @@ ## 1.* +### 1.3.2 - 24-08-12 - Allow subclasses of dtypes + +(also when using objects for dtypes, subclasses of that object are allowed to validate) + ### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes Previously we would only allow dtypes if we knew for sure that there was some diff --git a/pyproject.toml b/pyproject.toml index 3f14e28..e1ffa3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "numpydantic" -version = "1.3.1" +version = "1.3.2" description = "Type and shape validation and serialization for numpy arrays in pydantic models" authors = [ {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},