mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
move helper methods to schema
module and out of ndarray
This commit is contained in:
parent
060de62334
commit
49a51563d6
6 changed files with 269 additions and 182 deletions
|
@ -112,7 +112,6 @@ Unicode = np.unicode_
|
|||
|
||||
Number = tuple(
|
||||
[
|
||||
np.number,
|
||||
*Integer,
|
||||
*Float,
|
||||
*Complex,
|
||||
|
|
|
@ -15,19 +15,36 @@ def generate_ndarray_stub() -> str:
|
|||
Make a stub file based on the array interfaces that are available
|
||||
"""
|
||||
|
||||
import_strings = [
|
||||
f"from {arr.__module__} import {arr.__name__}"
|
||||
for arr in Interface.input_types()
|
||||
if arr.__module__ != "builtins"
|
||||
]
|
||||
import_strings = []
|
||||
type_names = []
|
||||
for arr in Interface.input_types():
|
||||
if arr.__module__ == "builtins":
|
||||
continue
|
||||
|
||||
# Create import statements, saving aliased name of type if needed
|
||||
if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing":
|
||||
type_name = arr.__name__
|
||||
import_strings.append(f"from {arr.__module__} import {arr.__name__}")
|
||||
else:
|
||||
# since other packages could use the same name for an imported object
|
||||
# (eg dask and zarr both use an Array class)
|
||||
# we make an import alias from the module names to differentiate them
|
||||
# in the type annotation
|
||||
mod_name = "".join([a.capitalize() for a in arr.__module__.split(".")])
|
||||
type_name = mod_name + arr.__name__
|
||||
import_strings.append(
|
||||
f"from {arr.__module__} import {arr.__name__} " f"as {type_name}"
|
||||
)
|
||||
|
||||
if arr.__module__ != "typing":
|
||||
type_names.append(type_name)
|
||||
else:
|
||||
type_names.append(str(arr))
|
||||
|
||||
import_strings.extend(_BUILTIN_IMPORTS)
|
||||
import_string = "\n".join(import_strings)
|
||||
|
||||
class_names = [
|
||||
arr.__name__ if arr.__module__ != "typing" else str(arr)
|
||||
for arr in Interface.input_types()
|
||||
]
|
||||
class_union = " | ".join(class_names)
|
||||
class_union = " | ".join(type_names)
|
||||
ndarray_type = "NDArray = " + class_union
|
||||
|
||||
stub_string = "\n".join([import_string, ndarray_type])
|
||||
|
|
|
@ -1,16 +1,21 @@
|
|||
"""
|
||||
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization
|
||||
|
||||
* Order to store data in (row first)
|
||||
.. note::
|
||||
|
||||
This module should *only* have the :class:`.NDArray` class in it, because the
|
||||
type stub ``ndarray.pyi`` is only created for :class:`.NDArray` . Otherwise,
|
||||
type checkers will complain about using any helper functions elsewhere -
|
||||
those all belong in :mod:`numpydantic.schema` .
|
||||
|
||||
Keeping with nptyping's style, NDArrayMeta is in this module even if it's
|
||||
excluded from the type stub.
|
||||
|
||||
"""
|
||||
|
||||
import pdb
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Tuple
|
||||
|
||||
import nptyping.structure
|
||||
import numpy as np
|
||||
from nptyping import Shape
|
||||
from nptyping.error import InvalidArgumentsError
|
||||
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||
from nptyping.nptyping_type import NPTypingType
|
||||
|
@ -21,176 +26,26 @@ from nptyping.typing_ import (
|
|||
)
|
||||
from pydantic import GetJsonSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
from pydantic_core.core_schema import CoreSchema, ListSchema
|
||||
|
||||
from numpydantic import dtype as dt
|
||||
from numpydantic.dtype import DType
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import np_to_python
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
from numpydantic.maps import python_to_nptyping
|
||||
from numpydantic.schema import (
|
||||
_handler_type,
|
||||
_jsonize_array,
|
||||
coerce_list,
|
||||
get_validate_interface,
|
||||
make_json_schema,
|
||||
)
|
||||
from numpydantic.types import DtypeType, ShapeType
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from pydantic import ValidationInfo
|
||||
pass
|
||||
|
||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
||||
|
||||
_UNSUPPORTED_TYPES = (complex,)
|
||||
"""
|
||||
python types that pydantic/json schema can't support (and Any will be used instead)
|
||||
"""
|
||||
|
||||
|
||||
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
"""Make a numeric dtype that respects min/max values from extended numpy types"""
|
||||
if dtype.__module__ == "builtins":
|
||||
metadata = None
|
||||
else:
|
||||
metadata = {"dtype": ".".join([dtype.__module__, dtype.__name__])}
|
||||
|
||||
if issubclass(dtype, np.floating):
|
||||
info = np.finfo(dtype)
|
||||
schema = core_schema.float_schema(le=float(info.max), ge=float(info.min))
|
||||
elif issubclass(dtype, np.integer):
|
||||
info = np.iinfo(dtype)
|
||||
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
|
||||
|
||||
else:
|
||||
schema = _handler.generate_schema(dtype, metadata=metadata)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||
|
||||
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
|
||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||
|
||||
if isinstance(dtype, tuple):
|
||||
# if it's a meta-type that refers to a generic float/int, just make that
|
||||
if dtype == dt.Float:
|
||||
array_type = core_schema.float_schema()
|
||||
elif dtype == dt.Integer:
|
||||
array_type = core_schema.int_schema()
|
||||
elif dtype == dt.Complex:
|
||||
array_type = core_schema.any_schema()
|
||||
else:
|
||||
# make a union of dtypes recursively
|
||||
types_ = list(set(dtype))
|
||||
array_type = core_schema.union_schema(
|
||||
[_lol_dtype(t, _handler) for t in types_]
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
python_type = np_to_python[dtype]
|
||||
except KeyError as e:
|
||||
if dtype in np_to_python.values():
|
||||
# it's already a python type
|
||||
python_type = dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"dtype given in model does not have a corresponding python base type - add one to the `maps.np_to_python` dict"
|
||||
) from e
|
||||
|
||||
if python_type in _UNSUPPORTED_TYPES:
|
||||
array_type = core_schema.any_schema()
|
||||
# TODO: warn and log here
|
||||
elif python_type in (float, int):
|
||||
array_type = _numeric_dtype(dtype, _handler)
|
||||
else:
|
||||
array_type = _handler.generate_schema(python_type)
|
||||
|
||||
return array_type
|
||||
|
||||
|
||||
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||
"""Make a pydantic JSON schema for an array as a list of lists."""
|
||||
|
||||
shape_parts = shape.__args__[0].split(",")
|
||||
split_parts = [
|
||||
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
|
||||
]
|
||||
|
||||
# Construct a list of list schema
|
||||
# go in reverse order - construct list schemas such that
|
||||
# the final schema is the one that checks the first dimension
|
||||
shape_labels = reversed(split_parts)
|
||||
shape_args = reversed(shape.prepared_args)
|
||||
list_schema = None
|
||||
for arg, label in zip(shape_args, shape_labels):
|
||||
# which handler to use? for the first we use the actual type
|
||||
# handler, everywhere else we use the prior list handler
|
||||
inner_schema = array_type_handler if list_schema is None else list_schema
|
||||
|
||||
# make a label annotation, if we have one
|
||||
metadata = {"name": label} if label is not None else None
|
||||
|
||||
# make the current level list schema, accounting for shape
|
||||
if arg == "*":
|
||||
list_schema = core_schema.list_schema(inner_schema, metadata=metadata)
|
||||
else:
|
||||
arg = int(arg)
|
||||
list_schema = core_schema.list_schema(
|
||||
inner_schema, min_length=arg, max_length=arg, metadata=metadata
|
||||
)
|
||||
return list_schema
|
||||
|
||||
|
||||
def make_json_schema(
|
||||
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
|
||||
) -> ListSchema:
|
||||
"""
|
||||
|
||||
Args:
|
||||
shape:
|
||||
dtype:
|
||||
_handler:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
dtype_schema = _lol_dtype(dtype, _handler)
|
||||
|
||||
# get the names of the shape constraints, if any
|
||||
if shape is Any:
|
||||
list_schema = core_schema.list_schema(core_schema.any_schema())
|
||||
else:
|
||||
list_schema = list_of_lists_schema(shape, dtype_schema)
|
||||
|
||||
return list_schema
|
||||
|
||||
|
||||
def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
||||
"""
|
||||
Validate using a matching :class:`.Interface` class using its
|
||||
:meth:`.Interface.validate` method
|
||||
"""
|
||||
|
||||
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType:
|
||||
interface_cls = Interface.match(value)
|
||||
interface = interface_cls(shape, dtype)
|
||||
value = interface.validate(value)
|
||||
return value
|
||||
|
||||
return validate_interface
|
||||
|
||||
|
||||
def _jsonize_array(value: Any) -> Union[list, dict]:
|
||||
interface_cls = Interface.match_output(value)
|
||||
return interface_cls.to_json(value)
|
||||
|
||||
|
||||
def coerce_list(value: Any) -> np.ndarray:
|
||||
"""
|
||||
If a value is passed as a list or list of lists, try and coerce it into an array
|
||||
rather than failing validation.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
value = np.array(value)
|
||||
return value
|
||||
|
||||
|
||||
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||
"""
|
||||
Hooking into nptyping's array metaclass to override methods pending
|
||||
|
@ -201,9 +56,12 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
|||
"""
|
||||
Override of base _get_dtype method to allow for compound tuple types
|
||||
"""
|
||||
if dtype_candidate in python_to_nptyping:
|
||||
dtype_candidate = python_to_nptyping[dtype_candidate]
|
||||
is_dtype = isinstance(dtype_candidate, type) and issubclass(
|
||||
dtype_candidate, np.generic
|
||||
)
|
||||
|
||||
if dtype_candidate is Any:
|
||||
dtype = Any
|
||||
elif is_dtype:
|
||||
|
@ -265,7 +123,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
[
|
||||
core_schema.no_info_plain_validator_function(coerce_list),
|
||||
core_schema.with_info_plain_validator_function(
|
||||
_get_validate_interface(shape, dtype)
|
||||
get_validate_interface(shape, dtype)
|
||||
),
|
||||
]
|
||||
),
|
||||
|
@ -277,12 +135,12 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
):
|
||||
) -> core_schema.JsonSchema:
|
||||
json_schema = handler(schema)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
|
||||
dtype = cls.__args__[1]
|
||||
if dtype.__module__ != "builtins":
|
||||
if not isinstance(dtype, tuple) and dtype.__module__ != "builtins":
|
||||
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
|
||||
|
||||
return json_schema
|
||||
|
|
189
src/numpydantic/schema.py
Normal file
189
src/numpydantic/schema.py
Normal file
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
||||
:mod:`~numpydantic.ndarray` for why these are separated.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import nptyping.structure
|
||||
import numpy as np
|
||||
from nptyping import Shape
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
||||
|
||||
from numpydantic import dtype as dt
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import np_to_python
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
||||
_UNSUPPORTED_TYPES = (complex,)
|
||||
|
||||
|
||||
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
"""Make a numeric dtype that respects min/max values from extended numpy types"""
|
||||
if dtype in (np.number,):
|
||||
dtype = float
|
||||
|
||||
if issubclass(dtype, np.floating):
|
||||
info = np.finfo(dtype)
|
||||
schema = core_schema.float_schema(le=float(info.max), ge=float(info.min))
|
||||
elif issubclass(dtype, np.integer):
|
||||
info = np.iinfo(dtype)
|
||||
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
|
||||
|
||||
else:
|
||||
schema = _handler.generate_schema(dtype)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||
|
||||
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
|
||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||
|
||||
if isinstance(dtype, tuple):
|
||||
# if it's a meta-type that refers to a generic float/int, just make that
|
||||
if dtype == dt.Float:
|
||||
array_type = core_schema.float_schema()
|
||||
elif dtype == dt.Integer:
|
||||
array_type = core_schema.int_schema()
|
||||
elif dtype == dt.Complex:
|
||||
array_type = core_schema.any_schema()
|
||||
else:
|
||||
# make a union of dtypes recursively
|
||||
types_ = list(set(dtype))
|
||||
array_type = core_schema.union_schema(
|
||||
[_lol_dtype(t, _handler) for t in types_]
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
python_type = np_to_python[dtype]
|
||||
except KeyError as e:
|
||||
if dtype in np_to_python.values():
|
||||
# it's already a python type
|
||||
python_type = dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"dtype given in model does not have a corresponding python base "
|
||||
"type - add one to the `maps.np_to_python` dict"
|
||||
) from e
|
||||
|
||||
if python_type in _UNSUPPORTED_TYPES:
|
||||
array_type = core_schema.any_schema()
|
||||
# TODO: warn and log here
|
||||
elif python_type in (float, int):
|
||||
array_type = _numeric_dtype(dtype, _handler)
|
||||
else:
|
||||
array_type = _handler.generate_schema(python_type)
|
||||
|
||||
return array_type
|
||||
|
||||
|
||||
def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
|
||||
"""
|
||||
Make a pydantic JSON schema for an array as a list of lists.
|
||||
|
||||
For each item in the shape, create a list schema. In the innermost schema
|
||||
insert the passed ``array_type`` schema.
|
||||
|
||||
This function is typically called from :func:`.make_json_schema`
|
||||
|
||||
Args:
|
||||
shape (:class:`.Shape` ): Shape determines the depth and max/min elements
|
||||
for each layer of list schema
|
||||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||
core schema to use in the innermost list entry
|
||||
"""
|
||||
|
||||
shape_parts = shape.__args__[0].split(",")
|
||||
split_parts = [
|
||||
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
|
||||
]
|
||||
|
||||
# Construct a list of list schema
|
||||
# go in reverse order - construct list schemas such that
|
||||
# the final schema is the one that checks the first dimension
|
||||
shape_labels = reversed(split_parts)
|
||||
shape_args = reversed(shape.prepared_args)
|
||||
list_schema = None
|
||||
for arg, label in zip(shape_args, shape_labels):
|
||||
# which handler to use? for the first we use the actual type
|
||||
# handler, everywhere else we use the prior list handler
|
||||
inner_schema = array_type if list_schema is None else list_schema
|
||||
|
||||
# make a label annotation, if we have one
|
||||
metadata = {"name": label} if label is not None else None
|
||||
|
||||
# make the current level list schema, accounting for shape
|
||||
if arg == "*":
|
||||
list_schema = core_schema.list_schema(inner_schema, metadata=metadata)
|
||||
else:
|
||||
arg = int(arg)
|
||||
list_schema = core_schema.list_schema(
|
||||
inner_schema, min_length=arg, max_length=arg, metadata=metadata
|
||||
)
|
||||
return list_schema
|
||||
|
||||
|
||||
def make_json_schema(
|
||||
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
|
||||
) -> ListSchema:
|
||||
"""
|
||||
Make a list of list JSON schema from a shape and a dtype.
|
||||
|
||||
First resolves the dtype into a pydantic ``CoreSchema`` ,
|
||||
and then uses that with :func:`.list_of_lists_schema` .
|
||||
|
||||
Args:
|
||||
shape ( ShapeType ): Specification of a shape, as a tuple or
|
||||
an nptyping ``Shape``
|
||||
dtype ( DtypeType ): A builtin type or numpy dtype
|
||||
_handler: The pydantic schema generation handler (see pydantic docs)
|
||||
|
||||
Returns:
|
||||
:class:`pydantic_core.core_schema.ListSchema`
|
||||
"""
|
||||
dtype_schema = _lol_dtype(dtype, _handler)
|
||||
|
||||
# get the names of the shape constraints, if any
|
||||
if shape is Any:
|
||||
list_schema = core_schema.list_schema(core_schema.any_schema())
|
||||
else:
|
||||
list_schema = list_of_lists_schema(shape, dtype_schema)
|
||||
|
||||
return list_schema
|
||||
|
||||
|
||||
def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
||||
"""
|
||||
Validate using a matching :class:`.Interface` class using its
|
||||
:meth:`.Interface.validate` method
|
||||
"""
|
||||
|
||||
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType:
|
||||
interface_cls = Interface.match(value)
|
||||
interface = interface_cls(shape, dtype)
|
||||
value = interface.validate(value)
|
||||
return value
|
||||
|
||||
return validate_interface
|
||||
|
||||
|
||||
def _jsonize_array(value: Any) -> Union[list, dict]:
|
||||
"""Use an interface class to render an array as JSON"""
|
||||
interface_cls = Interface.match_output(value)
|
||||
return interface_cls.to_json(value)
|
||||
|
||||
|
||||
def coerce_list(value: Any) -> np.ndarray:
|
||||
"""
|
||||
If a value is passed as a list or list of lists, try and coerce it into an array
|
||||
rather than failing validation.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
value = np.array(value)
|
||||
return value
|
|
@ -2,6 +2,7 @@ import sys
|
|||
import pytest
|
||||
|
||||
from numpydantic import NDArray
|
||||
from numpydantic.meta import update_ndarray_stub
|
||||
|
||||
if sys.version_info.minor < 11:
|
||||
from typing_extensions import reveal_type
|
||||
|
@ -9,6 +10,15 @@ else:
|
|||
from typing import reveal_type
|
||||
|
||||
|
||||
def test_no_warn(recwarn):
|
||||
"""
|
||||
If something is going wrong with generating meta stubs, a warning will be emitted.
|
||||
that is bad.
|
||||
"""
|
||||
update_ndarray_stub()
|
||||
assert len(recwarn) == 0
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO")
|
||||
def test_generate_stub():
|
||||
"""
|
||||
|
|
|
@ -160,12 +160,26 @@ def test_json_schema_dtype_single(dtype, array_model):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_json_schema_dtype_builtin(dtype):
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,expected",
|
||||
[
|
||||
(dtype.Integer, "integer"),
|
||||
(dtype.Float, "number"),
|
||||
(dtype.Bool, "boolean"),
|
||||
(int, "integer"),
|
||||
(float, "number"),
|
||||
(bool, "boolean"),
|
||||
],
|
||||
)
|
||||
def test_json_schema_dtype_builtin(dtype, expected, array_model):
|
||||
"""
|
||||
Using builtin or generic (eg. `dtype.Integer` ) dtypes should
|
||||
make a simple json schema without mins/maxes/dtypes.
|
||||
"""
|
||||
model = array_model(dtype=dtype)
|
||||
schema = model.model_json_schema()
|
||||
inner_type = schema["properties"]["array"]["items"]["items"]
|
||||
assert inner_type["type"] == expected
|
||||
|
||||
|
||||
@pytest.mark.skip("Not implemented yet")
|
||||
|
|
Loading…
Reference in a new issue