mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 02:34:28 +00:00
175 lines
5.6 KiB
Python
175 lines
5.6 KiB
Python
import pdb
|
|
import sys
|
|
|
|
import pytest
|
|
from typing import Any, Tuple, Union, Type
|
|
|
|
if sys.version_info.minor >= 10:
|
|
from typing import TypeAlias
|
|
else:
|
|
from typing_extensions import TypeAlias
|
|
from pydantic import BaseModel, computed_field, ConfigDict
|
|
from numpydantic import NDArray, Shape
|
|
from numpydantic.ndarray import NDArrayMeta
|
|
from numpydantic.dtype import Float, Number, Integer
|
|
import numpy as np
|
|
|
|
from tests.fixtures import *
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption(
|
|
"--with-output",
|
|
action="store_true",
|
|
help="Keep test outputs in the __tmp__ directory",
|
|
)
|
|
|
|
|
|
class ValidationCase(BaseModel):
|
|
"""
|
|
Test case for validating an array.
|
|
|
|
Contains both the validating model and the parameterization for an array to
|
|
test in a given interface
|
|
"""
|
|
|
|
annotation: Any = NDArray[Shape["10, 10, *"], Float]
|
|
"""
|
|
Array annotation used in the validating model
|
|
Any typed because the types of type annotations are weird
|
|
"""
|
|
shape: Tuple[int, ...] = (10, 10, 10)
|
|
"""Shape of the array to validate"""
|
|
dtype: Union[Type, np.dtype] = float
|
|
"""Dtype of the array to validate"""
|
|
passes: bool
|
|
"""Whether the validation should pass or not"""
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
@computed_field()
|
|
def model(self) -> Type[BaseModel]:
|
|
"""A model with a field ``array`` with the given annotation"""
|
|
annotation = self.annotation
|
|
|
|
class Model(BaseModel):
|
|
array: annotation
|
|
|
|
return Model
|
|
|
|
|
|
class BasicModel(BaseModel):
|
|
x: int
|
|
|
|
|
|
class BadModel(BaseModel):
|
|
x: int
|
|
|
|
|
|
RGB_UNION: TypeAlias = Union[
|
|
NDArray[Shape["* x, * y"], Number],
|
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
|
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
|
|
]
|
|
|
|
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
|
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
|
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
|
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
|
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
|
|
|
|
|
@pytest.fixture(
|
|
scope="module",
|
|
params=[
|
|
ValidationCase(shape=(10, 10, 10), passes=True),
|
|
ValidationCase(shape=(10, 10), passes=False),
|
|
ValidationCase(shape=(10, 10, 10, 10), passes=False),
|
|
ValidationCase(shape=(11, 10, 10), passes=False),
|
|
ValidationCase(shape=(9, 10, 10), passes=False),
|
|
ValidationCase(shape=(10, 10, 9), passes=True),
|
|
ValidationCase(shape=(10, 10, 11), passes=True),
|
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5), passes=True),
|
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3), passes=True),
|
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 4), passes=True),
|
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4), passes=False),
|
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 3, 6), passes=False),
|
|
ValidationCase(annotation=RGB_UNION, shape=(5, 5, 4, 6), passes=False),
|
|
],
|
|
ids=[
|
|
"valid shape",
|
|
"missing dimension",
|
|
"extra dimension",
|
|
"dimension too large",
|
|
"dimension too small",
|
|
"wildcard smaller",
|
|
"wildcard larger",
|
|
"Union 2D",
|
|
"Union 3D",
|
|
"Union 4D",
|
|
"Union incorrect 3D",
|
|
"Union incorrect 4D",
|
|
"Union incorrect both",
|
|
],
|
|
)
|
|
def shape_cases(request) -> ValidationCase:
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(
|
|
scope="module",
|
|
params=[
|
|
ValidationCase(dtype=float, passes=True),
|
|
ValidationCase(dtype=int, passes=False),
|
|
ValidationCase(dtype=np.uint8, passes=False),
|
|
ValidationCase(annotation=NUMBER, dtype=int, passes=True),
|
|
ValidationCase(annotation=NUMBER, dtype=float, passes=True),
|
|
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True),
|
|
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True),
|
|
ValidationCase(annotation=NUMBER, dtype=str, passes=False),
|
|
ValidationCase(annotation=INTEGER, dtype=int, passes=True),
|
|
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
|
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
|
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
|
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
|
|
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
|
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
|
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
|
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
|
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
|
|
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
|
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
|
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
|
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
|
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
|
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
|
],
|
|
ids=[
|
|
"float",
|
|
"int",
|
|
"uint8",
|
|
"number-int",
|
|
"number-float",
|
|
"number-uint8",
|
|
"number-float16",
|
|
"number-str",
|
|
"integer-int",
|
|
"integer-uint8",
|
|
"integer-float",
|
|
"integer-float32",
|
|
"integer-str",
|
|
"float-float",
|
|
"float-float32",
|
|
"float-int",
|
|
"float-uint8",
|
|
"float-str",
|
|
"str-str",
|
|
"str-int",
|
|
"str-float",
|
|
"model-model",
|
|
"model-badmodel",
|
|
"model-int",
|
|
],
|
|
)
|
|
def dtype_cases(request) -> ValidationCase:
|
|
return request.param
|