allow arbitrary dtypes, and allow pydantic models as the inner type in json schema array creation

This commit is contained in:
sneakers-the-rat 2024-08-12 20:50:33 -07:00
parent 32db88fc1b
commit dd9a8e959f
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
8 changed files with 71 additions and 20 deletions

View file

@ -125,14 +125,10 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
check_type_names(dtype, dtype_per_name)
elif isinstance(dtype_candidate, tuple): # pragma: no cover
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
else: # pragma: no cover
raise InvalidArgumentsError(
f"Unexpected argument '{dtype_candidate}', expecting"
" Structure[<StructureExpression>]"
" or Literal[<StructureExpression>]"
" or a dtype"
" or typing.Any."
)
else:
# arbitrary dtype - allow failure elsewhere :)
dtype = dtype_candidate
return dtype
def _dtype_to_str(cls, dtype: Any) -> str:

View file

@ -8,7 +8,7 @@ import json
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import numpy as np
from pydantic import SerializationInfo
from pydantic import BaseModel, SerializationInfo
from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo
@ -66,18 +66,18 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
else:
try:
python_type = np_to_python[dtype]
except KeyError as e: # pragma: no cover
except KeyError: # pragma: no cover
# this should pretty much only happen in downstream/3rd-party interfaces
# that use interface-specific types. those need to provide mappings back
# to base python types (making this more streamlined is TODO)
if dtype in np_to_python.values():
# it's already a python type
python_type = dtype
elif issubclass(dtype, BaseModel):
python_type = dtype
else:
raise ValueError(
"dtype given in model does not have a corresponding python base "
"type - add one to the `maps.np_to_python` dict"
) from e
# does this need a warning?
python_type = Any
if python_type in _UNSUPPORTED_TYPES:
array_type = core_schema.any_schema()

View file

@ -58,6 +58,10 @@ class ValidationCase(BaseModel):
return Model
class BasicModel(BaseModel):
x: int
RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
@ -68,6 +72,7 @@ NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
@pytest.fixture(
@ -131,6 +136,8 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=STRING, dtype=str, passes=True),
ValidationCase(annotation=STRING, dtype=int, passes=False),
ValidationCase(annotation=STRING, dtype=float, passes=False),
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
ValidationCase(annotation=MODEL, dtype=int, passes=False),
],
ids=[
"float",
@ -154,6 +161,8 @@ def shape_cases(request) -> ValidationCase:
"str-str",
"str-int",
"str-float",
"model-model",
"model-int",
],
)
def dtype_cases(request) -> ValidationCase:

View file

@ -4,7 +4,7 @@ import pytest
import json
import dask.array as da
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from numpydantic.interface import DaskInterface
from numpydantic.exceptions import DtypeError, ShapeError
@ -13,7 +13,10 @@ from tests.conftest import ValidationCase
def dask_array(case: ValidationCase) -> da.Array:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
if issubclass(case.dtype, BaseModel):
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
else:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
def _test_dask_case(case: ValidationCase):

View file

@ -20,6 +20,8 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
Returns:
"""
if issubclass(case.dtype, BaseModel):
pytest.skip("hdf5 cant support arbitrary python objects")
return array_func(case.shape, case.dtype)

View file

@ -1,13 +1,16 @@
import numpy as np
import pytest
from pydantic import ValidationError
from pydantic import ValidationError, BaseModel
from numpydantic.exceptions import DtypeError, ShapeError
from tests.conftest import ValidationCase
def numpy_array(case: ValidationCase) -> np.ndarray:
return np.zeros(shape=case.shape, dtype=case.dtype)
if issubclass(case.dtype, BaseModel):
return np.full(shape=case.shape, fill_value=case.dtype(x=1))
else:
return np.zeros(shape=case.shape, dtype=case.dtype)
def _test_np_case(case: ValidationCase):

View file

@ -3,7 +3,9 @@ import json
import pytest
import zarr
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from numcodecs import Pickle
from numpydantic.interface import ZarrInterface
from numpydantic.interface.zarr import ZarrArrayPath
@ -31,7 +33,19 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore:
def _zarr_array(case: ValidationCase, store) -> zarr.core.Array:
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
if issubclass(case.dtype, BaseModel):
pytest.skip(
f"Zarr can't handle objects properly at the moment, "
"see https://github.com/zarr-developers/zarr-python/issues/2081"
)
# return zarr.full(
# shape=case.shape,
# fill_value=case.dtype(x=1),
# dtype=object,
# object_codec=Pickle(),
# )
else:
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
def _test_zarr_case(case: ValidationCase, store):

View file

@ -266,6 +266,30 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
assert inner_type["type"] == expected
def test_json_schema_dtype_model():
"""
Pydantic models can be used in arrays as dtypes
"""
class TestModel(BaseModel):
x: int
y: int
z: int
class MyModel(BaseModel):
array: NDArray[Shape["*, *"], TestModel]
schema = MyModel.model_json_schema()
# we should have a "$defs" with TestModel in it,
# and our array should be objects of that type
assert schema["properties"]["array"]["items"]["items"] == {
"$ref": "#/$defs/TestModel"
}
# we don't test pydantic' generic json schema model generation,
# just that one was defined
assert "TestModel" in schema["$defs"]
def _recursive_array(schema):
assert "$defs" in schema
# get the key uses for the array