better json schema generation

This commit is contained in:
sneakers-the-rat 2024-05-15 20:49:15 -07:00
parent a1786144fa
commit 060de62334
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
4 changed files with 206 additions and 53 deletions

View file

@ -6,7 +6,6 @@ from datetime import datetime
from typing import Any
import numpy as np
from nptyping import Bool, Float, Int, String
from numpydantic import dtype as dt
@ -59,5 +58,5 @@ flat_to_nptyping = {
}
"""Map from NWB-style flat dtypes to nptyping types"""
python_to_nptyping = {float: Float, str: String, int: Int, bool: Bool}
python_to_nptyping = {float: dt.Float, str: dt.String, int: dt.Int, bool: dt.Bool}
"""Map from python types to nptyping types"""

View file

@ -4,6 +4,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
* Order to store data in (row first)
"""
import pdb
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Tuple, Union
@ -18,22 +19,94 @@ from nptyping.structure_expression import check_type_names
from nptyping.typing_ import (
dtype_per_name,
)
from pydantic import GetJsonSchemaHandler
from pydantic_core import core_schema
from pydantic_core.core_schema import ListSchema
from pydantic_core.core_schema import CoreSchema, ListSchema
from numpydantic import dtype as dt
from numpydantic.dtype import DType
from numpydantic.interface import Interface
from numpydantic.maps import np_to_python
# from numpydantic.proxy import NDArrayProxy
from numpydantic.types import DtypeType, NDArrayType, ShapeType
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from pydantic import ValidationInfo
_handler_type = Callable[[Any], core_schema.CoreSchema]
_UNSUPPORTED_TYPES = (complex,)
"""
python types that pydantic/json schema can't support (and Any will be used instead)
"""
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Make a numeric dtype that respects min/max values from extended numpy types"""
if dtype.__module__ == "builtins":
metadata = None
else:
metadata = {"dtype": ".".join([dtype.__module__, dtype.__name__])}
if issubclass(dtype, np.floating):
info = np.finfo(dtype)
schema = core_schema.float_schema(le=float(info.max), ge=float(info.min))
elif issubclass(dtype, np.integer):
info = np.iinfo(dtype)
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
else:
schema = _handler.generate_schema(dtype, metadata=metadata)
return schema
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported")
if isinstance(dtype, tuple):
# if it's a meta-type that refers to a generic float/int, just make that
if dtype == dt.Float:
array_type = core_schema.float_schema()
elif dtype == dt.Integer:
array_type = core_schema.int_schema()
elif dtype == dt.Complex:
array_type = core_schema.any_schema()
else:
# make a union of dtypes recursively
types_ = list(set(dtype))
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in types_]
)
else:
try:
python_type = np_to_python[dtype]
except KeyError as e:
if dtype in np_to_python.values():
# it's already a python type
python_type = dtype
else:
raise ValueError(
"dtype given in model does not have a corresponding python base type - add one to the `maps.np_to_python` dict"
) from e
if python_type in _UNSUPPORTED_TYPES:
array_type = core_schema.any_schema()
# TODO: warn and log here
elif python_type in (float, int):
array_type = _numeric_dtype(dtype, _handler)
else:
array_type = _handler.generate_schema(python_type)
return array_type
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
"""Make a pydantic JSON schema for an array as a list of lists."""
shape_parts = shape.__args__[0].split(",")
split_parts = [
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
@ -64,6 +137,30 @@ def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
return list_schema
def make_json_schema(
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
) -> ListSchema:
"""
Args:
shape:
dtype:
_handler:
Returns:
"""
dtype_schema = _lol_dtype(dtype, _handler)
# get the names of the shape constraints, if any
if shape is Any:
list_schema = core_schema.list_schema(core_schema.any_schema())
else:
list_schema = list_of_lists_schema(shape, dtype_schema)
return list_schema
def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
"""
Validate using a matching :class:`.Interface` class using its
@ -111,16 +208,16 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
dtype = Any
elif is_dtype:
dtype = dtype_candidate
elif issubclass(dtype_candidate, Structure):
elif issubclass(dtype_candidate, Structure): # pragma: no cover
dtype = dtype_candidate
check_type_names(dtype, dtype_per_name)
elif cls._is_literal_like(dtype_candidate):
elif cls._is_literal_like(dtype_candidate): # pragma: no cover
structure_expression = dtype_candidate.__args__[0]
dtype = Structure[structure_expression]
check_type_names(dtype, dtype_per_name)
elif isinstance(dtype_candidate, tuple):
elif isinstance(dtype_candidate, tuple): # pragma: no cover
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
else:
else: # pragma: no cover
raise InvalidArgumentsError(
f"Unexpected argument '{dtype_candidate}', expecting"
" Structure[<StructureExpression>]"
@ -153,33 +250,14 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
def __get_pydantic_core_schema__(
cls,
_source_type: "NDArray",
_handler: Callable[[Any], core_schema.CoreSchema],
_handler: _handler_type,
) -> core_schema.CoreSchema:
shape, dtype = _source_type.__args__
shape: ShapeType
dtype: DtypeType
# get pydantic core schema for the given specified type
if isinstance(dtype, nptyping.structure.StructureMeta):
raise NotImplementedError("Finish handling structured dtypes!")
# functools.reduce(operator.or_, [int, float, str])
else:
if isinstance(dtype, tuple):
types_ = list(set([np_to_python[dt] for dt in dtype]))
# TODO: better type filtering - explicitly model what
# numeric types are supported by JSON schema
types_ = [t for t in types_ if t not in (complex,)]
schemas = [_handler.generate_schema(dt) for dt in types_]
array_type_handler = core_schema.union_schema(schemas)
else:
array_type_handler = _handler.generate_schema(np_to_python[dtype])
# get the names of the shape constraints, if any
if shape is Any:
list_schema = core_schema.list_schema(core_schema.any_schema())
else:
list_schema = list_of_lists_schema(shape, array_type_handler)
# get pydantic core schema as a list of lists for JSON schema
list_schema = make_json_schema(shape, dtype, _handler)
return core_schema.json_or_python_schema(
json_schema=list_schema,
@ -195,3 +273,16 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
_jsonize_array, when_used="json"
),
)
@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
):
json_schema = handler(schema)
json_schema = handler.resolve_ref_schema(json_schema)
dtype = cls.__args__[1]
if dtype.__module__ != "builtins":
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
return json_schema

View file

@ -64,7 +64,7 @@ def array_model() -> (
shape_str = ", ".join([str(s) for s in shape])
class MyModel(BaseModel):
array: NDArray[Shape[shape_str], python_to_nptyping[dtype]]
array: NDArray[Shape[shape_str], dtype]
return MyModel

View file

@ -1,3 +1,5 @@
import pdb
import pytest
from typing import Union, Optional, Any
@ -9,6 +11,7 @@ from nptyping import Shape, Number
from numpydantic import NDArray
from numpydantic.exceptions import ShapeError, DtypeError
from numpydantic import dtype
# from .fixtures import tmp_output_dir_func
@ -99,23 +102,83 @@ def test_ndarray_serialize():
assert isinstance(mod_dict["array"], np.ndarray)
# def test_ndarray_proxy(tmp_output_dir_func):
# h5f_source = tmp_output_dir_func / 'test.h5'
# with h5py.File(h5f_source, 'w') as h5f:
# dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3)))
# dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
#
# class Model(BaseModel):
# array: NDArray[Shape["* x, * y, 3 z"], Number]
#
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data'))
# subarray = mod.array[0:5, 0:5, :]
# assert isinstance(subarray, np.ndarray)
# assert isinstance(subarray.sum(), float)
# assert mod.array.name == '/data'
#
# with pytest.raises(NotImplementedError):
# mod.array[0] = 5
#
# with pytest.raises(ValidationError):
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad'))
_json_schema_types = [
*[(t, float) for t in dtype.Float],
*[(t, int) for t in dtype.Integer],
]
def test_json_schema_basic(array_model):
"""
NDArray types should correctly generate a list of lists JSON schema
"""
shape = (15, 10)
dtype = float
model = array_model(shape, dtype)
schema = model.model_json_schema()
field = schema["properties"]["array"]
# outer shape
assert field["maxItems"] == shape[0]
assert field["minItems"] == shape[0]
assert field["type"] == "array"
# inner shape
inner = field["items"]
assert inner["minItems"] == shape[1]
assert inner["maxItems"] == shape[1]
assert inner["items"]["type"] == "number"
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
def test_json_schema_dtype_single(dtype, array_model):
"""
dtypes should have correct mins and maxes set, and store the source dtype
"""
if issubclass(dtype, np.floating):
info = np.finfo(dtype)
min_val = info.min
max_val = info.max
schema_type = "number"
elif issubclass(dtype, np.integer):
info = np.iinfo(dtype)
min_val = info.min
max_val = info.max
schema_type = "integer"
else:
raise ValueError("These should all be numpy types!")
shape = (15, 10)
model = array_model(shape, dtype)
schema = model.model_json_schema()
inner_type = schema["properties"]["array"]["items"]["items"]
assert inner_type["minimum"] == min_val
assert inner_type["maximum"] == max_val
assert inner_type["type"] == schema_type
assert schema["properties"]["array"]["dtype"] == ".".join(
[dtype.__module__, dtype.__name__]
)
@pytest.mark.skip()
def test_json_schema_dtype_builtin(dtype):
"""
Using builtin or generic (eg. `dtype.Integer` ) dtypes should
make a simple json schema without mins/maxes/dtypes.
"""
@pytest.mark.skip("Not implemented yet")
def test_json_schema_wildcard():
"""
NDarray types should generate a JSON schema without shape constraints
"""
pass
@pytest.mark.skip("Not implemented yet")
def test_json_schema_ellipsis():
"""
NDArray types should create a recursive JSON schema for any-shaped arrays
"""
pass