mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
get inner object from object array to test arbitrary dtype
This commit is contained in:
parent
dd9a8e959f
commit
90994b1ba1
3 changed files with 22 additions and 1 deletions
|
@ -8,6 +8,7 @@ import numpy as np
|
|||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
||||
try:
|
||||
from dask.array.core import Array as DaskArray
|
||||
|
@ -30,6 +31,10 @@ class DaskInterface(Interface):
|
|||
"""
|
||||
return DaskArray is not None and isinstance(array, DaskArray)
|
||||
|
||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||
return type(array.ravel()[0].compute())
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
"""check if we successfully imported dask"""
|
||||
|
|
|
@ -101,7 +101,17 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
Get the dtype from the input array
|
||||
"""
|
||||
return array.dtype
|
||||
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
|
||||
return self.get_object_dtype(array)
|
||||
else:
|
||||
return array.dtype
|
||||
|
||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""
|
||||
When an array contains an object, get the dtype of the object contained
|
||||
by the array.
|
||||
"""
|
||||
return type(array.ravel()[0])
|
||||
|
||||
def validate_dtype(self, dtype: DtypeType) -> bool:
|
||||
"""
|
||||
|
|
|
@ -62,6 +62,10 @@ class BasicModel(BaseModel):
|
|||
x: int
|
||||
|
||||
|
||||
class BadModel(BaseModel):
|
||||
x: int
|
||||
|
||||
|
||||
RGB_UNION: TypeAlias = Union[
|
||||
NDArray[Shape["* x, * y"], Number],
|
||||
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||
|
@ -137,6 +141,7 @@ def shape_cases(request) -> ValidationCase:
|
|||
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
||||
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
||||
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
||||
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
||||
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
||||
],
|
||||
ids=[
|
||||
|
@ -162,6 +167,7 @@ def shape_cases(request) -> ValidationCase:
|
|||
"str-int",
|
||||
"str-float",
|
||||
"model-model",
|
||||
"model-badmodel",
|
||||
"model-int",
|
||||
],
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue