Better test cases, dtype module, compound types

This commit is contained in:
sneakers-the-rat 2024-05-08 21:29:13 -07:00
parent d884055067
commit b82a49df1b
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
11 changed files with 368 additions and 46 deletions

View file

@ -1,5 +1,24 @@
# TODO
## Validation
```{todo}
Support pydantic value/range constraints - less than, greater than, etc.
```
```{todo}
Support different precision modes - eg. exact precision, or minimum precision
where specifying a float32 would also accept a float64, etc.
```
## Metadata
```{todo}
Use names in nptyping annotations in generated JSON schema metadata
```
## All TODOs
```{todolist}
```

View file

@ -116,9 +116,6 @@ ignore = [
fixable = ["ALL"]
[tool.black]
exclude = "tests"
[tool.mypy]
plugins = [
"pydantic.mypy"

115
src/numpydantic/dtype.py Normal file
View file

@ -0,0 +1,115 @@
"""
Replacement of :mod:`nptyping.typing_`
In the transition away from using nptyping, we want to allow for greater
control of dtype specifications - like different precision modes, etc.
and allow for abstract specifications of dtype that can be checked across
interfaces.
This module also allows for convenient access to all abstract dtypes in a single
module, rather than needing to import each individually.
Some types like :ref:`Integer` are compound types - tuples of multiple dtypes.
Check these using ``in`` rather than ``==``. This interface will develop in future
versions to allow a single dtype check.
"""
from typing import Tuple, TypeAlias, Union
import numpy as np
ShapeExpression: TypeAlias = str
StructureExpression: TypeAlias = str
DType: TypeAlias = Union[np.generic, StructureExpression, Tuple["DType"]]
ShapeTuple: TypeAlias = Tuple[int, ...]
Bool = np.bool_
Obj = np.object_ # Obj is a common abbreviation and should be usable.
Object = np.object_
Datetime64 = np.datetime64
Inexact = np.inexact
Int8 = np.int8
Int16 = np.int16
Int32 = np.int32
Int64 = np.int64
Byte = np.byte
Short = np.short
IntC = np.intc
IntP = np.intp
Int_ = np.int_
UInt8 = np.uint8
UInt16 = np.uint16
UInt32 = np.uint32
UInt64 = np.uint64
UByte = np.ubyte
UShort = np.ushort
UIntC = np.uintc
UIntP = np.uintp
UInt = np.uint
ULongLong = np.ulonglong
LongLong = np.longlong
Timedelta64 = np.timedelta64
SignedInteger = (np.int8, np.int16, np.int32, np.int64, np.short)
UnsignedInteger = (np.uint8, np.uint16, np.uint32, np.uint64, np.ushort)
Integer = tuple([np.integer, *SignedInteger, *UnsignedInteger])
Int = Integer # Int should translate to the "generic" int type.
Float16 = np.float16
Float32 = np.float32
Float64 = np.float64
Half = np.half
Single = np.single
Double = np.double
LongDouble = np.longdouble
LongFloat = np.longfloat
Float = (
np.float_,
np.float16,
np.float32,
np.float64,
np.floating,
np.single,
np.double,
)
Floating = Float
ComplexFloating = np.complexfloating
Complex64 = np.complex64
Complex128 = np.complex128
CSingle = np.csingle
SingleComplex = np.singlecomplex
CDouble = np.cdouble
CFloat = np.cfloat
CLongDouble = np.clongdouble
CLongFloat = np.clongfloat
Complex = (
np.complex_,
np.complexfloating,
np.complex64,
np.complex128,
np.csingle,
np.singlecomplex,
np.cdouble,
np.cfloat,
np.clongdouble,
np.clongfloat,
)
LongComplex = np.longcomplex
Flexible = np.flexible
Void = np.void
Character = np.character
Bytes = np.bytes_
Str = np.str_
String = np.string_
Unicode = np.unicode_
Number = tuple(
[
np.number,
*Integer,
*Float,
*Complex,
]
)

View file

@ -56,7 +56,13 @@ class Interface(ABC, Generic[T]):
"""
if self.dtype is Any:
return array
if not array.dtype == self.dtype:
if isinstance(self.dtype, tuple):
valid = array.dtype in self.dtype
else:
valid = array.dtype == self.dtype
if not valid:
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
return array
@ -151,6 +157,7 @@ class Interface(ABC, Generic[T]):
msg += "\n".join([f" - {i}" for i in matches])
raise ValueError(msg)
elif len(matches) == 0:
pdb.set_trace()
raise ValueError(f"No matching interfaces found for input {array}")
else:
return matches[0]

View file

@ -5,14 +5,14 @@ Interface to zarr arrays
import contextlib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Union, Sequence
from typing import Any, Optional, Sequence, Union
from numpydantic.interface.interface import Interface
try:
import zarr
from zarr.core import Array as ZarrArray
from zarr.storage import StoreLike
import zarr
except ImportError:
ZarrArray = None
StoreLike = None
@ -32,11 +32,16 @@ class ZarrArrayPath:
path: Optional[str] = None
"""Path to array within hierarchical zarr store"""
def open(self, **kwargs) -> ZarrArray:
def open(self, **kwargs: dict) -> ZarrArray:
"""Open the zarr array at the provided path"""
return zarr.open(str(self.file), path=self.path, **kwargs)
@classmethod
def from_iterable(cls, spec: Sequence) -> "ZarrArrayPath":
"""
Construct a :class:`.ZarrArrayPath` specifier from an iterable,
rather than kwargs
"""
if len(spec) == 1:
return ZarrArrayPath(file=spec[0])
elif len(spec) == 2:

View file

@ -8,6 +8,8 @@ from typing import Any
import numpy as np
from nptyping import Bool, Float, Int, String
from numpydantic import dtype as dt
np_to_python = {
Any: Any,
np.number: float,
@ -17,34 +19,9 @@ np_to_python = {
np.byte: bytes,
np.bytes_: bytes,
np.datetime64: datetime,
**{
n: int
for n in (
np.int8,
np.int16,
np.int32,
np.int64,
np.short,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.uint,
)
},
**{
n: float
for n in (
np.float16,
np.float32,
np.floating,
np.float32,
np.float64,
np.single,
np.double,
np.float_,
)
},
**{n: int for n in dt.Integer},
**{n: float for n in dt.Float},
**{n: complex for n in dt.Complex},
**{n: str for n in (np.character, np.str_, np.string_, np.unicode_)},
}
"""Map from python types to numpy"""

View file

@ -10,11 +10,18 @@ from typing import TYPE_CHECKING, Any, Tuple, Union
import nptyping.structure
import numpy as np
from nptyping import Shape
from nptyping.error import InvalidArgumentsError
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
from nptyping.nptyping_type import NPTypingType
from nptyping.structure import Structure
from nptyping.structure_expression import check_type_names
from nptyping.typing_ import (
dtype_per_name,
)
from pydantic_core import core_schema
from pydantic_core.core_schema import ListSchema
from numpydantic.dtype import DType
from numpydantic.interface import Interface
from numpydantic.maps import np_to_python
@ -24,12 +31,6 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
if TYPE_CHECKING:
from pydantic import ValidationInfo
COMPRESSION_THRESHOLD = 16 * 1024
"""
Arrays larger than this size (in bytes) will be compressed and b64 encoded when
serializing to JSON.
"""
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
"""Make a pydantic JSON schema for an array as a list of lists."""
@ -95,11 +96,40 @@ def coerce_list(value: Any) -> np.ndarray:
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
"""
Kept here to allow for hooking into metaclass, which has
been necessary on and off as we work this class into a stable
state
Hooking into nptyping's array metaclass to override methods pending
completion of the transition away from nptyping
"""
def _get_dtype(cls, dtype_candidate: Any) -> DType:
"""
Override of base _get_dtype method to allow for compound tuple types
"""
is_dtype = isinstance(dtype_candidate, type) and issubclass(
dtype_candidate, np.generic
)
if dtype_candidate is Any:
dtype = Any
elif is_dtype:
dtype = dtype_candidate
elif issubclass(dtype_candidate, Structure):
dtype = dtype_candidate
check_type_names(dtype, dtype_per_name)
elif cls._is_literal_like(dtype_candidate):
structure_expression = dtype_candidate.__args__[0]
dtype = Structure[structure_expression]
check_type_names(dtype, dtype_per_name)
elif isinstance(dtype_candidate, tuple):
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
else:
raise InvalidArgumentsError(
f"Unexpected argument '{dtype_candidate}', expecting"
" Structure[<StructureExpression>]"
" or Literal[<StructureExpression>]"
" or a dtype"
" or typing.Any."
)
return dtype
class NDArray(NPTypingType, metaclass=NDArrayMeta):
"""
@ -133,6 +163,15 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
if isinstance(dtype, nptyping.structure.StructureMeta):
raise NotImplementedError("Finish handling structured dtypes!")
# functools.reduce(operator.or_, [int, float, str])
else:
if isinstance(dtype, tuple):
types_ = list(set([np_to_python[dt] for dt in dtype]))
# TODO: better type filtering - explicitly model what
# numeric types are supported by JSON schema
types_ = [t for t in types_ if t not in (complex,)]
schemas = [_handler.generate_schema(dt) for dt in types_]
array_type_handler = core_schema.union_schema(schemas)
else:
array_type_handler = _handler.generate_schema(np_to_python[dtype])

View file

@ -1,4 +1,12 @@
import pdb
import pytest
from typing import Any, Tuple, Union, Type, TypeAlias
from pydantic import BaseModel, computed_field, ConfigDict
from numpydantic import NDArray, Shape
from numpydantic.ndarray import NDArrayMeta
from numpydantic.dtype import Float, Number, Integer
import numpy as np
from tests.fixtures import *
@ -9,3 +17,127 @@ def pytest_addoption(parser):
action="store_true",
help="Keep test outputs in the __tmp__ directory",
)
class ValidationCase(BaseModel):
"""
Test case for validating an array.
Contains both the validating model and the parameterization for an array to
test in a given interface
"""
annotation: Any = NDArray[Shape["10, 10, *"], Float]
"""
Array annotation used in the validating model
Any typed because the types of type annotations are weird
"""
shape: Tuple[int, ...] = (10, 10, 10)
"""Shape of the array to validate"""
dtype: Union[Type, np.dtype] = float
"""Dtype of the array to validate"""
passes: bool
"""Whether the validation should pass or not"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@computed_field()
def model(self) -> Type[BaseModel]:
"""A model with a field ``array`` with the given annotation"""
annotation = self.annotation
class Model(BaseModel):
array: annotation
return Model
RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
]
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
@pytest.fixture(
scope="session",
params=[
ValidationCase(shape=(10, 10, 10), passes=True),
ValidationCase(shape=(10, 10), passes=False),
ValidationCase(shape=(10, 10, 10, 10), passes=False),
ValidationCase(shape=(11, 10, 10), passes=False),
ValidationCase(shape=(9, 10, 10), passes=False),
ValidationCase(shape=(10, 10, 9), passes=True),
ValidationCase(shape=(10, 10, 11), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False),
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False),
],
ids=[
"valid shape",
"missing dimension",
"extra dimension",
"dimension too large",
"dimension too small",
"wildcard smaller",
"wildcard larger",
"Union 2D",
"Union 3D",
"Union 4D",
"Union incorrect 3D",
"Union incorrect 4D",
"Union incorrect both",
],
)
def shape_cases(request) -> ValidationCase:
return request.param
@pytest.fixture(
scope="session",
params=[
ValidationCase(dtype=float, passes=True),
ValidationCase(dtype=int, passes=False),
ValidationCase(dtype=np.uint8, passes=False),
ValidationCase(annotation=NUMBER, dtype=int, passes=True),
ValidationCase(annotation=NUMBER, dtype=float, passes=True),
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True),
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True),
ValidationCase(annotation=NUMBER, dtype=str, passes=False),
ValidationCase(annotation=INTEGER, dtype=int, passes=True),
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
],
ids=[
"float",
"int",
"uint8",
"number-int",
"number-float",
"number-uint8",
"number-float16",
"number-str",
"integer-int",
"integer-uint8",
"integer-float",
"integer-float32",
"float-float",
"float-float32",
"float-int",
"float-uint8",
],
)
def dtype_cases(request) -> ValidationCase:
return request.param

View file

@ -5,7 +5,6 @@ from typing import Callable, Optional, Tuple, Type, Union
import h5py
import numpy as np
import pytest
from nptyping import Number
from pydantic import BaseModel, Field
import zarr
@ -13,6 +12,7 @@ from numpydantic.interface.hdf5 import H5ArrayPath
from numpydantic.interface.zarr import ZarrArrayPath
from numpydantic import NDArray, Shape
from numpydantic.maps import python_to_nptyping
from numpydantic.dtype import Number
@pytest.fixture(scope="session")

View file

@ -30,4 +30,8 @@ from tests.fixtures import hdf5_array, zarr_nested_array, zarr_array
],
)
def interface_type(request):
"""
Test cases for each interface's ``check`` method - each input should match the
provided interface and that interface only
"""
return request.param

View file

@ -0,0 +1,27 @@
import numpy as np
import pytest
from pydantic import ValidationError
from numpydantic.exceptions import DtypeError, ShapeError
from tests.conftest import ValidationCase
def numpy_array(case: ValidationCase) -> np.ndarray:
return np.zeros(shape=case.shape, dtype=case.dtype)
def _test_np_case(case: ValidationCase):
array = numpy_array(case)
if case.passes:
case.model(array=array)
else:
with pytest.raises((ValidationError, DtypeError, ShapeError)):
case.model(array=array)
def test_numpy_shape(shape_cases):
_test_np_case(shape_cases)
def test_numpy_dtype(dtype_cases):
_test_np_case(dtype_cases)