From 49a51563d6c5b3aa706db29c16b5b6580eb427d3 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 15 May 2024 22:56:02 -0700 Subject: [PATCH] move helper methods to `schema` module and out of `ndarray` --- src/numpydantic/dtype.py | 1 - src/numpydantic/meta.py | 37 +++++-- src/numpydantic/ndarray.py | 196 +++++-------------------------------- src/numpydantic/schema.py | 189 +++++++++++++++++++++++++++++++++++ tests/test_meta.py | 10 ++ tests/test_ndarray.py | 18 +++- 6 files changed, 269 insertions(+), 182 deletions(-) create mode 100644 src/numpydantic/schema.py diff --git a/src/numpydantic/dtype.py b/src/numpydantic/dtype.py index 5a292c5..bbe91c3 100644 --- a/src/numpydantic/dtype.py +++ b/src/numpydantic/dtype.py @@ -112,7 +112,6 @@ Unicode = np.unicode_ Number = tuple( [ - np.number, *Integer, *Float, *Complex, diff --git a/src/numpydantic/meta.py b/src/numpydantic/meta.py index d42b9b3..30d8593 100644 --- a/src/numpydantic/meta.py +++ b/src/numpydantic/meta.py @@ -15,19 +15,36 @@ def generate_ndarray_stub() -> str: Make a stub file based on the array interfaces that are available """ - import_strings = [ - f"from {arr.__module__} import {arr.__name__}" - for arr in Interface.input_types() - if arr.__module__ != "builtins" - ] + import_strings = [] + type_names = [] + for arr in Interface.input_types(): + if arr.__module__ == "builtins": + continue + + # Create import statements, saving aliased name of type if needed + if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing": + type_name = arr.__name__ + import_strings.append(f"from {arr.__module__} import {arr.__name__}") + else: + # since other packages could use the same name for an imported object + # (eg dask and zarr both use an Array class) + # we make an import alias from the module names to differentiate them + # in the type annotation + mod_name = "".join([a.capitalize() for a in arr.__module__.split(".")]) + type_name = mod_name + arr.__name__ + import_strings.append( + f"from {arr.__module__} import {arr.__name__} " f"as {type_name}" + ) + + if arr.__module__ != "typing": + type_names.append(type_name) + else: + type_names.append(str(arr)) + import_strings.extend(_BUILTIN_IMPORTS) import_string = "\n".join(import_strings) - class_names = [ - arr.__name__ if arr.__module__ != "typing" else str(arr) - for arr in Interface.input_types() - ] - class_union = " | ".join(class_names) + class_union = " | ".join(type_names) ndarray_type = "NDArray = " + class_union stub_string = "\n".join([import_string, ndarray_type]) diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index 2ae1bcd..3f205b1 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -1,16 +1,21 @@ """ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization -* Order to store data in (row first) +.. note:: + + This module should *only* have the :class:`.NDArray` class in it, because the + type stub ``ndarray.pyi`` is only created for :class:`.NDArray` . Otherwise, + type checkers will complain about using any helper functions elsewhere - + those all belong in :mod:`numpydantic.schema` . + + Keeping with nptyping's style, NDArrayMeta is in this module even if it's + excluded from the type stub. + """ -import pdb -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, Tuple, Union +from typing import TYPE_CHECKING, Any, Tuple -import nptyping.structure import numpy as np -from nptyping import Shape from nptyping.error import InvalidArgumentsError from nptyping.ndarray import NDArrayMeta as _NDArrayMeta from nptyping.nptyping_type import NPTypingType @@ -21,176 +26,26 @@ from nptyping.typing_ import ( ) from pydantic import GetJsonSchemaHandler from pydantic_core import core_schema -from pydantic_core.core_schema import CoreSchema, ListSchema -from numpydantic import dtype as dt from numpydantic.dtype import DType -from numpydantic.interface import Interface -from numpydantic.maps import np_to_python -from numpydantic.types import DtypeType, NDArrayType, ShapeType +from numpydantic.maps import python_to_nptyping +from numpydantic.schema import ( + _handler_type, + _jsonize_array, + coerce_list, + get_validate_interface, + make_json_schema, +) +from numpydantic.types import DtypeType, ShapeType if TYPE_CHECKING: # pragma: no cover - from pydantic import ValidationInfo + pass -_handler_type = Callable[[Any], core_schema.CoreSchema] - -_UNSUPPORTED_TYPES = (complex,) """ python types that pydantic/json schema can't support (and Any will be used instead) """ -def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: - """Make a numeric dtype that respects min/max values from extended numpy types""" - if dtype.__module__ == "builtins": - metadata = None - else: - metadata = {"dtype": ".".join([dtype.__module__, dtype.__name__])} - - if issubclass(dtype, np.floating): - info = np.finfo(dtype) - schema = core_schema.float_schema(le=float(info.max), ge=float(info.min)) - elif issubclass(dtype, np.integer): - info = np.iinfo(dtype) - schema = core_schema.int_schema(le=int(info.max), ge=int(info.min)) - - else: - schema = _handler.generate_schema(dtype, metadata=metadata) - - return schema - - -def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: - """Get the innermost dtype schema to use in the generated pydantic schema""" - - if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover - raise NotImplementedError("Structured dtypes are currently unsupported") - - if isinstance(dtype, tuple): - # if it's a meta-type that refers to a generic float/int, just make that - if dtype == dt.Float: - array_type = core_schema.float_schema() - elif dtype == dt.Integer: - array_type = core_schema.int_schema() - elif dtype == dt.Complex: - array_type = core_schema.any_schema() - else: - # make a union of dtypes recursively - types_ = list(set(dtype)) - array_type = core_schema.union_schema( - [_lol_dtype(t, _handler) for t in types_] - ) - - else: - try: - python_type = np_to_python[dtype] - except KeyError as e: - if dtype in np_to_python.values(): - # it's already a python type - python_type = dtype - else: - raise ValueError( - "dtype given in model does not have a corresponding python base type - add one to the `maps.np_to_python` dict" - ) from e - - if python_type in _UNSUPPORTED_TYPES: - array_type = core_schema.any_schema() - # TODO: warn and log here - elif python_type in (float, int): - array_type = _numeric_dtype(dtype, _handler) - else: - array_type = _handler.generate_schema(python_type) - - return array_type - - -def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema: - """Make a pydantic JSON schema for an array as a list of lists.""" - - shape_parts = shape.__args__[0].split(",") - split_parts = [ - p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts - ] - - # Construct a list of list schema - # go in reverse order - construct list schemas such that - # the final schema is the one that checks the first dimension - shape_labels = reversed(split_parts) - shape_args = reversed(shape.prepared_args) - list_schema = None - for arg, label in zip(shape_args, shape_labels): - # which handler to use? for the first we use the actual type - # handler, everywhere else we use the prior list handler - inner_schema = array_type_handler if list_schema is None else list_schema - - # make a label annotation, if we have one - metadata = {"name": label} if label is not None else None - - # make the current level list schema, accounting for shape - if arg == "*": - list_schema = core_schema.list_schema(inner_schema, metadata=metadata) - else: - arg = int(arg) - list_schema = core_schema.list_schema( - inner_schema, min_length=arg, max_length=arg, metadata=metadata - ) - return list_schema - - -def make_json_schema( - shape: ShapeType, dtype: DtypeType, _handler: _handler_type -) -> ListSchema: - """ - - Args: - shape: - dtype: - _handler: - - Returns: - - """ - dtype_schema = _lol_dtype(dtype, _handler) - - # get the names of the shape constraints, if any - if shape is Any: - list_schema = core_schema.list_schema(core_schema.any_schema()) - else: - list_schema = list_of_lists_schema(shape, dtype_schema) - - return list_schema - - -def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: - """ - Validate using a matching :class:`.Interface` class using its - :meth:`.Interface.validate` method - """ - - def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType: - interface_cls = Interface.match(value) - interface = interface_cls(shape, dtype) - value = interface.validate(value) - return value - - return validate_interface - - -def _jsonize_array(value: Any) -> Union[list, dict]: - interface_cls = Interface.match_output(value) - return interface_cls.to_json(value) - - -def coerce_list(value: Any) -> np.ndarray: - """ - If a value is passed as a list or list of lists, try and coerce it into an array - rather than failing validation. - """ - if isinstance(value, list): - value = np.array(value) - return value - - class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): """ Hooking into nptyping's array metaclass to override methods pending @@ -201,9 +56,12 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): """ Override of base _get_dtype method to allow for compound tuple types """ + if dtype_candidate in python_to_nptyping: + dtype_candidate = python_to_nptyping[dtype_candidate] is_dtype = isinstance(dtype_candidate, type) and issubclass( dtype_candidate, np.generic ) + if dtype_candidate is Any: dtype = Any elif is_dtype: @@ -265,7 +123,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): [ core_schema.no_info_plain_validator_function(coerce_list), core_schema.with_info_plain_validator_function( - _get_validate_interface(shape, dtype) + get_validate_interface(shape, dtype) ), ] ), @@ -277,12 +135,12 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): @classmethod def __get_pydantic_json_schema__( cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ): + ) -> core_schema.JsonSchema: json_schema = handler(schema) json_schema = handler.resolve_ref_schema(json_schema) dtype = cls.__args__[1] - if dtype.__module__ != "builtins": + if not isinstance(dtype, tuple) and dtype.__module__ != "builtins": json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__]) return json_schema diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py new file mode 100644 index 0000000..568ab5d --- /dev/null +++ b/src/numpydantic/schema.py @@ -0,0 +1,189 @@ +""" +Helper functions for use with :class:`~numpydantic.NDArray` - see the note in +:mod:`~numpydantic.ndarray` for why these are separated. +""" + +from typing import Any, Callable, Union + +import nptyping.structure +import numpy as np +from nptyping import Shape +from pydantic_core import CoreSchema, core_schema +from pydantic_core.core_schema import ListSchema, ValidationInfo + +from numpydantic import dtype as dt +from numpydantic.interface import Interface +from numpydantic.maps import np_to_python +from numpydantic.types import DtypeType, NDArrayType, ShapeType + +_handler_type = Callable[[Any], core_schema.CoreSchema] +_UNSUPPORTED_TYPES = (complex,) + + +def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: + """Make a numeric dtype that respects min/max values from extended numpy types""" + if dtype in (np.number,): + dtype = float + + if issubclass(dtype, np.floating): + info = np.finfo(dtype) + schema = core_schema.float_schema(le=float(info.max), ge=float(info.min)) + elif issubclass(dtype, np.integer): + info = np.iinfo(dtype) + schema = core_schema.int_schema(le=int(info.max), ge=int(info.min)) + + else: + schema = _handler.generate_schema(dtype) + + return schema + + +def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: + """Get the innermost dtype schema to use in the generated pydantic schema""" + + if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover + raise NotImplementedError("Structured dtypes are currently unsupported") + + if isinstance(dtype, tuple): + # if it's a meta-type that refers to a generic float/int, just make that + if dtype == dt.Float: + array_type = core_schema.float_schema() + elif dtype == dt.Integer: + array_type = core_schema.int_schema() + elif dtype == dt.Complex: + array_type = core_schema.any_schema() + else: + # make a union of dtypes recursively + types_ = list(set(dtype)) + array_type = core_schema.union_schema( + [_lol_dtype(t, _handler) for t in types_] + ) + + else: + try: + python_type = np_to_python[dtype] + except KeyError as e: + if dtype in np_to_python.values(): + # it's already a python type + python_type = dtype + else: + raise ValueError( + "dtype given in model does not have a corresponding python base " + "type - add one to the `maps.np_to_python` dict" + ) from e + + if python_type in _UNSUPPORTED_TYPES: + array_type = core_schema.any_schema() + # TODO: warn and log here + elif python_type in (float, int): + array_type = _numeric_dtype(dtype, _handler) + else: + array_type = _handler.generate_schema(python_type) + + return array_type + + +def list_of_lists_schema(shape: Shape, array_type: CoreSchema) -> ListSchema: + """ + Make a pydantic JSON schema for an array as a list of lists. + + For each item in the shape, create a list schema. In the innermost schema + insert the passed ``array_type`` schema. + + This function is typically called from :func:`.make_json_schema` + + Args: + shape (:class:`.Shape` ): Shape determines the depth and max/min elements + for each layer of list schema + array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic + core schema to use in the innermost list entry + """ + + shape_parts = shape.__args__[0].split(",") + split_parts = [ + p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts + ] + + # Construct a list of list schema + # go in reverse order - construct list schemas such that + # the final schema is the one that checks the first dimension + shape_labels = reversed(split_parts) + shape_args = reversed(shape.prepared_args) + list_schema = None + for arg, label in zip(shape_args, shape_labels): + # which handler to use? for the first we use the actual type + # handler, everywhere else we use the prior list handler + inner_schema = array_type if list_schema is None else list_schema + + # make a label annotation, if we have one + metadata = {"name": label} if label is not None else None + + # make the current level list schema, accounting for shape + if arg == "*": + list_schema = core_schema.list_schema(inner_schema, metadata=metadata) + else: + arg = int(arg) + list_schema = core_schema.list_schema( + inner_schema, min_length=arg, max_length=arg, metadata=metadata + ) + return list_schema + + +def make_json_schema( + shape: ShapeType, dtype: DtypeType, _handler: _handler_type +) -> ListSchema: + """ + Make a list of list JSON schema from a shape and a dtype. + + First resolves the dtype into a pydantic ``CoreSchema`` , + and then uses that with :func:`.list_of_lists_schema` . + + Args: + shape ( ShapeType ): Specification of a shape, as a tuple or + an nptyping ``Shape`` + dtype ( DtypeType ): A builtin type or numpy dtype + _handler: The pydantic schema generation handler (see pydantic docs) + + Returns: + :class:`pydantic_core.core_schema.ListSchema` + """ + dtype_schema = _lol_dtype(dtype, _handler) + + # get the names of the shape constraints, if any + if shape is Any: + list_schema = core_schema.list_schema(core_schema.any_schema()) + else: + list_schema = list_of_lists_schema(shape, dtype_schema) + + return list_schema + + +def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: + """ + Validate using a matching :class:`.Interface` class using its + :meth:`.Interface.validate` method + """ + + def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType: + interface_cls = Interface.match(value) + interface = interface_cls(shape, dtype) + value = interface.validate(value) + return value + + return validate_interface + + +def _jsonize_array(value: Any) -> Union[list, dict]: + """Use an interface class to render an array as JSON""" + interface_cls = Interface.match_output(value) + return interface_cls.to_json(value) + + +def coerce_list(value: Any) -> np.ndarray: + """ + If a value is passed as a list or list of lists, try and coerce it into an array + rather than failing validation. + """ + if isinstance(value, list): + value = np.array(value) + return value diff --git a/tests/test_meta.py b/tests/test_meta.py index 380e179..c137633 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -2,6 +2,7 @@ import sys import pytest from numpydantic import NDArray +from numpydantic.meta import update_ndarray_stub if sys.version_info.minor < 11: from typing_extensions import reveal_type @@ -9,6 +10,15 @@ else: from typing import reveal_type +def test_no_warn(recwarn): + """ + If something is going wrong with generating meta stubs, a warning will be emitted. + that is bad. + """ + update_ndarray_stub() + assert len(recwarn) == 0 + + @pytest.mark.skip("TODO") def test_generate_stub(): """ diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 0276001..88fe55b 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -160,12 +160,26 @@ def test_json_schema_dtype_single(dtype, array_model): ) -@pytest.mark.skip() -def test_json_schema_dtype_builtin(dtype): +@pytest.mark.parametrize( + "dtype,expected", + [ + (dtype.Integer, "integer"), + (dtype.Float, "number"), + (dtype.Bool, "boolean"), + (int, "integer"), + (float, "number"), + (bool, "boolean"), + ], +) +def test_json_schema_dtype_builtin(dtype, expected, array_model): """ Using builtin or generic (eg. `dtype.Integer` ) dtypes should make a simple json schema without mins/maxes/dtypes. """ + model = array_model(dtype=dtype) + schema = model.model_json_schema() + inner_type = schema["properties"]["array"]["items"]["items"] + assert inner_type["type"] == expected @pytest.mark.skip("Not implemented yet")