Merge pull request #7 from p2p-ld/dtype-subclasses

Check for dtype subclass
This commit is contained in:
Jonny Saunders 2024-08-12 21:37:23 -07:00 committed by GitHub
commit 87e6226ccf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 2 deletions

View file

@ -2,6 +2,10 @@
## 1.* ## 1.*
### 1.3.2 - 24-08-12 - Allow subclasses of dtypes
(also when using objects for dtypes, subclasses of that object are allowed to validate)
### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes ### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes
Previously we would only allow dtypes if we knew for sure that there was some Previously we would only allow dtypes if we knew for sure that there was some

View file

@ -1,6 +1,6 @@
[project] [project]
name = "numpydantic" name = "numpydantic"
version = "1.3.1" version = "1.3.2"
description = "Type and shape validation and serialization for numpy arrays in pydantic models" description = "Type and shape validation and serialization for numpy arrays in pydantic models"
authors = [ authors = [
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},

View file

@ -128,7 +128,13 @@ class Interface(ABC, Generic[T]):
elif self.dtype is np.str_: elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_ valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
else: else:
valid = dtype == self.dtype # try to match as any subclass, if self.dtype is a class
try:
valid = issubclass(dtype, self.dtype)
except TypeError:
# expected, if dtype or self.dtype is not a class
valid = dtype == self.dtype
return valid return valid
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:

View file

@ -66,6 +66,10 @@ class BadModel(BaseModel):
x: int x: int
class SubClass(BasicModel):
pass
RGB_UNION: TypeAlias = Union[ RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number],
@ -143,6 +147,7 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
ValidationCase(annotation=MODEL, dtype=int, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False),
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
], ],
ids=[ ids=[
"float", "float",
@ -169,6 +174,7 @@ def shape_cases(request) -> ValidationCase:
"model-model", "model-model",
"model-badmodel", "model-badmodel",
"model-int", "model-int",
"model-subclass",
], ],
) )
def dtype_cases(request) -> ValidationCase: def dtype_cases(request) -> ValidationCase: