mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
Better test cases, dtype module, compound types
This commit is contained in:
parent
d884055067
commit
b82a49df1b
11 changed files with 368 additions and 46 deletions
19
docs/todo.md
19
docs/todo.md
|
@ -1,5 +1,24 @@
|
|||
# TODO
|
||||
|
||||
## Validation
|
||||
|
||||
```{todo}
|
||||
Support pydantic value/range constraints - less than, greater than, etc.
|
||||
```
|
||||
|
||||
```{todo}
|
||||
Support different precision modes - eg. exact precision, or minimum precision
|
||||
where specifying a float32 would also accept a float64, etc.
|
||||
```
|
||||
|
||||
## Metadata
|
||||
|
||||
```{todo}
|
||||
Use names in nptyping annotations in generated JSON schema metadata
|
||||
```
|
||||
|
||||
## All TODOs
|
||||
|
||||
```{todolist}
|
||||
|
||||
```
|
|
@ -116,9 +116,6 @@ ignore = [
|
|||
|
||||
fixable = ["ALL"]
|
||||
|
||||
[tool.black]
|
||||
exclude = "tests"
|
||||
|
||||
[tool.mypy]
|
||||
plugins = [
|
||||
"pydantic.mypy"
|
||||
|
|
115
src/numpydantic/dtype.py
Normal file
115
src/numpydantic/dtype.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
Replacement of :mod:`nptyping.typing_`
|
||||
|
||||
In the transition away from using nptyping, we want to allow for greater
|
||||
control of dtype specifications - like different precision modes, etc.
|
||||
and allow for abstract specifications of dtype that can be checked across
|
||||
interfaces.
|
||||
|
||||
This module also allows for convenient access to all abstract dtypes in a single
|
||||
module, rather than needing to import each individually.
|
||||
|
||||
Some types like :ref:`Integer` are compound types - tuples of multiple dtypes.
|
||||
Check these using ``in`` rather than ``==``. This interface will develop in future
|
||||
versions to allow a single dtype check.
|
||||
"""
|
||||
|
||||
from typing import Tuple, TypeAlias, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
ShapeExpression: TypeAlias = str
|
||||
StructureExpression: TypeAlias = str
|
||||
DType: TypeAlias = Union[np.generic, StructureExpression, Tuple["DType"]]
|
||||
ShapeTuple: TypeAlias = Tuple[int, ...]
|
||||
|
||||
Bool = np.bool_
|
||||
Obj = np.object_ # Obj is a common abbreviation and should be usable.
|
||||
Object = np.object_
|
||||
Datetime64 = np.datetime64
|
||||
Inexact = np.inexact
|
||||
|
||||
Int8 = np.int8
|
||||
Int16 = np.int16
|
||||
Int32 = np.int32
|
||||
Int64 = np.int64
|
||||
Byte = np.byte
|
||||
Short = np.short
|
||||
IntC = np.intc
|
||||
IntP = np.intp
|
||||
Int_ = np.int_
|
||||
UInt8 = np.uint8
|
||||
UInt16 = np.uint16
|
||||
UInt32 = np.uint32
|
||||
UInt64 = np.uint64
|
||||
UByte = np.ubyte
|
||||
UShort = np.ushort
|
||||
UIntC = np.uintc
|
||||
UIntP = np.uintp
|
||||
UInt = np.uint
|
||||
ULongLong = np.ulonglong
|
||||
LongLong = np.longlong
|
||||
Timedelta64 = np.timedelta64
|
||||
SignedInteger = (np.int8, np.int16, np.int32, np.int64, np.short)
|
||||
UnsignedInteger = (np.uint8, np.uint16, np.uint32, np.uint64, np.ushort)
|
||||
Integer = tuple([np.integer, *SignedInteger, *UnsignedInteger])
|
||||
Int = Integer # Int should translate to the "generic" int type.
|
||||
|
||||
Float16 = np.float16
|
||||
Float32 = np.float32
|
||||
Float64 = np.float64
|
||||
Half = np.half
|
||||
Single = np.single
|
||||
Double = np.double
|
||||
LongDouble = np.longdouble
|
||||
LongFloat = np.longfloat
|
||||
Float = (
|
||||
np.float_,
|
||||
np.float16,
|
||||
np.float32,
|
||||
np.float64,
|
||||
np.floating,
|
||||
np.single,
|
||||
np.double,
|
||||
)
|
||||
Floating = Float
|
||||
|
||||
ComplexFloating = np.complexfloating
|
||||
Complex64 = np.complex64
|
||||
Complex128 = np.complex128
|
||||
CSingle = np.csingle
|
||||
SingleComplex = np.singlecomplex
|
||||
CDouble = np.cdouble
|
||||
CFloat = np.cfloat
|
||||
CLongDouble = np.clongdouble
|
||||
CLongFloat = np.clongfloat
|
||||
Complex = (
|
||||
np.complex_,
|
||||
np.complexfloating,
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
np.csingle,
|
||||
np.singlecomplex,
|
||||
np.cdouble,
|
||||
np.cfloat,
|
||||
np.clongdouble,
|
||||
np.clongfloat,
|
||||
)
|
||||
|
||||
LongComplex = np.longcomplex
|
||||
Flexible = np.flexible
|
||||
Void = np.void
|
||||
Character = np.character
|
||||
Bytes = np.bytes_
|
||||
Str = np.str_
|
||||
String = np.string_
|
||||
Unicode = np.unicode_
|
||||
|
||||
Number = tuple(
|
||||
[
|
||||
np.number,
|
||||
*Integer,
|
||||
*Float,
|
||||
*Complex,
|
||||
]
|
||||
)
|
|
@ -56,7 +56,13 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
if self.dtype is Any:
|
||||
return array
|
||||
if not array.dtype == self.dtype:
|
||||
|
||||
if isinstance(self.dtype, tuple):
|
||||
valid = array.dtype in self.dtype
|
||||
else:
|
||||
valid = array.dtype == self.dtype
|
||||
|
||||
if not valid:
|
||||
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
|
||||
return array
|
||||
|
||||
|
@ -151,6 +157,7 @@ class Interface(ABC, Generic[T]):
|
|||
msg += "\n".join([f" - {i}" for i in matches])
|
||||
raise ValueError(msg)
|
||||
elif len(matches) == 0:
|
||||
pdb.set_trace()
|
||||
raise ValueError(f"No matching interfaces found for input {array}")
|
||||
else:
|
||||
return matches[0]
|
||||
|
|
|
@ -5,14 +5,14 @@ Interface to zarr arrays
|
|||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union, Sequence
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
|
||||
try:
|
||||
import zarr
|
||||
from zarr.core import Array as ZarrArray
|
||||
from zarr.storage import StoreLike
|
||||
import zarr
|
||||
except ImportError:
|
||||
ZarrArray = None
|
||||
StoreLike = None
|
||||
|
@ -32,11 +32,16 @@ class ZarrArrayPath:
|
|||
path: Optional[str] = None
|
||||
"""Path to array within hierarchical zarr store"""
|
||||
|
||||
def open(self, **kwargs) -> ZarrArray:
|
||||
def open(self, **kwargs: dict) -> ZarrArray:
|
||||
"""Open the zarr array at the provided path"""
|
||||
return zarr.open(str(self.file), path=self.path, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(cls, spec: Sequence) -> "ZarrArrayPath":
|
||||
"""
|
||||
Construct a :class:`.ZarrArrayPath` specifier from an iterable,
|
||||
rather than kwargs
|
||||
"""
|
||||
if len(spec) == 1:
|
||||
return ZarrArrayPath(file=spec[0])
|
||||
elif len(spec) == 2:
|
||||
|
|
|
@ -8,6 +8,8 @@ from typing import Any
|
|||
import numpy as np
|
||||
from nptyping import Bool, Float, Int, String
|
||||
|
||||
from numpydantic import dtype as dt
|
||||
|
||||
np_to_python = {
|
||||
Any: Any,
|
||||
np.number: float,
|
||||
|
@ -17,34 +19,9 @@ np_to_python = {
|
|||
np.byte: bytes,
|
||||
np.bytes_: bytes,
|
||||
np.datetime64: datetime,
|
||||
**{
|
||||
n: int
|
||||
for n in (
|
||||
np.int8,
|
||||
np.int16,
|
||||
np.int32,
|
||||
np.int64,
|
||||
np.short,
|
||||
np.uint8,
|
||||
np.uint16,
|
||||
np.uint32,
|
||||
np.uint64,
|
||||
np.uint,
|
||||
)
|
||||
},
|
||||
**{
|
||||
n: float
|
||||
for n in (
|
||||
np.float16,
|
||||
np.float32,
|
||||
np.floating,
|
||||
np.float32,
|
||||
np.float64,
|
||||
np.single,
|
||||
np.double,
|
||||
np.float_,
|
||||
)
|
||||
},
|
||||
**{n: int for n in dt.Integer},
|
||||
**{n: float for n in dt.Float},
|
||||
**{n: complex for n in dt.Complex},
|
||||
**{n: str for n in (np.character, np.str_, np.string_, np.unicode_)},
|
||||
}
|
||||
"""Map from python types to numpy"""
|
||||
|
|
|
@ -10,11 +10,18 @@ from typing import TYPE_CHECKING, Any, Tuple, Union
|
|||
import nptyping.structure
|
||||
import numpy as np
|
||||
from nptyping import Shape
|
||||
from nptyping.error import InvalidArgumentsError
|
||||
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||
from nptyping.nptyping_type import NPTypingType
|
||||
from nptyping.structure import Structure
|
||||
from nptyping.structure_expression import check_type_names
|
||||
from nptyping.typing_ import (
|
||||
dtype_per_name,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
from pydantic_core.core_schema import ListSchema
|
||||
|
||||
from numpydantic.dtype import DType
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import np_to_python
|
||||
|
||||
|
@ -24,12 +31,6 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
|||
if TYPE_CHECKING:
|
||||
from pydantic import ValidationInfo
|
||||
|
||||
COMPRESSION_THRESHOLD = 16 * 1024
|
||||
"""
|
||||
Arrays larger than this size (in bytes) will be compressed and b64 encoded when
|
||||
serializing to JSON.
|
||||
"""
|
||||
|
||||
|
||||
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||
"""Make a pydantic JSON schema for an array as a list of lists."""
|
||||
|
@ -95,11 +96,40 @@ def coerce_list(value: Any) -> np.ndarray:
|
|||
|
||||
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||
"""
|
||||
Kept here to allow for hooking into metaclass, which has
|
||||
been necessary on and off as we work this class into a stable
|
||||
state
|
||||
Hooking into nptyping's array metaclass to override methods pending
|
||||
completion of the transition away from nptyping
|
||||
"""
|
||||
|
||||
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
||||
"""
|
||||
Override of base _get_dtype method to allow for compound tuple types
|
||||
"""
|
||||
is_dtype = isinstance(dtype_candidate, type) and issubclass(
|
||||
dtype_candidate, np.generic
|
||||
)
|
||||
if dtype_candidate is Any:
|
||||
dtype = Any
|
||||
elif is_dtype:
|
||||
dtype = dtype_candidate
|
||||
elif issubclass(dtype_candidate, Structure):
|
||||
dtype = dtype_candidate
|
||||
check_type_names(dtype, dtype_per_name)
|
||||
elif cls._is_literal_like(dtype_candidate):
|
||||
structure_expression = dtype_candidate.__args__[0]
|
||||
dtype = Structure[structure_expression]
|
||||
check_type_names(dtype, dtype_per_name)
|
||||
elif isinstance(dtype_candidate, tuple):
|
||||
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
|
||||
else:
|
||||
raise InvalidArgumentsError(
|
||||
f"Unexpected argument '{dtype_candidate}', expecting"
|
||||
" Structure[<StructureExpression>]"
|
||||
" or Literal[<StructureExpression>]"
|
||||
" or a dtype"
|
||||
" or typing.Any."
|
||||
)
|
||||
return dtype
|
||||
|
||||
|
||||
class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||
"""
|
||||
|
@ -134,7 +164,16 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
raise NotImplementedError("Finish handling structured dtypes!")
|
||||
# functools.reduce(operator.or_, [int, float, str])
|
||||
else:
|
||||
array_type_handler = _handler.generate_schema(np_to_python[dtype])
|
||||
if isinstance(dtype, tuple):
|
||||
types_ = list(set([np_to_python[dt] for dt in dtype]))
|
||||
# TODO: better type filtering - explicitly model what
|
||||
# numeric types are supported by JSON schema
|
||||
types_ = [t for t in types_ if t not in (complex,)]
|
||||
schemas = [_handler.generate_schema(dt) for dt in types_]
|
||||
array_type_handler = core_schema.union_schema(schemas)
|
||||
|
||||
else:
|
||||
array_type_handler = _handler.generate_schema(np_to_python[dtype])
|
||||
|
||||
# get the names of the shape constraints, if any
|
||||
if shape is Any:
|
||||
|
|
|
@ -1,4 +1,12 @@
|
|||
import pdb
|
||||
|
||||
import pytest
|
||||
from typing import Any, Tuple, Union, Type, TypeAlias
|
||||
from pydantic import BaseModel, computed_field, ConfigDict
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.ndarray import NDArrayMeta
|
||||
from numpydantic.dtype import Float, Number, Integer
|
||||
import numpy as np
|
||||
|
||||
from tests.fixtures import *
|
||||
|
||||
|
@ -9,3 +17,127 @@ def pytest_addoption(parser):
|
|||
action="store_true",
|
||||
help="Keep test outputs in the __tmp__ directory",
|
||||
)
|
||||
|
||||
|
||||
class ValidationCase(BaseModel):
|
||||
"""
|
||||
Test case for validating an array.
|
||||
|
||||
Contains both the validating model and the parameterization for an array to
|
||||
test in a given interface
|
||||
"""
|
||||
|
||||
annotation: Any = NDArray[Shape["10, 10, *"], Float]
|
||||
"""
|
||||
Array annotation used in the validating model
|
||||
Any typed because the types of type annotations are weird
|
||||
"""
|
||||
shape: Tuple[int, ...] = (10, 10, 10)
|
||||
"""Shape of the array to validate"""
|
||||
dtype: Union[Type, np.dtype] = float
|
||||
"""Dtype of the array to validate"""
|
||||
passes: bool
|
||||
"""Whether the validation should pass or not"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@computed_field()
|
||||
def model(self) -> Type[BaseModel]:
|
||||
"""A model with a field ``array`` with the given annotation"""
|
||||
annotation = self.annotation
|
||||
|
||||
class Model(BaseModel):
|
||||
array: annotation
|
||||
|
||||
return Model
|
||||
|
||||
|
||||
RGB_UNION: TypeAlias = Union[
|
||||
NDArray[Shape["* x, * y"], Number],
|
||||
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
|
||||
]
|
||||
|
||||
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
||||
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
ValidationCase(shape=(10, 10, 10), passes=True),
|
||||
ValidationCase(shape=(10, 10), passes=False),
|
||||
ValidationCase(shape=(10, 10, 10, 10), passes=False),
|
||||
ValidationCase(shape=(11, 10, 10), passes=False),
|
||||
ValidationCase(shape=(9, 10, 10), passes=False),
|
||||
ValidationCase(shape=(10, 10, 9), passes=True),
|
||||
ValidationCase(shape=(10, 10, 11), passes=True),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False),
|
||||
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False),
|
||||
],
|
||||
ids=[
|
||||
"valid shape",
|
||||
"missing dimension",
|
||||
"extra dimension",
|
||||
"dimension too large",
|
||||
"dimension too small",
|
||||
"wildcard smaller",
|
||||
"wildcard larger",
|
||||
"Union 2D",
|
||||
"Union 3D",
|
||||
"Union 4D",
|
||||
"Union incorrect 3D",
|
||||
"Union incorrect 4D",
|
||||
"Union incorrect both",
|
||||
],
|
||||
)
|
||||
def shape_cases(request) -> ValidationCase:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="session",
|
||||
params=[
|
||||
ValidationCase(dtype=float, passes=True),
|
||||
ValidationCase(dtype=int, passes=False),
|
||||
ValidationCase(dtype=np.uint8, passes=False),
|
||||
ValidationCase(annotation=NUMBER, dtype=int, passes=True),
|
||||
ValidationCase(annotation=NUMBER, dtype=float, passes=True),
|
||||
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True),
|
||||
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True),
|
||||
ValidationCase(annotation=NUMBER, dtype=str, passes=False),
|
||||
ValidationCase(annotation=INTEGER, dtype=int, passes=True),
|
||||
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
||||
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
||||
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
||||
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
||||
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
||||
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
||||
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
||||
],
|
||||
ids=[
|
||||
"float",
|
||||
"int",
|
||||
"uint8",
|
||||
"number-int",
|
||||
"number-float",
|
||||
"number-uint8",
|
||||
"number-float16",
|
||||
"number-str",
|
||||
"integer-int",
|
||||
"integer-uint8",
|
||||
"integer-float",
|
||||
"integer-float32",
|
||||
"float-float",
|
||||
"float-float32",
|
||||
"float-int",
|
||||
"float-uint8",
|
||||
],
|
||||
)
|
||||
def dtype_cases(request) -> ValidationCase:
|
||||
return request.param
|
||||
|
|
|
@ -5,7 +5,6 @@ from typing import Callable, Optional, Tuple, Type, Union
|
|||
import h5py
|
||||
import numpy as np
|
||||
import pytest
|
||||
from nptyping import Number
|
||||
from pydantic import BaseModel, Field
|
||||
import zarr
|
||||
|
||||
|
@ -13,6 +12,7 @@ from numpydantic.interface.hdf5 import H5ArrayPath
|
|||
from numpydantic.interface.zarr import ZarrArrayPath
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.maps import python_to_nptyping
|
||||
from numpydantic.dtype import Number
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
|
@ -30,4 +30,8 @@ from tests.fixtures import hdf5_array, zarr_nested_array, zarr_array
|
|||
],
|
||||
)
|
||||
def interface_type(request):
|
||||
"""
|
||||
Test cases for each interface's ``check`` method - each input should match the
|
||||
provided interface and that interface only
|
||||
"""
|
||||
return request.param
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from numpydantic.exceptions import DtypeError, ShapeError
|
||||
|
||||
from tests.conftest import ValidationCase
|
||||
|
||||
|
||||
def numpy_array(case: ValidationCase) -> np.ndarray:
|
||||
return np.zeros(shape=case.shape, dtype=case.dtype)
|
||||
|
||||
|
||||
def _test_np_case(case: ValidationCase):
|
||||
array = numpy_array(case)
|
||||
if case.passes:
|
||||
case.model(array=array)
|
||||
else:
|
||||
with pytest.raises((ValidationError, DtypeError, ShapeError)):
|
||||
case.model(array=array)
|
||||
|
||||
|
||||
def test_numpy_shape(shape_cases):
|
||||
_test_np_case(shape_cases)
|
||||
|
||||
|
||||
def test_numpy_dtype(dtype_cases):
|
||||
_test_np_case(dtype_cases)
|
Loading…
Reference in a new issue