diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index d334a0b..7719e98 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -8,6 +8,7 @@ import numpy as np from pydantic import SerializationInfo from numpydantic.interface.interface import Interface +from numpydantic.types import DtypeType, NDArrayType try: from dask.array.core import Array as DaskArray @@ -30,6 +31,10 @@ class DaskInterface(Interface): """ return DaskArray is not None and isinstance(array, DaskArray) + def get_object_dtype(self, array: NDArrayType) -> DtypeType: + """Dask arrays require a compute() call to retrieve a single value""" + return type(array.ravel()[0].compute()) + @classmethod def enabled(cls) -> bool: """check if we successfully imported dask""" diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 832fe83..3030220 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -101,7 +101,17 @@ class Interface(ABC, Generic[T]): """ Get the dtype from the input array """ - return array.dtype + if hasattr(array.dtype, "type") and array.dtype.type is np.object_: + return self.get_object_dtype(array) + else: + return array.dtype + + def get_object_dtype(self, array: NDArrayType) -> DtypeType: + """ + When an array contains an object, get the dtype of the object contained + by the array. + """ + return type(array.ravel()[0]) def validate_dtype(self, dtype: DtypeType) -> bool: """ diff --git a/tests/conftest.py b/tests/conftest.py index af3a48e..0655362 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,6 +62,10 @@ class BasicModel(BaseModel): x: int +class BadModel(BaseModel): + x: int + + RGB_UNION: TypeAlias = Union[ NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number], @@ -137,6 +141,7 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=STRING, dtype=int, passes=False), ValidationCase(annotation=STRING, dtype=float, passes=False), ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), + ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False), ], ids=[ @@ -162,6 +167,7 @@ def shape_cases(request) -> ValidationCase: "str-int", "str-float", "model-model", + "model-badmodel", "model-int", ], )