diff --git a/src/numpydantic/maps.py b/src/numpydantic/maps.py index 61a7879..7de60c0 100644 --- a/src/numpydantic/maps.py +++ b/src/numpydantic/maps.py @@ -6,7 +6,6 @@ from datetime import datetime from typing import Any import numpy as np -from nptyping import Bool, Float, Int, String from numpydantic import dtype as dt @@ -59,5 +58,5 @@ flat_to_nptyping = { } """Map from NWB-style flat dtypes to nptyping types""" -python_to_nptyping = {float: Float, str: String, int: Int, bool: Bool} +python_to_nptyping = {float: dt.Float, str: dt.String, int: dt.Int, bool: dt.Bool} """Map from python types to nptyping types""" diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index 13ed7d7..2ae1bcd 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -4,6 +4,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa * Order to store data in (row first) """ +import pdb from collections.abc import Callable from typing import TYPE_CHECKING, Any, Tuple, Union @@ -18,22 +19,94 @@ from nptyping.structure_expression import check_type_names from nptyping.typing_ import ( dtype_per_name, ) +from pydantic import GetJsonSchemaHandler from pydantic_core import core_schema -from pydantic_core.core_schema import ListSchema +from pydantic_core.core_schema import CoreSchema, ListSchema +from numpydantic import dtype as dt from numpydantic.dtype import DType from numpydantic.interface import Interface from numpydantic.maps import np_to_python - -# from numpydantic.proxy import NDArrayProxy from numpydantic.types import DtypeType, NDArrayType, ShapeType -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from pydantic import ValidationInfo +_handler_type = Callable[[Any], core_schema.CoreSchema] + +_UNSUPPORTED_TYPES = (complex,) +""" +python types that pydantic/json schema can't support (and Any will be used instead) +""" + + +def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: + """Make a numeric dtype that respects min/max values from extended numpy types""" + if dtype.__module__ == "builtins": + metadata = None + else: + metadata = {"dtype": ".".join([dtype.__module__, dtype.__name__])} + + if issubclass(dtype, np.floating): + info = np.finfo(dtype) + schema = core_schema.float_schema(le=float(info.max), ge=float(info.min)) + elif issubclass(dtype, np.integer): + info = np.iinfo(dtype) + schema = core_schema.int_schema(le=int(info.max), ge=int(info.min)) + + else: + schema = _handler.generate_schema(dtype, metadata=metadata) + + return schema + + +def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: + """Get the innermost dtype schema to use in the generated pydantic schema""" + + if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover + raise NotImplementedError("Structured dtypes are currently unsupported") + + if isinstance(dtype, tuple): + # if it's a meta-type that refers to a generic float/int, just make that + if dtype == dt.Float: + array_type = core_schema.float_schema() + elif dtype == dt.Integer: + array_type = core_schema.int_schema() + elif dtype == dt.Complex: + array_type = core_schema.any_schema() + else: + # make a union of dtypes recursively + types_ = list(set(dtype)) + array_type = core_schema.union_schema( + [_lol_dtype(t, _handler) for t in types_] + ) + + else: + try: + python_type = np_to_python[dtype] + except KeyError as e: + if dtype in np_to_python.values(): + # it's already a python type + python_type = dtype + else: + raise ValueError( + "dtype given in model does not have a corresponding python base type - add one to the `maps.np_to_python` dict" + ) from e + + if python_type in _UNSUPPORTED_TYPES: + array_type = core_schema.any_schema() + # TODO: warn and log here + elif python_type in (float, int): + array_type = _numeric_dtype(dtype, _handler) + else: + array_type = _handler.generate_schema(python_type) + + return array_type + def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema: """Make a pydantic JSON schema for an array as a list of lists.""" + shape_parts = shape.__args__[0].split(",") split_parts = [ p.split(" ")[1] if len(p.split(" ")) == 2 else None for p in shape_parts @@ -64,6 +137,30 @@ def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema: return list_schema +def make_json_schema( + shape: ShapeType, dtype: DtypeType, _handler: _handler_type +) -> ListSchema: + """ + + Args: + shape: + dtype: + _handler: + + Returns: + + """ + dtype_schema = _lol_dtype(dtype, _handler) + + # get the names of the shape constraints, if any + if shape is Any: + list_schema = core_schema.list_schema(core_schema.any_schema()) + else: + list_schema = list_of_lists_schema(shape, dtype_schema) + + return list_schema + + def _get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: """ Validate using a matching :class:`.Interface` class using its @@ -111,16 +208,16 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): dtype = Any elif is_dtype: dtype = dtype_candidate - elif issubclass(dtype_candidate, Structure): + elif issubclass(dtype_candidate, Structure): # pragma: no cover dtype = dtype_candidate check_type_names(dtype, dtype_per_name) - elif cls._is_literal_like(dtype_candidate): + elif cls._is_literal_like(dtype_candidate): # pragma: no cover structure_expression = dtype_candidate.__args__[0] dtype = Structure[structure_expression] check_type_names(dtype, dtype_per_name) - elif isinstance(dtype_candidate, tuple): + elif isinstance(dtype_candidate, tuple): # pragma: no cover dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate]) - else: + else: # pragma: no cover raise InvalidArgumentsError( f"Unexpected argument '{dtype_candidate}', expecting" " Structure[]" @@ -153,33 +250,14 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): def __get_pydantic_core_schema__( cls, _source_type: "NDArray", - _handler: Callable[[Any], core_schema.CoreSchema], + _handler: _handler_type, ) -> core_schema.CoreSchema: shape, dtype = _source_type.__args__ shape: ShapeType dtype: DtypeType - # get pydantic core schema for the given specified type - if isinstance(dtype, nptyping.structure.StructureMeta): - raise NotImplementedError("Finish handling structured dtypes!") - # functools.reduce(operator.or_, [int, float, str]) - else: - if isinstance(dtype, tuple): - types_ = list(set([np_to_python[dt] for dt in dtype])) - # TODO: better type filtering - explicitly model what - # numeric types are supported by JSON schema - types_ = [t for t in types_ if t not in (complex,)] - schemas = [_handler.generate_schema(dt) for dt in types_] - array_type_handler = core_schema.union_schema(schemas) - - else: - array_type_handler = _handler.generate_schema(np_to_python[dtype]) - - # get the names of the shape constraints, if any - if shape is Any: - list_schema = core_schema.list_schema(core_schema.any_schema()) - else: - list_schema = list_of_lists_schema(shape, array_type_handler) + # get pydantic core schema as a list of lists for JSON schema + list_schema = make_json_schema(shape, dtype, _handler) return core_schema.json_or_python_schema( json_schema=list_schema, @@ -195,3 +273,16 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): _jsonize_array, when_used="json" ), ) + + @classmethod + def __get_pydantic_json_schema__( + cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ): + json_schema = handler(schema) + json_schema = handler.resolve_ref_schema(json_schema) + + dtype = cls.__args__[1] + if dtype.__module__ != "builtins": + json_schema["dtype"] = ".".join([dtype.__module__, dtype.__name__]) + + return json_schema diff --git a/tests/fixtures.py b/tests/fixtures.py index 2ae8ff2..e780058 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -64,7 +64,7 @@ def array_model() -> ( shape_str = ", ".join([str(s) for s in shape]) class MyModel(BaseModel): - array: NDArray[Shape[shape_str], python_to_nptyping[dtype]] + array: NDArray[Shape[shape_str], dtype] return MyModel diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 0305371..0276001 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -1,3 +1,5 @@ +import pdb + import pytest from typing import Union, Optional, Any @@ -9,6 +11,7 @@ from nptyping import Shape, Number from numpydantic import NDArray from numpydantic.exceptions import ShapeError, DtypeError +from numpydantic import dtype # from .fixtures import tmp_output_dir_func @@ -99,23 +102,83 @@ def test_ndarray_serialize(): assert isinstance(mod_dict["array"], np.ndarray) -# def test_ndarray_proxy(tmp_output_dir_func): -# h5f_source = tmp_output_dir_func / 'test.h5' -# with h5py.File(h5f_source, 'w') as h5f: -# dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3))) -# dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4))) -# -# class Model(BaseModel): -# array: NDArray[Shape["* x, * y, 3 z"], Number] -# -# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data')) -# subarray = mod.array[0:5, 0:5, :] -# assert isinstance(subarray, np.ndarray) -# assert isinstance(subarray.sum(), float) -# assert mod.array.name == '/data' -# -# with pytest.raises(NotImplementedError): -# mod.array[0] = 5 -# -# with pytest.raises(ValidationError): -# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad')) +_json_schema_types = [ + *[(t, float) for t in dtype.Float], + *[(t, int) for t in dtype.Integer], +] + + +def test_json_schema_basic(array_model): + """ + NDArray types should correctly generate a list of lists JSON schema + """ + shape = (15, 10) + dtype = float + model = array_model(shape, dtype) + schema = model.model_json_schema() + field = schema["properties"]["array"] + + # outer shape + assert field["maxItems"] == shape[0] + assert field["minItems"] == shape[0] + assert field["type"] == "array" + + # inner shape + inner = field["items"] + assert inner["minItems"] == shape[1] + assert inner["maxItems"] == shape[1] + assert inner["items"]["type"] == "number" + + +@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float]) +def test_json_schema_dtype_single(dtype, array_model): + """ + dtypes should have correct mins and maxes set, and store the source dtype + """ + if issubclass(dtype, np.floating): + info = np.finfo(dtype) + min_val = info.min + max_val = info.max + schema_type = "number" + elif issubclass(dtype, np.integer): + info = np.iinfo(dtype) + min_val = info.min + max_val = info.max + schema_type = "integer" + else: + raise ValueError("These should all be numpy types!") + + shape = (15, 10) + model = array_model(shape, dtype) + schema = model.model_json_schema() + inner_type = schema["properties"]["array"]["items"]["items"] + assert inner_type["minimum"] == min_val + assert inner_type["maximum"] == max_val + assert inner_type["type"] == schema_type + assert schema["properties"]["array"]["dtype"] == ".".join( + [dtype.__module__, dtype.__name__] + ) + + +@pytest.mark.skip() +def test_json_schema_dtype_builtin(dtype): + """ + Using builtin or generic (eg. `dtype.Integer` ) dtypes should + make a simple json schema without mins/maxes/dtypes. + """ + + +@pytest.mark.skip("Not implemented yet") +def test_json_schema_wildcard(): + """ + NDarray types should generate a JSON schema without shape constraints + """ + pass + + +@pytest.mark.skip("Not implemented yet") +def test_json_schema_ellipsis(): + """ + NDArray types should create a recursive JSON schema for any-shaped arrays + """ + pass