From dd9a8e959fa1fb09f116f7a0709e2e5cc1db1bdf Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 12 Aug 2024 20:50:33 -0700 Subject: [PATCH 1/3] allow arbitrary dtypes, and allow pydantic models as the inner type in json schema array creation --- src/numpydantic/ndarray.py | 12 ++++-------- src/numpydantic/schema.py | 12 ++++++------ tests/conftest.py | 9 +++++++++ tests/test_interface/test_dask.py | 7 +++++-- tests/test_interface/test_hdf5.py | 2 ++ tests/test_interface/test_numpy.py | 7 +++++-- tests/test_interface/test_zarr.py | 18 ++++++++++++++++-- tests/test_ndarray.py | 24 ++++++++++++++++++++++++ 8 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index 42fc3f8..8756ae0 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -125,14 +125,10 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): check_type_names(dtype, dtype_per_name) elif isinstance(dtype_candidate, tuple): # pragma: no cover dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate]) - else: # pragma: no cover - raise InvalidArgumentsError( - f"Unexpected argument '{dtype_candidate}', expecting" - " Structure[]" - " or Literal[]" - " or a dtype" - " or typing.Any." - ) + else: + # arbitrary dtype - allow failure elsewhere :) + dtype = dtype_candidate + return dtype def _dtype_to_str(cls, dtype: Any) -> str: diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index 9636190..552c27a 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -8,7 +8,7 @@ import json from typing import TYPE_CHECKING, Any, Callable, Optional, Union import numpy as np -from pydantic import SerializationInfo +from pydantic import BaseModel, SerializationInfo from pydantic_core import CoreSchema, core_schema from pydantic_core.core_schema import ListSchema, ValidationInfo @@ -66,18 +66,18 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: else: try: python_type = np_to_python[dtype] - except KeyError as e: # pragma: no cover + except KeyError: # pragma: no cover # this should pretty much only happen in downstream/3rd-party interfaces # that use interface-specific types. those need to provide mappings back # to base python types (making this more streamlined is TODO) if dtype in np_to_python.values(): # it's already a python type python_type = dtype + elif issubclass(dtype, BaseModel): + python_type = dtype else: - raise ValueError( - "dtype given in model does not have a corresponding python base " - "type - add one to the `maps.np_to_python` dict" - ) from e + # does this need a warning? + python_type = Any if python_type in _UNSUPPORTED_TYPES: array_type = core_schema.any_schema() diff --git a/tests/conftest.py b/tests/conftest.py index 2292dd1..af3a48e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,6 +58,10 @@ class ValidationCase(BaseModel): return Model +class BasicModel(BaseModel): + x: int + + RGB_UNION: TypeAlias = Union[ NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number], @@ -68,6 +72,7 @@ NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number] INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] STRING: TypeAlias = NDArray[Shape["*, *, *"], str] +MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] @pytest.fixture( @@ -131,6 +136,8 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=STRING, dtype=str, passes=True), ValidationCase(annotation=STRING, dtype=int, passes=False), ValidationCase(annotation=STRING, dtype=float, passes=False), + ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), + ValidationCase(annotation=MODEL, dtype=int, passes=False), ], ids=[ "float", @@ -154,6 +161,8 @@ def shape_cases(request) -> ValidationCase: "str-str", "str-int", "str-float", + "model-model", + "model-int", ], ) def dtype_cases(request) -> ValidationCase: diff --git a/tests/test_interface/test_dask.py b/tests/test_interface/test_dask.py index 6f7a8ac..fb1e4cb 100644 --- a/tests/test_interface/test_dask.py +++ b/tests/test_interface/test_dask.py @@ -4,7 +4,7 @@ import pytest import json import dask.array as da -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from numpydantic.interface import DaskInterface from numpydantic.exceptions import DtypeError, ShapeError @@ -13,7 +13,10 @@ from tests.conftest import ValidationCase def dask_array(case: ValidationCase) -> da.Array: - return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) + if issubclass(case.dtype, BaseModel): + return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1) + else: + return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) def _test_dask_case(case: ValidationCase): diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index f12bd87..78af785 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -20,6 +20,8 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath: Returns: """ + if issubclass(case.dtype, BaseModel): + pytest.skip("hdf5 cant support arbitrary python objects") return array_func(case.shape, case.dtype) diff --git a/tests/test_interface/test_numpy.py b/tests/test_interface/test_numpy.py index 1ab6208..6a34b98 100644 --- a/tests/test_interface/test_numpy.py +++ b/tests/test_interface/test_numpy.py @@ -1,13 +1,16 @@ import numpy as np import pytest -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel from numpydantic.exceptions import DtypeError, ShapeError from tests.conftest import ValidationCase def numpy_array(case: ValidationCase) -> np.ndarray: - return np.zeros(shape=case.shape, dtype=case.dtype) + if issubclass(case.dtype, BaseModel): + return np.full(shape=case.shape, fill_value=case.dtype(x=1)) + else: + return np.zeros(shape=case.shape, dtype=case.dtype) def _test_np_case(case: ValidationCase): diff --git a/tests/test_interface/test_zarr.py b/tests/test_interface/test_zarr.py index eab3e52..2e465f2 100644 --- a/tests/test_interface/test_zarr.py +++ b/tests/test_interface/test_zarr.py @@ -3,7 +3,9 @@ import json import pytest import zarr -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError +from numcodecs import Pickle + from numpydantic.interface import ZarrInterface from numpydantic.interface.zarr import ZarrArrayPath @@ -31,7 +33,19 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore: def _zarr_array(case: ValidationCase, store) -> zarr.core.Array: - return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) + if issubclass(case.dtype, BaseModel): + pytest.skip( + f"Zarr can't handle objects properly at the moment, " + "see https://github.com/zarr-developers/zarr-python/issues/2081" + ) + # return zarr.full( + # shape=case.shape, + # fill_value=case.dtype(x=1), + # dtype=object, + # object_codec=Pickle(), + # ) + else: + return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) def _test_zarr_case(case: ValidationCase, store): diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 7b512ef..9883c2a 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -266,6 +266,30 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model): assert inner_type["type"] == expected +def test_json_schema_dtype_model(): + """ + Pydantic models can be used in arrays as dtypes + """ + + class TestModel(BaseModel): + x: int + y: int + z: int + + class MyModel(BaseModel): + array: NDArray[Shape["*, *"], TestModel] + + schema = MyModel.model_json_schema() + # we should have a "$defs" with TestModel in it, + # and our array should be objects of that type + assert schema["properties"]["array"]["items"]["items"] == { + "$ref": "#/$defs/TestModel" + } + # we don't test pydantic' generic json schema model generation, + # just that one was defined + assert "TestModel" in schema["$defs"] + + def _recursive_array(schema): assert "$defs" in schema # get the key uses for the array From 90994b1ba143f8c746cf451352c7c851a635ba47 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 12 Aug 2024 21:10:56 -0700 Subject: [PATCH 2/3] get inner object from object array to test arbitrary dtype --- src/numpydantic/interface/dask.py | 5 +++++ src/numpydantic/interface/interface.py | 12 +++++++++++- tests/conftest.py | 6 ++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index d334a0b..7719e98 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -8,6 +8,7 @@ import numpy as np from pydantic import SerializationInfo from numpydantic.interface.interface import Interface +from numpydantic.types import DtypeType, NDArrayType try: from dask.array.core import Array as DaskArray @@ -30,6 +31,10 @@ class DaskInterface(Interface): """ return DaskArray is not None and isinstance(array, DaskArray) + def get_object_dtype(self, array: NDArrayType) -> DtypeType: + """Dask arrays require a compute() call to retrieve a single value""" + return type(array.ravel()[0].compute()) + @classmethod def enabled(cls) -> bool: """check if we successfully imported dask""" diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 832fe83..3030220 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -101,7 +101,17 @@ class Interface(ABC, Generic[T]): """ Get the dtype from the input array """ - return array.dtype + if hasattr(array.dtype, "type") and array.dtype.type is np.object_: + return self.get_object_dtype(array) + else: + return array.dtype + + def get_object_dtype(self, array: NDArrayType) -> DtypeType: + """ + When an array contains an object, get the dtype of the object contained + by the array. + """ + return type(array.ravel()[0]) def validate_dtype(self, dtype: DtypeType) -> bool: """ diff --git a/tests/conftest.py b/tests/conftest.py index af3a48e..0655362 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,6 +62,10 @@ class BasicModel(BaseModel): x: int +class BadModel(BaseModel): + x: int + + RGB_UNION: TypeAlias = Union[ NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number], @@ -137,6 +141,7 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=STRING, dtype=int, passes=False), ValidationCase(annotation=STRING, dtype=float, passes=False), ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), + ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False), ], ids=[ @@ -162,6 +167,7 @@ def shape_cases(request) -> ValidationCase: "str-int", "str-float", "model-model", + "model-badmodel", "model-int", ], ) From bd8b075561b93f14ad4536f880749bd20dc40b8a Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 12 Aug 2024 21:14:39 -0700 Subject: [PATCH 3/3] update changelog, bump version --- docs/changelog.md | 17 +++++++++++++++++ pyproject.toml | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index 8a0a566..af19479 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,23 @@ ## 1.* +### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes + +Previously we would only allow dtypes if we knew for sure that there was some +python base type to generate a schema with. + +That seems overly restrictive, so relax the requirements to allow +any type to be a dtype. If there are problems with serialization (we assume there will) +or handling the object in a given array framework, we leave that up to the person +who declared the model to handle :). Let people break things and have fun! + +Also support the ability to use a pydantic model as the inner type, which works +as expected because pydantic already knows how to generate a schema from its own models. + +Only one substantial change, and that is a `get_object_dtype` method which +interfaces can override if there is some fancy way they have of getting +types/items from an object array. + ### 1.3.0 - 24-08-05 - Better string dtype handling API Changes: diff --git a/pyproject.toml b/pyproject.toml index 3a1ebf1..3f14e28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "numpydantic" -version = "1.3.0" +version = "1.3.1" description = "Type and shape validation and serialization for numpy arrays in pydantic models" authors = [ {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},