mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
Merge pull request #7 from p2p-ld/dtype-subclasses
Check for dtype subclass
This commit is contained in:
commit
87e6226ccf
4 changed files with 18 additions and 2 deletions
|
@ -2,6 +2,10 @@
|
|||
|
||||
## 1.*
|
||||
|
||||
### 1.3.2 - 24-08-12 - Allow subclasses of dtypes
|
||||
|
||||
(also when using objects for dtypes, subclasses of that object are allowed to validate)
|
||||
|
||||
### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes
|
||||
|
||||
Previously we would only allow dtypes if we knew for sure that there was some
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "numpydantic"
|
||||
version = "1.3.1"
|
||||
version = "1.3.2"
|
||||
description = "Type and shape validation and serialization for numpy arrays in pydantic models"
|
||||
authors = [
|
||||
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
||||
|
|
|
@ -128,7 +128,13 @@ class Interface(ABC, Generic[T]):
|
|||
elif self.dtype is np.str_:
|
||||
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
|
||||
else:
|
||||
valid = dtype == self.dtype
|
||||
# try to match as any subclass, if self.dtype is a class
|
||||
try:
|
||||
valid = issubclass(dtype, self.dtype)
|
||||
except TypeError:
|
||||
# expected, if dtype or self.dtype is not a class
|
||||
valid = dtype == self.dtype
|
||||
|
||||
return valid
|
||||
|
||||
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
||||
|
|
|
@ -66,6 +66,10 @@ class BadModel(BaseModel):
|
|||
x: int
|
||||
|
||||
|
||||
class SubClass(BasicModel):
|
||||
pass
|
||||
|
||||
|
||||
RGB_UNION: TypeAlias = Union[
|
||||
NDArray[Shape["* x, * y"], Number],
|
||||
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||
|
@ -143,6 +147,7 @@ def shape_cases(request) -> ValidationCase:
|
|||
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
||||
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
||||
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
||||
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
|
||||
],
|
||||
ids=[
|
||||
"float",
|
||||
|
@ -169,6 +174,7 @@ def shape_cases(request) -> ValidationCase:
|
|||
"model-model",
|
||||
"model-badmodel",
|
||||
"model-int",
|
||||
"model-subclass",
|
||||
],
|
||||
)
|
||||
def dtype_cases(request) -> ValidationCase:
|
||||
|
|
Loading…
Reference in a new issue