mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
get inner object from object array to test arbitrary dtype
This commit is contained in:
parent
dd9a8e959f
commit
90994b1ba1
3 changed files with 22 additions and 1 deletions
|
@ -8,6 +8,7 @@ import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface
|
||||||
|
from numpydantic.types import DtypeType, NDArrayType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dask.array.core import Array as DaskArray
|
from dask.array.core import Array as DaskArray
|
||||||
|
@ -30,6 +31,10 @@ class DaskInterface(Interface):
|
||||||
"""
|
"""
|
||||||
return DaskArray is not None and isinstance(array, DaskArray)
|
return DaskArray is not None and isinstance(array, DaskArray)
|
||||||
|
|
||||||
|
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||||
|
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||||
|
return type(array.ravel()[0].compute())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
"""check if we successfully imported dask"""
|
"""check if we successfully imported dask"""
|
||||||
|
|
|
@ -101,8 +101,18 @@ class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
Get the dtype from the input array
|
Get the dtype from the input array
|
||||||
"""
|
"""
|
||||||
|
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
|
||||||
|
return self.get_object_dtype(array)
|
||||||
|
else:
|
||||||
return array.dtype
|
return array.dtype
|
||||||
|
|
||||||
|
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||||
|
"""
|
||||||
|
When an array contains an object, get the dtype of the object contained
|
||||||
|
by the array.
|
||||||
|
"""
|
||||||
|
return type(array.ravel()[0])
|
||||||
|
|
||||||
def validate_dtype(self, dtype: DtypeType) -> bool:
|
def validate_dtype(self, dtype: DtypeType) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate the dtype of the given array, returning
|
Validate the dtype of the given array, returning
|
||||||
|
|
|
@ -62,6 +62,10 @@ class BasicModel(BaseModel):
|
||||||
x: int
|
x: int
|
||||||
|
|
||||||
|
|
||||||
|
class BadModel(BaseModel):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
|
||||||
RGB_UNION: TypeAlias = Union[
|
RGB_UNION: TypeAlias = Union[
|
||||||
NDArray[Shape["* x, * y"], Number],
|
NDArray[Shape["* x, * y"], Number],
|
||||||
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||||
|
@ -137,6 +141,7 @@ def shape_cases(request) -> ValidationCase:
|
||||||
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
||||||
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
||||||
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
||||||
|
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
||||||
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
|
@ -162,6 +167,7 @@ def shape_cases(request) -> ValidationCase:
|
||||||
"str-int",
|
"str-int",
|
||||||
"str-float",
|
"str-float",
|
||||||
"model-model",
|
"model-model",
|
||||||
|
"model-badmodel",
|
||||||
"model-int",
|
"model-int",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue