move helper methods to schema module and out of ndarray

This commit is contained in:
sneakers-the-rat 2024-05-15 22:56:02 -07:00
parent 060de62334
commit 49a51563d6
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
6 changed files with 269 additions and 182 deletions

View file

@ -112,7 +112,6 @@ Unicode = np.unicode_
Number = tuple(
[
np.number,
*Integer,
*Float,
*Complex,

View file

@ -15,19 +15,36 @@ def generate_ndarray_stub() -> str:
Make a stub file based on the array interfaces that are available
"""
import_strings = [
f"from {arr.__module__} import {arr.__name__}"
for arr in Interface.input_types()
if arr.__module__ != "builtins"
]
import_strings = []
type_names = []
for arr in Interface.input_types():
if arr.__module__ == "builtins":
continue
# Create import statements, saving aliased name of type if needed
if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing":
type_name = arr.__name__
import_strings.append(f"from {arr.__module__} import {arr.__name__}")
else:
# since other packages could use the same name for an imported object
# (eg dask and zarr both use an Array class)
# we make an import alias from the module names to differentiate them
# in the type annotation
mod_name = "".join([a.capitalize() for a in arr.__module__.split(".")])
type_name = mod_name + arr.__name__
import_strings.append(
f"from {arr.__module__} import {arr.__name__} " f"as {type_name}"
)
if arr.__module__ != "typing":
type_names.append(type_name)
else:
type_names.append(str(arr))
import_strings.extend(_BUILTIN_IMPORTS)
import_string = "\n".join(import_strings)
class_names = [
arr.__name__ if arr.__module__ != "typing" else str(arr)
for arr in Interface.input_types()
]
class_union = " | ".join(class_names)
class_union = " | ".join(type_names)
ndarray_type = "NDArray = " + class_union
stub_string = "\n".join([import_string, ndarray_type])

View file

@ -1,16 +1,21 @@
"""
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization
* Order to store data in (row first)
.. note::
This module should *only* have the :class:`.NDArray` class in it, because the
type stub ``ndarray.pyi`` is only created for :class:`.NDArray` . Otherwise,
type checkers will complain about using any helper functions elsewhere -
those all belong in :mod:`numpydantic.schema` .
Keeping with nptyping's style, NDArrayMeta is in this module even if it's
excluded from the type stub.
"""
import pdb
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Tuple, Union
from typing import TYPE_CHECKING, Any, Tuple
import nptyping.structure
import numpy as np
from nptyping import Shape
from nptyping.error import InvalidArgumentsError
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
from nptyping.nptyping_type import NPTypingType
@ -21,176 +26,26 @@ from nptyping.typing_ import (
)
from pydantic import GetJsonSchemaHandler
from pydantic_core import core_schema
from pydantic_core.core_schema import CoreSchema, ListSchema
from numpydantic import dtype as dt
from numpydantic.dtype import DType
from numpydantic.interface import Interface
from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.maps import python_to_nptyping
from numpydantic.schema import (
_handler_type,
_jsonize_array,
coerce_list,
get_validate_interface,
make_json_schema,
)
from numpydantic.types import DtypeType, ShapeType
if TYPE_CHECKING: # pragma: no cover
from pydantic import ValidationInfo
pass
_handler_type = Callable[[Any], core_schema.CoreSchema]
_UNSUPPORTED_TYPES = (complex,)
"""
python types that pydantic/json schema can't support (and Any will be used instead)
"""
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Make a numeric dtype that respects min/max values from extended numpy types"""
if dtype.__module__ == "builtins":
metadata = None
else:
metadata = {"dtype": ".".join([dtype.__module__, dtype.__name__])}
if issubclass(dtype, np.floating):
info = np.finfo(dtype)
schema = core_schema.float_schema(le=float(info.max), ge=float(info.min))
elif issubclass(dtype, np.integer):
info = np.iinfo(dtype)
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
else:
schema = _handler.generate_schema(dtype, metadata=metadata)
return schema
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported")
if isinstance(dtype, tuple):
# if it's a meta-type that refers to a generic float/int, just make that
if dtype == dt.Float:
array_type = core_schema.float_schema()
elif dtype == dt.Integer:
array_type = core_schema.int_schema()
elif dtype == dt.Complex:
array_type = core_schema.any_schema()
else:
# make a union of dtypes recursively
types_ = list(set(dtype))
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in types_]
)
else:
try:
python_type = np_to_python[dtype]
except KeyError as e:
if dtype in np_to_python.values():
# it's already a python type
python_type = dtype
else:
raise ValueError(
"dtype given in model does not have a corresponding python base type - add one to the `maps.np_to_python` dict"
) from e
if python_type in _UNSUPPORTED_TYPES:
array_type = core_schema.any_schema()
# TODO: warn and log here
elif python_type in (float, int):
array_type = _numeric_dtype(dtype, _handler)
else:
array_type = _handler.generate_schema(python_type)
return array_type
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
"""Make a pydantic JSON schema for an array as a list of lists."""
shape_parts = shape.__args__[0].split(",")
split_parts = [
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
]
# Construct a list of list schema
# go in reverse order - construct list schemas such that
# the final schema is the one that checks the first dimension
shape_labels = reversed(split_parts)
shape_args = reversed(shape.prepared_args)
list_schema = None
for arg, label in zip(shape_args, shape_labels):
# which handler to use? for the first we use the actual type
# handler, everywhere else we use the prior list handler
inner_schema = array_type_handler if list_schema is None else list_schema
# make a label annotation, if we have one
metadata = {"name": label} if label is not None else None
# make the current level list schema, accounting for shape
if arg == "*":
list_schema = core_schema.list_schema(inner_schema, metadata=metadata)
else:
arg = int(arg)
list_schema = core_schema.list_schema(
inner_schema, min_length=arg, max_length=arg, metadata=metadata
)
return list_schema
def make_json_schema(
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
) -> ListSchema:
"""
Args:
shape:
dtype:
_handler:
Returns:
"""
dtype_schema = _lol_dtype(dtype, _handler)
# get the names of the shape constraints, if any
if shape is Any:
list_schema = core_schema.list_schema(core_schema.any_schema())
else:
list_schema = list_of_lists_schema(shape, dtype_schema)
return list_schema
def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
"""
Validate using a matching :class:`.Interface` class using its
:meth:`.Interface.validate` method
"""
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType:
interface_cls = Interface.match(value)
interface = interface_cls(shape, dtype)
value = interface.validate(value)
return value
return validate_interface
def _jsonize_array(value: Any) -> Union[list, dict]:
interface_cls = Interface.match_output(value)
return interface_cls.to_json(value)
def coerce_list(value: Any) -> np.ndarray:
"""
If a value is passed as a list or list of lists, try and coerce it into an array
rather than failing validation.
"""
if isinstance(value, list):
value = np.array(value)
return value
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
"""
Hooking into nptyping's array metaclass to override methods pending
@ -201,9 +56,12 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
"""
Override of base _get_dtype method to allow for compound tuple types
"""
if dtype_candidate in python_to_nptyping:
dtype_candidate = python_to_nptyping[dtype_candidate]
is_dtype = isinstance(dtype_candidate, type) and issubclass(
dtype_candidate, np.generic
)
if dtype_candidate is Any:
dtype = Any
elif is_dtype:
@ -265,7 +123,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
[
core_schema.no_info_plain_validator_function(coerce_list),
core_schema.with_info_plain_validator_function(
_get_validate_interface(shape, dtype)
get_validate_interface(shape, dtype)
),
]
),
@ -277,12 +135,12 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
@classmethod
def __get_pydantic_json_schema__(
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
):
) -> core_schema.JsonSchema:
json_schema = handler(schema)
json_schema = handler.resolve_ref_schema(json_schema)
dtype = cls.__args__[1]
if dtype.__module__ != "builtins":
if not isinstance(dtype, tuple) and dtype.__module__ != "builtins":
json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__])
return json_schema

189
src/numpydantic/schema.py Normal file
View file

@ -0,0 +1,189 @@
"""
Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
:mod:`~numpydantic.ndarray` for why these are separated.
"""
from typing import Any, Callable, Union
import nptyping.structure
import numpy as np
from nptyping import Shape
from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo
from numpydantic import dtype as dt
from numpydantic.interface import Interface
from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType
_handler_type = Callable[[Any], core_schema.CoreSchema]
_UNSUPPORTED_TYPES = (complex,)
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Make a numeric dtype that respects min/max values from extended numpy types"""
if dtype in (np.number,):
dtype = float
if issubclass(dtype, np.floating):
info = np.finfo(dtype)
schema = core_schema.float_schema(le=float(info.max), ge=float(info.min))
elif issubclass(dtype, np.integer):
info = np.iinfo(dtype)
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
else:
schema = _handler.generate_schema(dtype)
return schema
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported")
if isinstance(dtype, tuple):
# if it's a meta-type that refers to a generic float/int, just make that
if dtype == dt.Float:
array_type = core_schema.float_schema()
elif dtype == dt.Integer:
array_type = core_schema.int_schema()
elif dtype == dt.Complex:
array_type = core_schema.any_schema()
else:
# make a union of dtypes recursively
types_ = list(set(dtype))
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in types_]
)
else:
try:
python_type = np_to_python[dtype]
except KeyError as e:
if dtype in np_to_python.values():
# it's already a python type
python_type = dtype
else:
raise ValueError(
"dtype given in model does not have a corresponding python base "
"type - add one to the `maps.np_to_python` dict"
) from e
if python_type in _UNSUPPORTED_TYPES:
array_type = core_schema.any_schema()
# TODO: warn and log here
elif python_type in (float, int):
array_type = _numeric_dtype(dtype, _handler)
else:
array_type = _handler.generate_schema(python_type)
return array_type
def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema:
"""
Make a pydantic JSON schema for an array as a list of lists.
For each item in the shape, create a list schema. In the innermost schema
insert the passed ``array_type`` schema.
This function is typically called from :func:`.make_json_schema`
Args:
shape (:class:`.Shape` ): Shape determines the depth and max/min elements
for each layer of list schema
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
core schema to use in the innermost list entry
"""
shape_parts = shape.__args__[0].split(",")
split_parts = [
p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts
]
# Construct a list of list schema
# go in reverse order - construct list schemas such that
# the final schema is the one that checks the first dimension
shape_labels = reversed(split_parts)
shape_args = reversed(shape.prepared_args)
list_schema = None
for arg, label in zip(shape_args, shape_labels):
# which handler to use? for the first we use the actual type
# handler, everywhere else we use the prior list handler
inner_schema = array_type if list_schema is None else list_schema
# make a label annotation, if we have one
metadata = {"name": label} if label is not None else None
# make the current level list schema, accounting for shape
if arg == "*":
list_schema = core_schema.list_schema(inner_schema, metadata=metadata)
else:
arg = int(arg)
list_schema = core_schema.list_schema(
inner_schema, min_length=arg, max_length=arg, metadata=metadata
)
return list_schema
def make_json_schema(
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
) -> ListSchema:
"""
Make a list of list JSON schema from a shape and a dtype.
First resolves the dtype into a pydantic ``CoreSchema`` ,
and then uses that with :func:`.list_of_lists_schema` .
Args:
shape ( ShapeType ): Specification of a shape, as a tuple or
an nptyping ``Shape``
dtype ( DtypeType ): A builtin type or numpy dtype
_handler: The pydantic schema generation handler (see pydantic docs)
Returns:
:class:`pydantic_core.core_schema.ListSchema`
"""
dtype_schema = _lol_dtype(dtype, _handler)
# get the names of the shape constraints, if any
if shape is Any:
list_schema = core_schema.list_schema(core_schema.any_schema())
else:
list_schema = list_of_lists_schema(shape, dtype_schema)
return list_schema
def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
"""
Validate using a matching :class:`.Interface` class using its
:meth:`.Interface.validate` method
"""
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType:
interface_cls = Interface.match(value)
interface = interface_cls(shape, dtype)
value = interface.validate(value)
return value
return validate_interface
def _jsonize_array(value: Any) -> Union[list, dict]:
"""Use an interface class to render an array as JSON"""
interface_cls = Interface.match_output(value)
return interface_cls.to_json(value)
def coerce_list(value: Any) -> np.ndarray:
"""
If a value is passed as a list or list of lists, try and coerce it into an array
rather than failing validation.
"""
if isinstance(value, list):
value = np.array(value)
return value

