diff --git a/docs/todo.md b/docs/todo.md index 9626352..a790241 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -1,5 +1,24 @@ # TODO +## Validation + +```{todo} +Support pydantic value/range constraints - less than, greater than, etc. +``` + +```{todo} +Support different precision modes - eg. exact precision, or minimum precision +where specifying a float32 would also accept a float64, etc. +``` + +## Metadata + +```{todo} +Use names in nptyping annotations in generated JSON schema metadata +``` + +## All TODOs + ```{todolist} ``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9584aa1..bfb816f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,9 +116,6 @@ ignore = [ fixable = ["ALL"] -[tool.black] -exclude = "tests" - [tool.mypy] plugins = [ "pydantic.mypy" diff --git a/src/numpydantic/dtype.py b/src/numpydantic/dtype.py new file mode 100644 index 0000000..76631a2 --- /dev/null +++ b/src/numpydantic/dtype.py @@ -0,0 +1,115 @@ +""" +Replacement of :mod:`nptyping.typing_` + +In the transition away from using nptyping, we want to allow for greater +control of dtype specifications - like different precision modes, etc. +and allow for abstract specifications of dtype that can be checked across +interfaces. + +This module also allows for convenient access to all abstract dtypes in a single +module, rather than needing to import each individually. + +Some types like :ref:`Integer` are compound types - tuples of multiple dtypes. +Check these using ``in`` rather than ``==``. This interface will develop in future +versions to allow a single dtype check. +""" + +from typing import Tuple, TypeAlias, Union + +import numpy as np + +ShapeExpression: TypeAlias = str +StructureExpression: TypeAlias = str +DType: TypeAlias = Union[np.generic, StructureExpression, Tuple["DType"]] +ShapeTuple: TypeAlias = Tuple[int, ...] + +Bool = np.bool_ +Obj = np.object_ # Obj is a common abbreviation and should be usable. +Object = np.object_ +Datetime64 = np.datetime64 +Inexact = np.inexact + +Int8 = np.int8 +Int16 = np.int16 +Int32 = np.int32 +Int64 = np.int64 +Byte = np.byte +Short = np.short +IntC = np.intc +IntP = np.intp +Int_ = np.int_ +UInt8 = np.uint8 +UInt16 = np.uint16 +UInt32 = np.uint32 +UInt64 = np.uint64 +UByte = np.ubyte +UShort = np.ushort +UIntC = np.uintc +UIntP = np.uintp +UInt = np.uint +ULongLong = np.ulonglong +LongLong = np.longlong +Timedelta64 = np.timedelta64 +SignedInteger = (np.int8, np.int16, np.int32, np.int64, np.short) +UnsignedInteger = (np.uint8, np.uint16, np.uint32, np.uint64, np.ushort) +Integer = tuple([np.integer, *SignedInteger, *UnsignedInteger]) +Int = Integer # Int should translate to the "generic" int type. + +Float16 = np.float16 +Float32 = np.float32 +Float64 = np.float64 +Half = np.half +Single = np.single +Double = np.double +LongDouble = np.longdouble +LongFloat = np.longfloat +Float = ( + np.float_, + np.float16, + np.float32, + np.float64, + np.floating, + np.single, + np.double, +) +Floating = Float + +ComplexFloating = np.complexfloating +Complex64 = np.complex64 +Complex128 = np.complex128 +CSingle = np.csingle +SingleComplex = np.singlecomplex +CDouble = np.cdouble +CFloat = np.cfloat +CLongDouble = np.clongdouble +CLongFloat = np.clongfloat +Complex = ( + np.complex_, + np.complexfloating, + np.complex64, + np.complex128, + np.csingle, + np.singlecomplex, + np.cdouble, + np.cfloat, + np.clongdouble, + np.clongfloat, +) + +LongComplex = np.longcomplex +Flexible = np.flexible +Void = np.void +Character = np.character +Bytes = np.bytes_ +Str = np.str_ +String = np.string_ +Unicode = np.unicode_ + +Number = tuple( + [ + np.number, + *Integer, + *Float, + *Complex, + ] +) diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 50991ea..5cb074d 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -56,7 +56,13 @@ class Interface(ABC, Generic[T]): """ if self.dtype is Any: return array - if not array.dtype == self.dtype: + + if isinstance(self.dtype, tuple): + valid = array.dtype in self.dtype + else: + valid = array.dtype == self.dtype + + if not valid: raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}") return array @@ -151,6 +157,7 @@ class Interface(ABC, Generic[T]): msg += "\n".join([f" - {i}" for i in matches]) raise ValueError(msg) elif len(matches) == 0: + pdb.set_trace() raise ValueError(f"No matching interfaces found for input {array}") else: return matches[0] diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index 8e8a2ea..23bee57 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -5,14 +5,14 @@ Interface to zarr arrays import contextlib from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Union, Sequence +from typing import Any, Optional, Sequence, Union from numpydantic.interface.interface import Interface try: + import zarr from zarr.core import Array as ZarrArray from zarr.storage import StoreLike - import zarr except ImportError: ZarrArray = None StoreLike = None @@ -32,11 +32,16 @@ class ZarrArrayPath: path: Optional[str] = None """Path to array within hierarchical zarr store""" - def open(self, **kwargs) -> ZarrArray: + def open(self, **kwargs: dict) -> ZarrArray: + """Open the zarr array at the provided path""" return zarr.open(str(self.file), path=self.path, **kwargs) @classmethod def from_iterable(cls, spec: Sequence) -> "ZarrArrayPath": + """ + Construct a :class:`.ZarrArrayPath` specifier from an iterable, + rather than kwargs + """ if len(spec) == 1: return ZarrArrayPath(file=spec[0]) elif len(spec) == 2: diff --git a/src/numpydantic/maps.py b/src/numpydantic/maps.py index 37edd7f..61a7879 100644 --- a/src/numpydantic/maps.py +++ b/src/numpydantic/maps.py @@ -8,6 +8,8 @@ from typing import Any import numpy as np from nptyping import Bool, Float, Int, String +from numpydantic import dtype as dt + np_to_python = { Any: Any, np.number: float, @@ -17,34 +19,9 @@ np_to_python = { np.byte: bytes, np.bytes_: bytes, np.datetime64: datetime, - **{ - n: int - for n in ( - np.int8, - np.int16, - np.int32, - np.int64, - np.short, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.uint, - ) - }, - **{ - n: float - for n in ( - np.float16, - np.float32, - np.floating, - np.float32, - np.float64, - np.single, - np.double, - np.float_, - ) - }, + **{n: int for n in dt.Integer}, + **{n: float for n in dt.Float}, + **{n: complex for n in dt.Complex}, **{n: str for n in (np.character, np.str_, np.string_, np.unicode_)}, } """Map from python types to numpy""" diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index a5de7d3..13ed7d7 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -10,11 +10,18 @@ from typing import TYPE_CHECKING, Any, Tuple, Union import nptyping.structure import numpy as np from nptyping import Shape +from nptyping.error import InvalidArgumentsError from nptyping.ndarray import NDArrayMeta as _NDArrayMeta from nptyping.nptyping_type import NPTypingType +from nptyping.structure import Structure +from nptyping.structure_expression import check_type_names +from nptyping.typing_ import ( + dtype_per_name, +) from pydantic_core import core_schema from pydantic_core.core_schema import ListSchema +from numpydantic.dtype import DType from numpydantic.interface import Interface from numpydantic.maps import np_to_python @@ -24,12 +31,6 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType if TYPE_CHECKING: from pydantic import ValidationInfo -COMPRESSION_THRESHOLD = 16 * 1024 -""" -Arrays larger than this size (in bytes) will be compressed and b64 encoded when -serializing to JSON. -""" - def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema: """Make a pydantic JSON schema for an array as a list of lists.""" @@ -95,11 +96,40 @@ def coerce_list(value: Any) -> np.ndarray: class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): """ - Kept here to allow for hooking into metaclass, which has - been necessary on and off as we work this class into a stable - state + Hooking into nptyping's array metaclass to override methods pending + completion of the transition away from nptyping """ + def _get_dtype(cls, dtype_candidate: Any) -> DType: + """ + Override of base _get_dtype method to allow for compound tuple types + """ + is_dtype = isinstance(dtype_candidate, type) and issubclass( + dtype_candidate, np.generic + ) + if dtype_candidate is Any: + dtype = Any + elif is_dtype: + dtype = dtype_candidate + elif issubclass(dtype_candidate, Structure): + dtype = dtype_candidate + check_type_names(dtype, dtype_per_name) + elif cls._is_literal_like(dtype_candidate): + structure_expression = dtype_candidate.__args__[0] + dtype = Structure[structure_expression] + check_type_names(dtype, dtype_per_name) + elif isinstance(dtype_candidate, tuple): + dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate]) + else: + raise InvalidArgumentsError( + f"Unexpected argument '{dtype_candidate}', expecting" + " Structure[]" + " or Literal[]" + " or a dtype" + " or typing.Any." + ) + return dtype + class NDArray(NPTypingType, metaclass=NDArrayMeta): """ @@ -134,7 +164,16 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): raise NotImplementedError("Finish handling structured dtypes!") # functools.reduce(operator.or_, [int, float, str]) else: - array_type_handler = _handler.generate_schema(np_to_python[dtype]) + if isinstance(dtype, tuple): + types_ = list(set([np_to_python[dt] for dt in dtype])) + # TODO: better type filtering - explicitly model what + # numeric types are supported by JSON schema + types_ = [t for t in types_ if t not in (complex,)] + schemas = [_handler.generate_schema(dt) for dt in types_] + array_type_handler = core_schema.union_schema(schemas) + + else: + array_type_handler = _handler.generate_schema(np_to_python[dtype]) # get the names of the shape constraints, if any if shape is Any: diff --git a/tests/conftest.py b/tests/conftest.py index dd2b597..b22d54a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,12 @@ +import pdb + import pytest +from typing import Any, Tuple, Union, Type, TypeAlias +from pydantic import BaseModel, computed_field, ConfigDict +from numpydantic import NDArray, Shape +from numpydantic.ndarray import NDArrayMeta +from numpydantic.dtype import Float, Number, Integer +import numpy as np from tests.fixtures import * @@ -9,3 +17,127 @@ def pytest_addoption(parser): action="store_true", help="Keep test outputs in the __tmp__ directory", ) + + +class ValidationCase(BaseModel): + """ + Test case for validating an array. + + Contains both the validating model and the parameterization for an array to + test in a given interface + """ + + annotation: Any = NDArray[Shape["10, 10, *"], Float] + """ + Array annotation used in the validating model + Any typed because the types of type annotations are weird + """ + shape: Tuple[int, ...] = (10, 10, 10) + """Shape of the array to validate""" + dtype: Union[Type, np.dtype] = float + """Dtype of the array to validate""" + passes: bool + """Whether the validation should pass or not""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @computed_field() + def model(self) -> Type[BaseModel]: + """A model with a field ``array`` with the given annotation""" + annotation = self.annotation + + class Model(BaseModel): + array: annotation + + return Model + + +RGB_UNION: TypeAlias = Union[ + NDArray[Shape["* x, * y"], Number], + NDArray[Shape["* x, * y, 3 r_g_b"], Number], + NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number], +] + +NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number] +INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] +FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] + + +@pytest.fixture( + scope="session", + params=[ + ValidationCase(shape=(10, 10, 10), passes=True), + ValidationCase(shape=(10, 10), passes=False), + ValidationCase(shape=(10, 10, 10, 10), passes=False), + ValidationCase(shape=(11, 10, 10), passes=False), + ValidationCase(shape=(9, 10, 10), passes=False), + ValidationCase(shape=(10, 10, 9), passes=True), + ValidationCase(shape=(10, 10, 11), passes=True), + ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True), + ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True), + ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True), + ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False), + ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False), + ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False), + ], + ids=[ + "valid shape", + "missing dimension", + "extra dimension", + "dimension too large", + "dimension too small", + "wildcard smaller", + "wildcard larger", + "Union 2D", + "Union 3D", + "Union 4D", + "Union incorrect 3D", + "Union incorrect 4D", + "Union incorrect both", + ], +) +def shape_cases(request) -> ValidationCase: + return request.param + + +@pytest.fixture( + scope="session", + params=[ + ValidationCase(dtype=float, passes=True), + ValidationCase(dtype=int, passes=False), + ValidationCase(dtype=np.uint8, passes=False), + ValidationCase(annotation=NUMBER, dtype=int, passes=True), + ValidationCase(annotation=NUMBER, dtype=float, passes=True), + ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True), + ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True), + ValidationCase(annotation=NUMBER, dtype=str, passes=False), + ValidationCase(annotation=INTEGER, dtype=int, passes=True), + ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True), + ValidationCase(annotation=INTEGER, dtype=float, passes=False), + ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False), + ValidationCase(annotation=FLOAT, dtype=float, passes=True), + ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True), + ValidationCase(annotation=FLOAT, dtype=int, passes=False), + ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False), + ], + ids=[ + "float", + "int", + "uint8", + "number-int", + "number-float", + "number-uint8", + "number-float16", + "number-str", + "integer-int", + "integer-uint8", + "integer-float", + "integer-float32", + "float-float", + "float-float32", + "float-int", + "float-uint8", + ], +) +def dtype_cases(request) -> ValidationCase: + return request.param diff --git a/tests/fixtures.py b/tests/fixtures.py index 138ccb4..6364962 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -5,7 +5,6 @@ from typing import Callable, Optional, Tuple, Type, Union import h5py import numpy as np import pytest -from nptyping import Number from pydantic import BaseModel, Field import zarr @@ -13,6 +12,7 @@ from numpydantic.interface.hdf5 import H5ArrayPath from numpydantic.interface.zarr import ZarrArrayPath from numpydantic import NDArray, Shape from numpydantic.maps import python_to_nptyping +from numpydantic.dtype import Number @pytest.fixture(scope="session") diff --git a/tests/test_interface/conftest.py b/tests/test_interface/conftest.py index 85f1c7e..b2021c3 100644 --- a/tests/test_interface/conftest.py +++ b/tests/test_interface/conftest.py @@ -30,4 +30,8 @@ from tests.fixtures import hdf5_array, zarr_nested_array, zarr_array ], ) def interface_type(request): + """ + Test cases for each interface's ``check`` method - each input should match the + provided interface and that interface only + """ return request.param diff --git a/tests/test_interface/test_numpy.py b/tests/test_interface/test_numpy.py index e69de29..0ffea50 100644 --- a/tests/test_interface/test_numpy.py +++ b/tests/test_interface/test_numpy.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest +from pydantic import ValidationError +from numpydantic.exceptions import DtypeError, ShapeError + +from tests.conftest import ValidationCase + + +def numpy_array(case: ValidationCase) -> np.ndarray: + return np.zeros(shape=case.shape, dtype=case.dtype) + + +def _test_np_case(case: ValidationCase): + array = numpy_array(case) + if case.passes: + case.model(array=array) + else: + with pytest.raises((ValidationError, DtypeError, ShapeError)): + case.model(array=array) + + +def test_numpy_shape(shape_cases): + _test_np_case(shape_cases) + + +def test_numpy_dtype(dtype_cases): + _test_np_case(dtype_cases)