mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
better json schema generation
This commit is contained in:
parent
a1786144fa
commit
060de62334
4 changed files with 206 additions and 53 deletions
|
@ -6,7 +6,6 @@ from datetime import datetime
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from nptyping import Bool, Float, Int, String
|
||||
|
||||
from numpydantic import dtype as dt
|
||||
|
||||
|
@ -59,5 +58,5 @@ flat_to_nptyping = {
|
|||
}
|
||||
"""Map from NWB-style flat dtypes to nptyping types"""
|
||||
|
||||
python_to_nptyping = {float: Float, str: String, int: Int, bool: Bool}
|
||||
python_to_nptyping = {float: dt.Float, str: dt.String, int: dt.Int, bool: dt.Bool}
|
||||
"""Map from python types to nptyping types"""
|
||||
|
|
|
@ -4,6 +4,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
|
|||
* Order to store data in (row first)
|
||||
"""
|
||||
|
||||
import pdb
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Tuple, Union
|
||||
|
||||
|
@ -18,22 +19,94 @@ from nptyping.structure_expression import check_type_names
|
|||
from nptyping.typing_ import (
|
||||
dtype_per_name,
|
||||
)
|
||||
from pydantic import GetJsonSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
from pydantic_core.core_schema import ListSchema
|
||||
from pydantic_core.core_schema import CoreSchema, ListSchema
|
||||
|
||||
from numpydantic import dtype as dt
|
||||
from numpydantic.dtype import DType
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import np_to_python
|
||||
|
||||
# from numpydantic.proxy import NDArrayProxy
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from pydantic import ValidationInfo
|
||||
|
||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
||||
|
||||
_UNSUPPORTED_TYPES = (complex,)
|
||||
"""
|
||||
python types that pydantic/json schema can't support (and Any will be used instead)
|
||||
"""
|
||||
|
||||
|
||||
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
"""Make a numeric dtype that respects min/max values from extended numpy types"""
|
||||
if dtype.__module__ == "builtins":
|
||||
metadata = None
|
||||
else:
|
||||
metadata = {"dtype": ".".join([dtype.__module__, dtype.__name__])}
|
||||
|
||||
if issubclass(dtype, np.floating):
|
||||
info = np.finfo(dtype)
|
||||
schema = core_schema.float_schema(le=float(info.max), ge=float(info.min))
|
||||
elif issubclass(dtype, np.integer):
|
||||
info = np.iinfo(dtype)
|
||||
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
|
||||
|
||||
else:
|
||||
schema = _handler.generate_schema(dtype, metadata=metadata)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||
|
||||
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
|
||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||
|
||||
if isinstance(dtype, tuple):
|
||||
# if it's a meta-type that refers to a generic float/int, just make that
|
||||
if dtype == dt.Float:
|
||||
array_type = core_schema.float_schema()
|
||||
elif dtype == dt.Integer:
|
||||
array_type = core_schema.int_schema()
|
||||
elif dtype == dt.Complex:
|
||||
array_type = core_schema.any_schema()
|
||||
else:
|
||||
# make a union of dtypes recursively
|
||||
types_ = list(set(dtype))
|
||||
array_type = core_schema.union_schema(
|
||||
[_lol_dtype(t, _handler) for t in types_]
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
python_type = np_to_python[dtype]
|
||||
except KeyError as e:
|
||||
if dtype in np_to_python.values():
|
||||
# it's already a python type
|
||||
python_type = dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"dtype given in model does not have a corresponding python base type - add one to the `maps.np_to_python` dict"
|
||||
) from e
|
||||
|
||||
if python_type in _UNSUPPORTED_TYPES:
|
||||
array_type = core_schema.any_schema()
|
||||
# TODO: warn and log here
|
||||
elif python_type in (float, int):
|
||||
array_type = _numeric_dtype(dtype, _handler)
|
||||
else:
|
||||
array_type = _handler.generate_schema(python_type)
|
||||
|
||||
return array_type
|
||||
|
||||
|
||||
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||
"""Make a pydantic JSON schema for an array as a list of lists."""
|
||||
|
||||
shape_parts = shape.__args__[0].split(",")
|
||||
split_parts = [
|
||||
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
|
||||
|
@ -64,6 +137,30 @@ def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
|||
return list_schema
|
||||
|
||||
|
||||
def make_json_schema(
|
||||
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
|
||||
) -> ListSchema:
|
||||
"""
|
||||
|
||||
Args:
|
||||
shape:
|
||||
dtype:
|
||||
_handler:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
dtype_schema = _lol_dtype(dtype, _handler)
|
||||
|
||||
# get the names of the shape constraints, if any
|
||||
if shape is Any:
|
||||
list_schema = core_schema.list_schema(core_schema.any_schema())
|
||||
else:
|
||||
list_schema = list_of_lists_schema(shape, dtype_schema)
|
||||
|
||||
return list_schema
|
||||
|
||||
|
||||
def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
||||
"""
|
||||
Validate using a matching :class:`.Interface` class using its
|
||||
|
@ -111,16 +208,16 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
|||
dtype = Any
|
||||
elif is_dtype:
|
||||
dtype = dtype_candidate
|
||||
elif issubclass(dtype_candidate, Structure):
|
||||
elif issubclass(dtype_candidate, Structure): # pragma: no cover
|
||||
dtype = dtype_candidate
|
||||
check_type_names(dtype, dtype_per_name)
|
||||
elif cls._is_literal_like(dtype_candidate):
|
||||
elif cls._is_literal_like(dtype_candidate): # pragma: no cover
|
||||
structure_expression = dtype_candidate.__args__[0]
|
||||
dtype = Structure[structure_expression]
|
||||
check_type_names(dtype, dtype_per_name)
|
||||
elif isinstance(dtype_candidate, tuple):
|
||||
elif isinstance(dtype_candidate, tuple): # pragma: no cover
|
||||
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
|
||||
else:
|
||||
else: # pragma: no cover
|
||||
raise InvalidArgumentsError(
|
||||
f"Unexpected argument '{dtype_candidate}', expecting"
|
||||
" Structure[<StructureExpression>]"
|
||||
|
@ -153,33 +250,14 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: "NDArray",
|
||||
_handler: Callable[[Any], core_schema.CoreSchema],
|
||||
_handler: _handler_type,
|
||||
) -> core_schema.CoreSchema:
|
||||
shape, dtype = _source_type.__args__
|
||||
shape: ShapeType
|
||||
dtype: DtypeType
|
||||
|
||||
# get pydantic core schema for the given specified type
|
||||
if isinstance(dtype, nptyping.structure.StructureMeta):
|
||||
raise NotImplementedError("Finish handling structured dtypes!")
|
||||
# functools.reduce(operator.or_, [int, float, str])
|
||||
else:
|
||||
if isinstance(dtype, tuple):
|
||||
types_ = list(set([np_to_python[dt] for dt in dtype]))
|
||||
# TODO: better type filtering - explicitly model what
|
||||
# numeric types are supported by JSON schema
|
||||
types_ = [t for t in types_ if t not in (complex,)]
|
||||
schemas = [_handler.generate_schema(dt) for dt in types_]
|
||||
array_type_handler = core_schema.union_schema(schemas)
|
||||
|
||||
else:
|
||||
array_type_handler = _handler.generate_schema(np_to_python[dtype])
|
||||
|
||||
# get the names of the shape constraints, if any
|
||||
if shape is Any:
|
||||
list_schema = core_schema.list_schema(core_schema.any_schema())
|
||||
else:
|
||||
list_schema = list_of_lists_schema(shape, array_type_handler)
|
||||
# get pydantic core schema as a list of lists for JSON schema
|
||||
list_schema = make_json_schema(shape, dtype, _handler)
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=list_schema,
|
||||
|
@ -195,3 +273,16 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
_jsonize_array, when_used="json"
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
):
|
||||
json_schema = handler(schema)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
|
||||
dtype = cls.__args__[1]
|
||||
if dtype.__module__ != "builtins":
|
||||
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
|
||||
|
||||
return json_schema
|
||||
|
|
|
@ -64,7 +64,7 @@ def array_model() -> (
|
|||
shape_str = ", ".join([str(s) for s in shape])
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Shape[shape_str], python_to_nptyping[dtype]]
|
||||
array: NDArray[Shape[shape_str], dtype]
|
||||
|
||||
return MyModel
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import pdb
|
||||
|
||||
import pytest
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
|
@ -9,6 +11,7 @@ from nptyping import Shape, Number
|
|||
|
||||
from numpydantic import NDArray
|
||||
from numpydantic.exceptions import ShapeError, DtypeError
|
||||
from numpydantic import dtype
|
||||
|
||||
|
||||
# from .fixtures import tmp_output_dir_func
|
||||
|
@ -99,23 +102,83 @@ def test_ndarray_serialize():
|
|||
assert isinstance(mod_dict["array"], np.ndarray)
|
||||
|
||||
|
||||
# def test_ndarray_proxy(tmp_output_dir_func):
|
||||
# h5f_source = tmp_output_dir_func / 'test.h5'
|
||||
# with h5py.File(h5f_source, 'w') as h5f:
|
||||
# dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3)))
|
||||
# dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
|
||||
#
|
||||
# class Model(BaseModel):
|
||||
# array: NDArray[Shape["* x, * y, 3 z"], Number]
|
||||
#
|
||||
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data'))
|
||||
# subarray = mod.array[0:5, 0:5, :]
|
||||
# assert isinstance(subarray, np.ndarray)
|
||||
# assert isinstance(subarray.sum(), float)
|
||||
# assert mod.array.name == '/data'
|
||||
#
|
||||
# with pytest.raises(NotImplementedError):
|
||||
# mod.array[0] = 5
|
||||
#
|
||||
# with pytest.raises(ValidationError):
|
||||
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad'))
|
||||
_json_schema_types = [
|
||||
*[(t, float) for t in dtype.Float],
|
||||
*[(t, int) for t in dtype.Integer],
|
||||
]
|
||||
|
||||
|
||||
def test_json_schema_basic(array_model):
|
||||
"""
|
||||
NDArray types should correctly generate a list of lists JSON schema
|
||||
"""
|
||||
shape = (15, 10)
|
||||
dtype = float
|
||||
model = array_model(shape, dtype)
|
||||
schema = model.model_json_schema()
|
||||
field = schema["properties"]["array"]
|
||||
|
||||
# outer shape
|
||||
assert field["maxItems"] == shape[0]
|
||||
assert field["minItems"] == shape[0]
|
||||
assert field["type"] == "array"
|
||||
|
||||
# inner shape
|
||||
inner = field["items"]
|
||||
assert inner["minItems"] == shape[1]
|
||||
assert inner["maxItems"] == shape[1]
|
||||
assert inner["items"]["type"] == "number"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
|
||||
def test_json_schema_dtype_single(dtype, array_model):
|
||||
"""
|
||||
dtypes should have correct mins and maxes set, and store the source dtype
|
||||
"""
|
||||
if issubclass(dtype, np.floating):
|
||||
info = np.finfo(dtype)
|
||||
min_val = info.min
|
||||
max_val = info.max
|
||||
schema_type = "number"
|
||||
elif issubclass(dtype, np.integer):
|
||||
info = np.iinfo(dtype)
|
||||
min_val = info.min
|
||||
max_val = info.max
|
||||
schema_type = "integer"
|
||||
else:
|
||||
raise ValueError("These should all be numpy types!")
|
||||
|
||||
shape = (15, 10)
|
||||
model = array_model(shape, dtype)
|
||||
schema = model.model_json_schema()
|
||||
inner_type = schema["properties"]["array"]["items"]["items"]
|
||||
assert inner_type["minimum"] == min_val
|
||||
assert inner_type["maximum"] == max_val
|
||||
assert inner_type["type"] == schema_type
|
||||
assert schema["properties"]["array"]["dtype"] == ".".join(
|
||||
[dtype.__module__, dtype.__name__]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_json_schema_dtype_builtin(dtype):
|
||||
"""
|
||||
Using builtin or generic (eg. `dtype.Integer` ) dtypes should
|
||||
make a simple json schema without mins/maxes/dtypes.
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.skip("Not implemented yet")
|
||||
def test_json_schema_wildcard():
|
||||
"""
|
||||
NDarray types should generate a JSON schema without shape constraints
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip("Not implemented yet")
|
||||
def test_json_schema_ellipsis():
|
||||
"""
|
||||
NDArray types should create a recursive JSON schema for any-shaped arrays
|
||||
"""
|
||||
pass
|
||||
|
|
Loading…
Reference in a new issue