allow arbitrary dtypes, and allow pydantic models as the inner type in json schema array creation

This commit is contained in:
sneakers-the-rat 2024-08-12 20:50:33 -07:00
parent 32db88fc1b
commit dd9a8e959f
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
8 changed files with 71 additions and 20 deletions

View file

@ -125,14 +125,10 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
check_type_names(dtype, dtype_per_name) check_type_names(dtype, dtype_per_name)
elif isinstance(dtype_candidate, tuple): # pragma: no cover elif isinstance(dtype_candidate, tuple): # pragma: no cover
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate]) dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
else: # pragma: no cover else:
raise InvalidArgumentsError( # arbitrary dtype - allow failure elsewhere :)
f"Unexpected argument '{dtype_candidate}', expecting" dtype = dtype_candidate
" Structure[<StructureExpression>]"
" or Literal[<StructureExpression>]"
" or a dtype"
" or typing.Any."
)
return dtype return dtype
def _dtype_to_str(cls, dtype: Any) -> str: def _dtype_to_str(cls, dtype: Any) -> str:

View file

@ -8,7 +8,7 @@ import json
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import BaseModel, SerializationInfo
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo from pydantic_core.core_schema import ListSchema, ValidationInfo
@ -66,18 +66,18 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
else: else:
try: try:
python_type = np_to_python[dtype] python_type = np_to_python[dtype]
except KeyError as e: # pragma: no cover except KeyError: # pragma: no cover
# this should pretty much only happen in downstream/3rd-party interfaces # this should pretty much only happen in downstream/3rd-party interfaces
# that use interface-specific types. those need to provide mappings back # that use interface-specific types. those need to provide mappings back
# to base python types (making this more streamlined is TODO) # to base python types (making this more streamlined is TODO)
if dtype in np_to_python.values(): if dtype in np_to_python.values():
# it's already a python type # it's already a python type
python_type = dtype python_type = dtype
elif issubclass(dtype, BaseModel):
python_type = dtype
else: else:
raise ValueError( # does this need a warning?
"dtype given in model does not have a corresponding python base " python_type = Any
"type - add one to the `maps.np_to_python` dict"
) from e
if python_type in _UNSUPPORTED_TYPES: if python_type in _UNSUPPORTED_TYPES:
array_type = core_schema.any_schema() array_type = core_schema.any_schema()

View file

@ -58,6 +58,10 @@ class ValidationCase(BaseModel):
return Model return Model
class BasicModel(BaseModel):
x: int
RGB_UNION: TypeAlias = Union[ RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number],
@ -68,6 +72,7 @@ NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str] STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
@pytest.fixture( @pytest.fixture(
@ -131,6 +136,8 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=STRING, dtype=str, passes=True), ValidationCase(annotation=STRING, dtype=str, passes=True),
ValidationCase(annotation=STRING, dtype=int, passes=False), ValidationCase(annotation=STRING, dtype=int, passes=False),
ValidationCase(annotation=STRING, dtype=float, passes=False), ValidationCase(annotation=STRING, dtype=float, passes=False),
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
ValidationCase(annotation=MODEL, dtype=int, passes=False),
], ],
ids=[ ids=[
"float", "float",
@ -154,6 +161,8 @@ def shape_cases(request) -> ValidationCase:
"str-str", "str-str",
"str-int", "str-int",
"str-float", "str-float",
"model-model",
"model-int",
], ],
) )
def dtype_cases(request) -> ValidationCase: def dtype_cases(request) -> ValidationCase:

View file

@ -4,7 +4,7 @@ import pytest
import json import json
import dask.array as da import dask.array as da
from pydantic import ValidationError from pydantic import BaseModel, ValidationError
from numpydantic.interface import DaskInterface from numpydantic.interface import DaskInterface
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import DtypeError, ShapeError
@ -13,6 +13,9 @@ from tests.conftest import ValidationCase
def dask_array(case: ValidationCase) -> da.Array: def dask_array(case: ValidationCase) -> da.Array:
if issubclass(case.dtype, BaseModel):
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
else:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)

View file

@ -20,6 +20,8 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
Returns: Returns:
""" """
if issubclass(case.dtype, BaseModel):
pytest.skip("hdf5 cant support arbitrary python objects")
return array_func(case.shape, case.dtype) return array_func(case.shape, case.dtype)

View file

@ -1,12 +1,15 @@
import numpy as np import numpy as np
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError, BaseModel
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import DtypeError, ShapeError
from tests.conftest import ValidationCase from tests.conftest import ValidationCase
def numpy_array(case: ValidationCase) -> np.ndarray: def numpy_array(case: ValidationCase) -> np.ndarray:
if issubclass(case.dtype, BaseModel):
return np.full(shape=case.shape, fill_value=case.dtype(x=1))
else:
return np.zeros(shape=case.shape, dtype=case.dtype) return np.zeros(shape=case.shape, dtype=case.dtype)

View file

@ -3,7 +3,9 @@ import json
import pytest import pytest
import zarr import zarr
from pydantic import ValidationError from pydantic import BaseModel, ValidationError
from numcodecs import Pickle
from numpydantic.interface import ZarrInterface from numpydantic.interface import ZarrInterface
from numpydantic.interface.zarr import ZarrArrayPath from numpydantic.interface.zarr import ZarrArrayPath
@ -31,6 +33,18 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore:
def _zarr_array(case: ValidationCase, store) -> zarr.core.Array: def _zarr_array(case: ValidationCase, store) -> zarr.core.Array:
if issubclass(case.dtype, BaseModel):
pytest.skip(
f"Zarr can't handle objects properly at the moment, "
"see https://github.com/zarr-developers/zarr-python/issues/2081"
)
# return zarr.full(
# shape=case.shape,
# fill_value=case.dtype(x=1),
# dtype=object,
# object_codec=Pickle(),
# )
else:
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)

View file

@ -266,6 +266,30 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
assert inner_type["type"] == expected assert inner_type["type"] == expected
def test_json_schema_dtype_model():
"""
Pydantic models can be used in arrays as dtypes
"""
class TestModel(BaseModel):
x: int
y: int
z: int
class MyModel(BaseModel):
array: NDArray[Shape["*, *"], TestModel]
schema = MyModel.model_json_schema()
# we should have a "$defs" with TestModel in it,
# and our array should be objects of that type
assert schema["properties"]["array"]["items"]["items"] == {
"$ref": "#/$defs/TestModel"
}
# we don't test pydantic' generic json schema model generation,
# just that one was defined
assert "TestModel" in schema["$defs"]
def _recursive_array(schema): def _recursive_array(schema):
assert "$defs" in schema assert "$defs" in schema
# get the key uses for the array # get the key uses for the array