View file

@ -2,6 +2,7 @@ import sys
import pytest
from numpydantic import NDArray
from numpydantic.meta import update_ndarray_stub
if sys.version_info.minor < 11:
from typing_extensions import reveal_type
@ -9,6 +10,15 @@ else:
from typing import reveal_type
def test_no_warn(recwarn):
"""
If something is going wrong with generating meta stubs, a warning will be emitted.
that is bad.
"""
update_ndarray_stub()
assert len(recwarn) == 0
@pytest.mark.skip("TODO")
def test_generate_stub():
"""

View file

@ -160,12 +160,26 @@ def test_json_schema_dtype_single(dtype, array_model):
)
@pytest.mark.skip()
def test_json_schema_dtype_builtin(dtype):
@pytest.mark.parametrize(
"dtype,expected",
[
(dtype.Integer, "integer"),
(dtype.Float, "number"),
(dtype.Bool, "boolean"),
(int, "integer"),
(float, "number"),
(bool, "boolean"),
],
)
def test_json_schema_dtype_builtin(dtype, expected, array_model):
"""
Using builtin or generic (eg. `dtype.Integer` ) dtypes should
make a simple json schema without mins/maxes/dtypes.
"""
model = array_model(dtype=dtype)
schema = model.model_json_schema()
inner_type = schema["properties"]["array"]["items"]["items"]
assert inner_type["type"] == expected
@pytest.mark.skip("Not implemented yet")