Merge pull request #7 from p2p-ld/dtype-subclasses

Check for dtype subclass
This commit is contained in:
Jonny Saunders 2024-08-12 21:37:23 -07:00 committed by GitHub
commit 87e6226ccf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 18 additions and 2 deletions

View file

@ -2,6 +2,10 @@
## 1.*
### 1.3.2 - 24-08-12 - Allow subclasses of dtypes
(also when using objects for dtypes, subclasses of that object are allowed to validate)
### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes
Previously we would only allow dtypes if we knew for sure that there was some

View file

@ -1,6 +1,6 @@
[project]
name = "numpydantic"
version = "1.3.1"
version = "1.3.2"
description = "Type and shape validation and serialization for numpy arrays in pydantic models"
authors = [
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},

View file

@ -128,7 +128,13 @@ class Interface(ABC, Generic[T]):
elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
else:
valid = dtype == self.dtype
# try to match as any subclass, if self.dtype is a class
try:
valid = issubclass(dtype, self.dtype)
except TypeError:
# expected, if dtype or self.dtype is not a class
valid = dtype == self.dtype
return valid
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:

View file

@ -66,6 +66,10 @@ class BadModel(BaseModel):
x: int
class SubClass(BasicModel):
pass
RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
@ -143,6 +147,7 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
ValidationCase(annotation=MODEL, dtype=int, passes=False),
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
],
ids=[
"float",
@ -169,6 +174,7 @@ def shape_cases(request) -> ValidationCase:
"model-model",
"model-badmodel",
"model-int",
"model-subclass",
],
)
def dtype_cases(request) -> ValidationCase: