get inner object from object array to test arbitrary dtype

This commit is contained in:
sneakers-the-rat 2024-08-12 21:10:56 -07:00
parent dd9a8e959f
commit 90994b1ba1
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
3 changed files with 22 additions and 1 deletions

View file

@ -8,6 +8,7 @@ import numpy as np
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface
from numpydantic.types import DtypeType, NDArrayType
try:
from dask.array.core import Array as DaskArray
@ -30,6 +31,10 @@ class DaskInterface(Interface):
"""
return DaskArray is not None and isinstance(array, DaskArray)
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""Dask arrays require a compute() call to retrieve a single value"""
return type(array.ravel()[0].compute())
@classmethod
def enabled(cls) -> bool:
"""check if we successfully imported dask"""

View file

@ -101,8 +101,18 @@ class Interface(ABC, Generic[T]):
"""
Get the dtype from the input array
"""
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
return self.get_object_dtype(array)
else:
return array.dtype
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""
When an array contains an object, get the dtype of the object contained
by the array.
"""
return type(array.ravel()[0])
def validate_dtype(self, dtype: DtypeType) -> bool:
"""
Validate the dtype of the given array, returning

View file

@ -62,6 +62,10 @@ class BasicModel(BaseModel):
x: int
class BadModel(BaseModel):
x: int
RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
@ -137,6 +141,7 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=STRING, dtype=int, passes=False),
ValidationCase(annotation=STRING, dtype=float, passes=False),
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
ValidationCase(annotation=MODEL, dtype=int, passes=False),
],
ids=[
@ -162,6 +167,7 @@ def shape_cases(request) -> ValidationCase:
"str-int",
"str-float",
"model-model",
"model-badmodel",
"model-int",
],
)