diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index 42fc3f8..8756ae0 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -125,14 +125,10 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): check_type_names(dtype, dtype_per_name) elif isinstance(dtype_candidate, tuple): # pragma: no cover dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate]) - else: # pragma: no cover - raise InvalidArgumentsError( - f"Unexpected argument '{dtype_candidate}', expecting" - " Structure[]" - " or Literal[]" - " or a dtype" - " or typing.Any." - ) + else: + # arbitrary dtype - allow failure elsewhere :) + dtype = dtype_candidate + return dtype def _dtype_to_str(cls, dtype: Any) -> str: diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index 9636190..552c27a 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -8,7 +8,7 @@ import json from typing import TYPE_CHECKING, Any, Callable, Optional, Union import numpy as np -from pydantic import SerializationInfo +from pydantic import BaseModel, SerializationInfo from pydantic_core import CoreSchema, core_schema from pydantic_core.core_schema import ListSchema, ValidationInfo @@ -66,18 +66,18 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: else: try: python_type = np_to_python[dtype] - except KeyError as e: # pragma: no cover + except KeyError: # pragma: no cover # this should pretty much only happen in downstream/3rd-party interfaces # that use interface-specific types. those need to provide mappings back # to base python types (making this more streamlined is TODO) if dtype in np_to_python.values(): # it's already a python type python_type = dtype + elif issubclass(dtype, BaseModel): + python_type = dtype else: - raise ValueError( - "dtype given in model does not have a corresponding python base " - "type - add one to the `maps.np_to_python` dict" - ) from e + # does this need a warning? + python_type = Any if python_type in _UNSUPPORTED_TYPES: array_type = core_schema.any_schema() diff --git a/tests/conftest.py b/tests/conftest.py index 2292dd1..af3a48e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,10 @@ class ValidationCase(BaseModel): return Model +class BasicModel(BaseModel): + x: int + + RGB_UNION: TypeAlias = Union[ NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number], @@ -68,6 +72,7 @@ NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number] INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] STRING: TypeAlias = NDArray[Shape["*, *, *"], str] +MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] @pytest.fixture( @@ -131,6 +136,8 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=STRING, dtype=str, passes=True), ValidationCase(annotation=STRING, dtype=int, passes=False), ValidationCase(annotation=STRING, dtype=float, passes=False), + ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), + ValidationCase(annotation=MODEL, dtype=int, passes=False), ], ids=[ "float", @@ -154,6 +161,8 @@ def shape_cases(request) -> ValidationCase: "str-str", "str-int", "str-float", + "model-model", + "model-int", ], ) def dtype_cases(request) -> ValidationCase: diff --git a/tests/test_interface/test_dask.py b/tests/test_interface/test_dask.py index 6f7a8ac..fb1e4cb 100644 --- a/tests/test_interface/test_dask.py +++ b/tests/test_interface/test_dask.py @@ -4,7 +4,7 @@ import pytest import json import dask.array as da -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from numpydantic.interface import DaskInterface from numpydantic.exceptions import DtypeError, ShapeError @@ -13,7 +13,10 @@ from tests.conftest import ValidationCase def dask_array(case: ValidationCase) -> da.Array: - return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) + if issubclass(case.dtype, BaseModel): + return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1) + else: + return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) def _test_dask_case(case: ValidationCase): diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index f12bd87..78af785 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -20,6 +20,8 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath: Returns: """ + if issubclass(case.dtype, BaseModel): + pytest.skip("hdf5 cant support arbitrary python objects") return array_func(case.shape, case.dtype) diff --git a/tests/test_interface/test_numpy.py b/tests/test_interface/test_numpy.py index 1ab6208..6a34b98 100644 --- a/tests/test_interface/test_numpy.py +++ b/tests/test_interface/test_numpy.py @@ -1,13 +1,16 @@ import numpy as np import pytest -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel from numpydantic.exceptions import DtypeError, ShapeError from tests.conftest import ValidationCase def numpy_array(case: ValidationCase) -> np.ndarray: - return np.zeros(shape=case.shape, dtype=case.dtype) + if issubclass(case.dtype, BaseModel): + return np.full(shape=case.shape, fill_value=case.dtype(x=1)) + else: + return np.zeros(shape=case.shape, dtype=case.dtype) def _test_np_case(case: ValidationCase): diff --git a/tests/test_interface/test_zarr.py b/tests/test_interface/test_zarr.py index eab3e52..2e465f2 100644 --- a/tests/test_interface/test_zarr.py +++ b/tests/test_interface/test_zarr.py @@ -3,7 +3,9 @@ import json import pytest import zarr -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError +from numcodecs import Pickle + from numpydantic.interface import ZarrInterface from numpydantic.interface.zarr import ZarrArrayPath @@ -31,7 +33,19 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore: def _zarr_array(case: ValidationCase, store) -> zarr.core.Array: - return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) + if issubclass(case.dtype, BaseModel): + pytest.skip( + f"Zarr can't handle objects properly at the moment, " + "see https://github.com/zarr-developers/zarr-python/issues/2081" + ) + # return zarr.full( + # shape=case.shape, + # fill_value=case.dtype(x=1), + # dtype=object, + # object_codec=Pickle(), + # ) + else: + return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) def _test_zarr_case(case: ValidationCase, store): diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 7b512ef..9883c2a 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -266,6 +266,30 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model): assert inner_type["type"] == expected +def test_json_schema_dtype_model(): + """ + Pydantic models can be used in arrays as dtypes + """ + + class TestModel(BaseModel): + x: int + y: int + z: int + + class MyModel(BaseModel): + array: NDArray[Shape["*, *"], TestModel] + + schema = MyModel.model_json_schema() + # we should have a "$defs" with TestModel in it, + # and our array should be objects of that type + assert schema["properties"]["array"]["items"]["items"] == { + "$ref": "#/$defs/TestModel" + } + # we don't test pydantic' generic json schema model generation, + # just that one was defined + assert "TestModel" in schema["$defs"] + + def _recursive_array(schema): assert "$defs" in schema # get the key uses for the array