better json schema generation

This commit is contained in:
sneakers-the-rat 2024-05-15 20:49:15 -07:00
parent a1786144fa
commit 060de62334
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
4 changed files with 206 additions and 53 deletions

View file

@ -6,7 +6,6 @@ from datetime import datetime
from typing import Any from typing import Any
import numpy as np import numpy as np
from nptyping import Bool, Float, Int, String
from numpydantic import dtype as dt from numpydantic import dtype as dt
@ -59,5 +58,5 @@ flat_to_nptyping = {
} }
"""Map from NWB-style flat dtypes to nptyping types""" """Map from NWB-style flat dtypes to nptyping types"""
python_to_nptyping = {float: Float, str: String, int: Int, bool: Bool} python_to_nptyping = {float: dt.Float, str: dt.String, int: dt.Int, bool: dt.Bool}
"""Map from python types to nptyping types""" """Map from python types to nptyping types"""

View file

@ -4,6 +4,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
* Order to store data in (row first) * Order to store data in (row first)
""" """
import pdb
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Tuple, Union from typing import TYPE_CHECKING, Any, Tuple, Union
@ -18,22 +19,94 @@ from nptyping.structure_expression import check_type_names
from nptyping.typing_ import ( from nptyping.typing_ import (
dtype_per_name, dtype_per_name,
) )
from pydantic import GetJsonSchemaHandler
from pydantic_core import core_schema from pydantic_core import core_schema
from pydantic_core.core_schema import ListSchema from pydantic_core.core_schema import CoreSchema, ListSchema
from numpydantic import dtype as dt
from numpydantic.dtype import DType from numpydantic.dtype import DType
from numpydantic.interface import Interface from numpydantic.interface import Interface
from numpydantic.maps import np_to_python from numpydantic.maps import np_to_python
# from numpydantic.proxy import NDArrayProxy
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
if TYPE_CHECKING: if TYPE_CHECKING: # pragma: no cover
from pydantic import ValidationInfo from pydantic import ValidationInfo
_handler_type = Callable[[Any], core_schema.CoreSchema]
_UNSUPPORTED_TYPES = (complex,)
"""
python types that pydantic/json schema can't support (and Any will be used instead)
"""
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Make a numeric dtype that respects min/max values from extended numpy types"""
if dtype.__module__ == "builtins":
metadata = None
else:
metadata = {"dtype": ".".join([dtype.__module__, dtype.__name__])}
if issubclass(dtype, np.floating):
info = np.finfo(dtype)
schema = core_schema.float_schema(le=float(info.max), ge=float(info.min))
elif issubclass(dtype, np.integer):
info = np.iinfo(dtype)
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
else:
schema = _handler.generate_schema(dtype, metadata=metadata)
return schema
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported")
if isinstance(dtype, tuple):
# if it's a meta-type that refers to a generic float/int, just make that
if dtype == dt.Float:
array_type = core_schema.float_schema()
elif dtype == dt.Integer:
array_type = core_schema.int_schema()
elif dtype == dt.Complex:
array_type = core_schema.any_schema()
else:
# make a union of dtypes recursively
types_ = list(set(dtype))
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in types_]
)
else:
try:
python_type = np_to_python[dtype]
except KeyError as e:
if dtype in np_to_python.values():
# it's already a python type
python_type = dtype
else:
raise ValueError(
"dtype given in model does not have a corresponding python base type - add one to the `maps.np_to_python` dict"
) from e
if python_type in _UNSUPPORTED_TYPES:
array_type = core_schema.any_schema()
# TODO: warn and log here
elif python_type in (float, int):
array_type = _numeric_dtype(dtype, _handler)
else:
array_type = _handler.generate_schema(python_type)
return array_type
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema: def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
"""Make a pydantic JSON schema for an array as a list of lists.""" """Make a pydantic JSON schema for an array as a list of lists."""
shape_parts = shape.__args__[0].split(",") shape_parts = shape.__args__[0].split(",")
split_parts = [ split_parts = [
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
@ -64,6 +137,30 @@ def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
return list_schema return list_schema
def make_json_schema(
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
) -> ListSchema:
"""
Args:
shape:
dtype:
_handler:
Returns:
"""
dtype_schema = _lol_dtype(dtype, _handler)
# get the names of the shape constraints, if any
if shape is Any:
list_schema = core_schema.list_schema(core_schema.any_schema())
else:
list_schema = list_of_lists_schema(shape, dtype_schema)
return list_schema
def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
""" """
Validate using a matching :class:`.Interface` class using its Validate using a matching :class:`.Interface` class using its
@ -111,16 +208,16 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
dtype = Any dtype = Any
elif is_dtype: elif is_dtype:
dtype = dtype_candidate dtype = dtype_candidate
elif issubclass(dtype_candidate, Structure): elif issubclass(dtype_candidate, Structure): # pragma: no cover
dtype = dtype_candidate dtype = dtype_candidate
check_type_names(dtype, dtype_per_name) check_type_names(dtype, dtype_per_name)
elif cls._is_literal_like(dtype_candidate): elif cls._is_literal_like(dtype_candidate): # pragma: no cover
structure_expression = dtype_candidate.__args__[0] structure_expression = dtype_candidate.__args__[0]
dtype = Structure[structure_expression] dtype = Structure[structure_expression]
check_type_names(dtype, dtype_per_name) check_type_names(dtype, dtype_per_name)
elif isinstance(dtype_candidate, tuple): elif isinstance(dtype_candidate, tuple): # pragma: no cover
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate]) dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
else: else: # pragma: no cover
raise InvalidArgumentsError( raise InvalidArgumentsError(
f"Unexpected argument '{dtype_candidate}', expecting" f"Unexpected argument '{dtype_candidate}', expecting"
" Structure[<StructureExpression>]" " Structure[<StructureExpression>]"
@ -153,33 +250,14 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
def __get_pydantic_core_schema__( def __get_pydantic_core_schema__(
cls, cls,
_source_type: "NDArray", _source_type: "NDArray",
_handler: Callable[[Any], core_schema.CoreSchema], _handler: _handler_type,
) -> core_schema.CoreSchema: ) -> core_schema.CoreSchema:
shape, dtype = _source_type.__args__ shape, dtype = _source_type.__args__
shape: ShapeType shape: ShapeType
dtype: DtypeType dtype: DtypeType
# get pydantic core schema for the given specified type # get pydantic core schema as a list of lists for JSON schema
if isinstance(dtype, nptyping.structure.StructureMeta): list_schema = make_json_schema(shape, dtype, _handler)
raise NotImplementedError("Finish handling structured dtypes!")
# functools.reduce(operator.or_, [int, float, str])
else:
if isinstance(dtype, tuple):
types_ = list(set([np_to_python[dt] for dt in dtype]))
# TODO: better type filtering - explicitly model what
# numeric types are supported by JSON schema
types_ = [t for t in types_ if t not in (complex,)]
schemas = [_handler.generate_schema(dt) for dt in types_]
array_type_handler = core_schema.union_schema(schemas)
else:
array_type_handler = _handler.generate_schema(np_to_python[dtype])
# get the names of the shape constraints, if any
if shape is Any:
list_schema = core_schema.list_schema(core_schema.any_schema())
else:
list_schema = list_of_lists_schema(shape, array_type_handler)
return core_schema.json_or_python_schema( return core_schema.json_or_python_schema(
json_schema=list_schema, json_schema=list_schema,
@ -195,3 +273,16 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
_jsonize_array, when_used="json" _jsonize_array, when_used="json"
), ),
) )
@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
):
json_schema = handler(schema)
json_schema = handler.resolve_ref_schema(json_schema)
dtype = cls.__args__[1]
if dtype.__module__ != "builtins":
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
return json_schema

View file

@ -64,7 +64,7 @@ def array_model() -> (
shape_str = ", ".join([str(s) for s in shape]) shape_str = ", ".join([str(s) for s in shape])
class MyModel(BaseModel): class MyModel(BaseModel):
array: NDArray[Shape[shape_str], python_to_nptyping[dtype]] array: NDArray[Shape[shape_str], dtype]
return MyModel return MyModel

View file

@ -1,3 +1,5 @@
import pdb
import pytest import pytest
from typing import Union, Optional, Any from typing import Union, Optional, Any
@ -9,6 +11,7 @@ from nptyping import Shape, Number
from numpydantic import NDArray from numpydantic import NDArray
from numpydantic.exceptions import ShapeError, DtypeError from numpydantic.exceptions import ShapeError, DtypeError
from numpydantic import dtype
# from .fixtures import tmp_output_dir_func # from .fixtures import tmp_output_dir_func
@ -99,23 +102,83 @@ def test_ndarray_serialize():
assert isinstance(mod_dict["array"], np.ndarray) assert isinstance(mod_dict["array"], np.ndarray)
# def test_ndarray_proxy(tmp_output_dir_func): _json_schema_types = [
# h5f_source = tmp_output_dir_func / 'test.h5' *[(t, float) for t in dtype.Float],
# with h5py.File(h5f_source, 'w') as h5f: *[(t, int) for t in dtype.Integer],
# dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3))) ]
# dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
#
# class Model(BaseModel): def test_json_schema_basic(array_model):
# array: NDArray[Shape["* x, * y, 3 z"], Number] """
# NDArray types should correctly generate a list of lists JSON schema
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data')) """
# subarray = mod.array[0:5, 0:5, :] shape = (15, 10)
# assert isinstance(subarray, np.ndarray) dtype = float
# assert isinstance(subarray.sum(), float) model = array_model(shape, dtype)
# assert mod.array.name == '/data' schema = model.model_json_schema()
# field = schema["properties"]["array"]
# with pytest.raises(NotImplementedError):
# mod.array[0] = 5 # outer shape
# assert field["maxItems"] == shape[0]
# with pytest.raises(ValidationError): assert field["minItems"] == shape[0]
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad')) assert field["type"] == "array"
# inner shape
inner = field["items"]
assert inner["minItems"] == shape[1]
assert inner["maxItems"] == shape[1]
assert inner["items"]["type"] == "number"
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
def test_json_schema_dtype_single(dtype, array_model):
"""
dtypes should have correct mins and maxes set, and store the source dtype
"""
if issubclass(dtype, np.floating):
info = np.finfo(dtype)
min_val = info.min
max_val = info.max
schema_type = "number"
elif issubclass(dtype, np.integer):
info = np.iinfo(dtype)
min_val = info.min
max_val = info.max
schema_type = "integer"
else:
raise ValueError("These should all be numpy types!")
shape = (15, 10)
model = array_model(shape, dtype)
schema = model.model_json_schema()
inner_type = schema["properties"]["array"]["items"]["items"]
assert inner_type["minimum"] == min_val
assert inner_type["maximum"] == max_val
assert inner_type["type"] == schema_type
assert schema["properties"]["array"]["dtype"] == ".".join(
[dtype.__module__, dtype.__name__]
)
@pytest.mark.skip()
def test_json_schema_dtype_builtin(dtype):
"""
Using builtin or generic (eg. `dtype.Integer` ) dtypes should
make a simple json schema without mins/maxes/dtypes.
"""
@pytest.mark.skip("Not implemented yet")
def test_json_schema_wildcard():
"""
NDarray types should generate a JSON schema without shape constraints
"""
pass
@pytest.mark.skip("Not implemented yet")
def test_json_schema_ellipsis():
"""
NDArray types should create a recursive JSON schema for any-shaped arrays
"""
pass