mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
Better test cases, dtype module, compound types
This commit is contained in:
parent
d884055067
commit
b82a49df1b
11 changed files with 368 additions and 46 deletions
19
docs/todo.md
19
docs/todo.md
|
@ -1,5 +1,24 @@
|
||||||
# TODO
|
# TODO
|
||||||
|
|
||||||
|
## Validation
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Support pydantic value/range constraints - less than, greater than, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Support different precision modes - eg. exact precision, or minimum precision
|
||||||
|
where specifying a float32 would also accept a float64, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metadata
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Use names in nptyping annotations in generated JSON schema metadata
|
||||||
|
```
|
||||||
|
|
||||||
|
## All TODOs
|
||||||
|
|
||||||
```{todolist}
|
```{todolist}
|
||||||
|
|
||||||
```
|
```
|
|
@ -116,9 +116,6 @@ ignore = [
|
||||||
|
|
||||||
fixable = ["ALL"]
|
fixable = ["ALL"]
|
||||||
|
|
||||||
[tool.black]
|
|
||||||
exclude = "tests"
|
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
plugins = [
|
plugins = [
|
||||||
"pydantic.mypy"
|
"pydantic.mypy"
|
||||||
|
|
115
src/numpydantic/dtype.py
Normal file
115
src/numpydantic/dtype.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
"""
|
||||||
|
Replacement of :mod:`nptyping.typing_`
|
||||||
|
|
||||||
|
In the transition away from using nptyping, we want to allow for greater
|
||||||
|
control of dtype specifications - like different precision modes, etc.
|
||||||
|
and allow for abstract specifications of dtype that can be checked across
|
||||||
|
interfaces.
|
||||||
|
|
||||||
|
This module also allows for convenient access to all abstract dtypes in a single
|
||||||
|
module, rather than needing to import each individually.
|
||||||
|
|
||||||
|
Some types like :ref:`Integer` are compound types - tuples of multiple dtypes.
|
||||||
|
Check these using ``in`` rather than ``==``. This interface will develop in future
|
||||||
|
versions to allow a single dtype check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple, TypeAlias, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
ShapeExpression: TypeAlias = str
|
||||||
|
StructureExpression: TypeAlias = str
|
||||||
|
DType: TypeAlias = Union[np.generic, StructureExpression, Tuple["DType"]]
|
||||||
|
ShapeTuple: TypeAlias = Tuple[int, ...]
|
||||||
|
|
||||||
|
Bool = np.bool_
|
||||||
|
Obj = np.object_ # Obj is a common abbreviation and should be usable.
|
||||||
|
Object = np.object_
|
||||||
|
Datetime64 = np.datetime64
|
||||||
|
Inexact = np.inexact
|
||||||
|
|
||||||
|
Int8 = np.int8
|
||||||
|
Int16 = np.int16
|
||||||
|
Int32 = np.int32
|
||||||
|
Int64 = np.int64
|
||||||
|
Byte = np.byte
|
||||||
|
Short = np.short
|
||||||
|
IntC = np.intc
|
||||||
|
IntP = np.intp
|
||||||
|
Int_ = np.int_
|
||||||
|
UInt8 = np.uint8
|
||||||
|
UInt16 = np.uint16
|
||||||
|
UInt32 = np.uint32
|
||||||
|
UInt64 = np.uint64
|
||||||
|
UByte = np.ubyte
|
||||||
|
UShort = np.ushort
|
||||||
|
UIntC = np.uintc
|
||||||
|
UIntP = np.uintp
|
||||||
|
UInt = np.uint
|
||||||
|
ULongLong = np.ulonglong
|
||||||
|
LongLong = np.longlong
|
||||||
|
Timedelta64 = np.timedelta64
|
||||||
|
SignedInteger = (np.int8, np.int16, np.int32, np.int64, np.short)
|
||||||
|
UnsignedInteger = (np.uint8, np.uint16, np.uint32, np.uint64, np.ushort)
|
||||||
|
Integer = tuple([np.integer, *SignedInteger, *UnsignedInteger])
|
||||||
|
Int = Integer # Int should translate to the "generic" int type.
|
||||||
|
|
||||||
|
Float16 = np.float16
|
||||||
|
Float32 = np.float32
|
||||||
|
Float64 = np.float64
|
||||||
|
Half = np.half
|
||||||
|
Single = np.single
|
||||||
|
Double = np.double
|
||||||
|
LongDouble = np.longdouble
|
||||||
|
LongFloat = np.longfloat
|
||||||
|
Float = (
|
||||||
|
np.float_,
|
||||||
|
np.float16,
|
||||||
|
np.float32,
|
||||||
|
np.float64,
|
||||||
|
np.floating,
|
||||||
|
np.single,
|
||||||
|
np.double,
|
||||||
|
)
|
||||||
|
Floating = Float
|
||||||
|
|
||||||
|
ComplexFloating = np.complexfloating
|
||||||
|
Complex64 = np.complex64
|
||||||
|
Complex128 = np.complex128
|
||||||
|
CSingle = np.csingle
|
||||||
|
SingleComplex = np.singlecomplex
|
||||||
|
CDouble = np.cdouble
|
||||||
|
CFloat = np.cfloat
|
||||||
|
CLongDouble = np.clongdouble
|
||||||
|
CLongFloat = np.clongfloat
|
||||||
|
Complex = (
|
||||||
|
np.complex_,
|
||||||
|
np.complexfloating,
|
||||||
|
np.complex64,
|
||||||
|
np.complex128,
|
||||||
|
np.csingle,
|
||||||
|
np.singlecomplex,
|
||||||
|
np.cdouble,
|
||||||
|
np.cfloat,
|
||||||
|
np.clongdouble,
|
||||||
|
np.clongfloat,
|
||||||
|
)
|
||||||
|
|
||||||
|
LongComplex = np.longcomplex
|
||||||
|
Flexible = np.flexible
|
||||||
|
Void = np.void
|
||||||
|
Character = np.character
|
||||||
|
Bytes = np.bytes_
|
||||||
|
Str = np.str_
|
||||||
|
String = np.string_
|
||||||
|
Unicode = np.unicode_
|
||||||
|
|
||||||
|
Number = tuple(
|
||||||
|
[
|
||||||
|
np.number,
|
||||||
|
*Integer,
|
||||||
|
*Float,
|
||||||
|
*Complex,
|
||||||
|
]
|
||||||
|
)
|
|
@ -56,7 +56,13 @@ class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
if self.dtype is Any:
|
if self.dtype is Any:
|
||||||
return array
|
return array
|
||||||
if not array.dtype == self.dtype:
|
|
||||||
|
if isinstance(self.dtype, tuple):
|
||||||
|
valid = array.dtype in self.dtype
|
||||||
|
else:
|
||||||
|
valid = array.dtype == self.dtype
|
||||||
|
|
||||||
|
if not valid:
|
||||||
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
|
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
@ -151,6 +157,7 @@ class Interface(ABC, Generic[T]):
|
||||||
msg += "\n".join([f" - {i}" for i in matches])
|
msg += "\n".join([f" - {i}" for i in matches])
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
elif len(matches) == 0:
|
elif len(matches) == 0:
|
||||||
|
pdb.set_trace()
|
||||||
raise ValueError(f"No matching interfaces found for input {array}")
|
raise ValueError(f"No matching interfaces found for input {array}")
|
||||||
else:
|
else:
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
|
@ -5,14 +5,14 @@ Interface to zarr arrays
|
||||||
import contextlib
|
import contextlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union, Sequence
|
from typing import Any, Optional, Sequence, Union
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import zarr
|
||||||
from zarr.core import Array as ZarrArray
|
from zarr.core import Array as ZarrArray
|
||||||
from zarr.storage import StoreLike
|
from zarr.storage import StoreLike
|
||||||
import zarr
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
ZarrArray = None
|
ZarrArray = None
|
||||||
StoreLike = None
|
StoreLike = None
|
||||||
|
@ -32,11 +32,16 @@ class ZarrArrayPath:
|
||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
"""Path to array within hierarchical zarr store"""
|
"""Path to array within hierarchical zarr store"""
|
||||||
|
|
||||||
def open(self, **kwargs) -> ZarrArray:
|
def open(self, **kwargs: dict) -> ZarrArray:
|
||||||
|
"""Open the zarr array at the provided path"""
|
||||||
return zarr.open(str(self.file), path=self.path, **kwargs)
|
return zarr.open(str(self.file), path=self.path, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_iterable(cls, spec: Sequence) -> "ZarrArrayPath":
|
def from_iterable(cls, spec: Sequence) -> "ZarrArrayPath":
|
||||||
|
"""
|
||||||
|
Construct a :class:`.ZarrArrayPath` specifier from an iterable,
|
||||||
|
rather than kwargs
|
||||||
|
"""
|
||||||
if len(spec) == 1:
|
if len(spec) == 1:
|
||||||
return ZarrArrayPath(file=spec[0])
|
return ZarrArrayPath(file=spec[0])
|
||||||
elif len(spec) == 2:
|
elif len(spec) == 2:
|
||||||
|
|
|
@ -8,6 +8,8 @@ from typing import Any
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nptyping import Bool, Float, Int, String
|
from nptyping import Bool, Float, Int, String
|
||||||
|
|
||||||
|
from numpydantic import dtype as dt
|
||||||
|
|
||||||
np_to_python = {
|
np_to_python = {
|
||||||
Any: Any,
|
Any: Any,
|
||||||
np.number: float,
|
np.number: float,
|
||||||
|
@ -17,34 +19,9 @@ np_to_python = {
|
||||||
np.byte: bytes,
|
np.byte: bytes,
|
||||||
np.bytes_: bytes,
|
np.bytes_: bytes,
|
||||||
np.datetime64: datetime,
|
np.datetime64: datetime,
|
||||||
**{
|
**{n: int for n in dt.Integer},
|
||||||
n: int
|
**{n: float for n in dt.Float},
|
||||||
for n in (
|
**{n: complex for n in dt.Complex},
|
||||||
np.int8,
|
|
||||||
np.int16,
|
|
||||||
np.int32,
|
|
||||||
np.int64,
|
|
||||||
np.short,
|
|
||||||
np.uint8,
|
|
||||||
np.uint16,
|
|
||||||
np.uint32,
|
|
||||||
np.uint64,
|
|
||||||
np.uint,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
**{
|
|
||||||
n: float
|
|
||||||
for n in (
|
|
||||||
np.float16,
|
|
||||||
np.float32,
|
|
||||||
np.floating,
|
|
||||||
np.float32,
|
|
||||||
np.float64,
|
|
||||||
np.single,
|
|
||||||
np.double,
|
|
||||||
np.float_,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
**{n: str for n in (np.character, np.str_, np.string_, np.unicode_)},
|
**{n: str for n in (np.character, np.str_, np.string_, np.unicode_)},
|
||||||
}
|
}
|
||||||
"""Map from python types to numpy"""
|
"""Map from python types to numpy"""
|
||||||
|
|
|
@ -10,11 +10,18 @@ from typing import TYPE_CHECKING, Any, Tuple, Union
|
||||||
import nptyping.structure
|
import nptyping.structure
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nptyping import Shape
|
from nptyping import Shape
|
||||||
|
from nptyping.error import InvalidArgumentsError
|
||||||
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||||
from nptyping.nptyping_type import NPTypingType
|
from nptyping.nptyping_type import NPTypingType
|
||||||
|
from nptyping.structure import Structure
|
||||||
|
from nptyping.structure_expression import check_type_names
|
||||||
|
from nptyping.typing_ import (
|
||||||
|
dtype_per_name,
|
||||||
|
)
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
from pydantic_core.core_schema import ListSchema
|
from pydantic_core.core_schema import ListSchema
|
||||||
|
|
||||||
|
from numpydantic.dtype import DType
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.maps import np_to_python
|
from numpydantic.maps import np_to_python
|
||||||
|
|
||||||
|
@ -24,12 +31,6 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pydantic import ValidationInfo
|
from pydantic import ValidationInfo
|
||||||
|
|
||||||
COMPRESSION_THRESHOLD = 16 * 1024
|
|
||||||
"""
|
|
||||||
Arrays larger than this size (in bytes) will be compressed and b64 encoded when
|
|
||||||
serializing to JSON.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||||
"""Make a pydantic JSON schema for an array as a list of lists."""
|
"""Make a pydantic JSON schema for an array as a list of lists."""
|
||||||
|
@ -95,11 +96,40 @@ def coerce_list(value: Any) -> np.ndarray:
|
||||||
|
|
||||||
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
"""
|
"""
|
||||||
Kept here to allow for hooking into metaclass, which has
|
Hooking into nptyping's array metaclass to override methods pending
|
||||||
been necessary on and off as we work this class into a stable
|
completion of the transition away from nptyping
|
||||||
state
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
||||||
|
"""
|
||||||
|
Override of base _get_dtype method to allow for compound tuple types
|
||||||
|
"""
|
||||||
|
is_dtype = isinstance(dtype_candidate, type) and issubclass(
|
||||||
|
dtype_candidate, np.generic
|
||||||
|
)
|
||||||
|
if dtype_candidate is Any:
|
||||||
|
dtype = Any
|
||||||
|
elif is_dtype:
|
||||||
|
dtype = dtype_candidate
|
||||||
|
elif issubclass(dtype_candidate, Structure):
|
||||||
|
dtype = dtype_candidate
|
||||||
|
check_type_names(dtype, dtype_per_name)
|
||||||
|
elif cls._is_literal_like(dtype_candidate):
|
||||||
|
structure_expression = dtype_candidate.__args__[0]
|
||||||
|
dtype = Structure[structure_expression]
|
||||||
|
check_type_names(dtype, dtype_per_name)
|
||||||
|
elif isinstance(dtype_candidate, tuple):
|
||||||
|
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
|
||||||
|
else:
|
||||||
|
raise InvalidArgumentsError(
|
||||||
|
f"Unexpected argument '{dtype_candidate}', expecting"
|
||||||
|
" Structure[<StructureExpression>]"
|
||||||
|
" or Literal[<StructureExpression>]"
|
||||||
|
" or a dtype"
|
||||||
|
" or typing.Any."
|
||||||
|
)
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
"""
|
"""
|
||||||
|
@ -134,7 +164,16 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
raise NotImplementedError("Finish handling structured dtypes!")
|
raise NotImplementedError("Finish handling structured dtypes!")
|
||||||
# functools.reduce(operator.or_, [int, float, str])
|
# functools.reduce(operator.or_, [int, float, str])
|
||||||
else:
|
else:
|
||||||
array_type_handler = _handler.generate_schema(np_to_python[dtype])
|
if isinstance(dtype, tuple):
|
||||||
|
types_ = list(set([np_to_python[dt] for dt in dtype]))
|
||||||
|
# TODO: better type filtering - explicitly model what
|
||||||
|
# numeric types are supported by JSON schema
|
||||||
|
types_ = [t for t in types_ if t not in (complex,)]
|
||||||
|
schemas = [_handler.generate_schema(dt) for dt in types_]
|
||||||
|
array_type_handler = core_schema.union_schema(schemas)
|
||||||
|
|
||||||
|
else:
|
||||||
|
array_type_handler = _handler.generate_schema(np_to_python[dtype])
|
||||||
|
|
||||||
# get the names of the shape constraints, if any
|
# get the names of the shape constraints, if any
|
||||||
if shape is Any:
|
if shape is Any:
|
||||||
|
|
|
@ -1,4 +1,12 @@
|
||||||
|
import pdb
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing import Any, Tuple, Union, Type, TypeAlias
|
||||||
|
from pydantic import BaseModel, computed_field, ConfigDict
|
||||||
|
from numpydantic import NDArray, Shape
|
||||||
|
from numpydantic.ndarray import NDArrayMeta
|
||||||
|
from numpydantic.dtype import Float, Number, Integer
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
|
||||||
|
@ -9,3 +17,127 @@ def pytest_addoption(parser):
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Keep test outputs in the __tmp__ directory",
|
help="Keep test outputs in the __tmp__ directory",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationCase(BaseModel):
|
||||||
|
"""
|
||||||
|
Test case for validating an array.
|
||||||
|
|
||||||
|
Contains both the validating model and the parameterization for an array to
|
||||||
|
test in a given interface
|
||||||
|
"""
|
||||||
|
|
||||||
|
annotation: Any = NDArray[Shape["10, 10, *"], Float]
|
||||||
|
"""
|
||||||
|
Array annotation used in the validating model
|
||||||
|
Any typed because the types of type annotations are weird
|
||||||
|
"""
|
||||||
|
shape: Tuple[int, ...] = (10, 10, 10)
|
||||||
|
"""Shape of the array to validate"""
|
||||||
|
dtype: Union[Type, np.dtype] = float
|
||||||
|
"""Dtype of the array to validate"""
|
||||||
|
passes: bool
|
||||||
|
"""Whether the validation should pass or not"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
@computed_field()
|
||||||
|
def model(self) -> Type[BaseModel]:
|
||||||
|
"""A model with a field ``array`` with the given annotation"""
|
||||||
|
annotation = self.annotation
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: annotation
|
||||||
|
|
||||||
|
return Model
|
||||||
|
|
||||||
|
|
||||||
|
RGB_UNION: TypeAlias = Union[
|
||||||
|
NDArray[Shape["* x, * y"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
|
||||||
|
]
|
||||||
|
|
||||||
|
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
||||||
|
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||||
|
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
scope="session",
|
||||||
|
params=[
|
||||||
|
ValidationCase(shape=(10, 10, 10), passes=True),
|
||||||
|
ValidationCase(shape=(10, 10), passes=False),
|
||||||
|
ValidationCase(shape=(10, 10, 10, 10), passes=False),
|
||||||
|
ValidationCase(shape=(11, 10, 10), passes=False),
|
||||||
|
ValidationCase(shape=(9, 10, 10), passes=False),
|
||||||
|
ValidationCase(shape=(10, 10, 9), passes=True),
|
||||||
|
ValidationCase(shape=(10, 10, 11), passes=True),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False),
|
||||||
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"valid shape",
|
||||||
|
"missing dimension",
|
||||||
|
"extra dimension",
|
||||||
|
"dimension too large",
|
||||||
|
"dimension too small",
|
||||||
|
"wildcard smaller",
|
||||||
|
"wildcard larger",
|
||||||
|
"Union 2D",
|
||||||
|
"Union 3D",
|
||||||
|
"Union 4D",
|
||||||
|
"Union incorrect 3D",
|
||||||
|
"Union incorrect 4D",
|
||||||
|
"Union incorrect both",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def shape_cases(request) -> ValidationCase:
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
scope="session",
|
||||||
|
params=[
|
||||||
|
ValidationCase(dtype=float, passes=True),
|
||||||
|
ValidationCase(dtype=int, passes=False),
|
||||||
|
ValidationCase(dtype=np.uint8, passes=False),
|
||||||
|
ValidationCase(annotation=NUMBER, dtype=int, passes=True),
|
||||||
|
ValidationCase(annotation=NUMBER, dtype=float, passes=True),
|
||||||
|
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True),
|
||||||
|
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True),
|
||||||
|
ValidationCase(annotation=NUMBER, dtype=str, passes=False),
|
||||||
|
ValidationCase(annotation=INTEGER, dtype=int, passes=True),
|
||||||
|
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
||||||
|
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
||||||
|
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
||||||
|
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
||||||
|
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
||||||
|
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
||||||
|
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"float",
|
||||||
|
"int",
|
||||||
|
"uint8",
|
||||||
|
"number-int",
|
||||||
|
"number-float",
|
||||||
|
"number-uint8",
|
||||||
|
"number-float16",
|
||||||
|
"number-str",
|
||||||
|
"integer-int",
|
||||||
|
"integer-uint8",
|
||||||
|
"integer-float",
|
||||||
|
"integer-float32",
|
||||||
|
"float-float",
|
||||||
|
"float-float32",
|
||||||
|
"float-int",
|
||||||
|
"float-uint8",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def dtype_cases(request) -> ValidationCase:
|
||||||
|
return request.param
|
||||||
|
|
|
@ -5,7 +5,6 @@ from typing import Callable, Optional, Tuple, Type, Union
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from nptyping import Number
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import zarr
|
import zarr
|
||||||
|
|
||||||
|
@ -13,6 +12,7 @@ from numpydantic.interface.hdf5 import H5ArrayPath
|
||||||
from numpydantic.interface.zarr import ZarrArrayPath
|
from numpydantic.interface.zarr import ZarrArrayPath
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from numpydantic.maps import python_to_nptyping
|
from numpydantic.maps import python_to_nptyping
|
||||||
|
from numpydantic.dtype import Number
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|
|
@ -30,4 +30,8 @@ from tests.fixtures import hdf5_array, zarr_nested_array, zarr_array
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def interface_type(request):
|
def interface_type(request):
|
||||||
|
"""
|
||||||
|
Test cases for each interface's ``check`` method - each input should match the
|
||||||
|
provided interface and that interface only
|
||||||
|
"""
|
||||||
return request.param
|
return request.param
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
|
||||||
|
from tests.conftest import ValidationCase
|
||||||
|
|
||||||
|
|
||||||
|
def numpy_array(case: ValidationCase) -> np.ndarray:
|
||||||
|
return np.zeros(shape=case.shape, dtype=case.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_np_case(case: ValidationCase):
|
||||||
|
array = numpy_array(case)
|
||||||
|
if case.passes:
|
||||||
|
case.model(array=array)
|
||||||
|
else:
|
||||||
|
with pytest.raises((ValidationError, DtypeError, ShapeError)):
|
||||||
|
case.model(array=array)
|
||||||
|
|
||||||
|
|
||||||
|
def test_numpy_shape(shape_cases):
|
||||||
|
_test_np_case(shape_cases)
|
||||||
|
|
||||||
|
|
||||||
|
def test_numpy_dtype(dtype_cases):
|
||||||
|
_test_np_case(dtype_cases)
|
Loading…
Reference in a new issue