numpydantic/tests/test_interface/test_interfaces.py

238 lines
7.5 KiB
Python
Raw Normal View History

2024-09-03 20:10:26 +00:00
"""
Tests that should be applied to all interfaces
"""
import json
2024-10-04 02:57:54 +00:00
from importlib.metadata import version
2024-10-19 03:15:15 +00:00
from typing import Generic, TypeVar, Union
import dask.array as da
2024-10-04 02:57:54 +00:00
import numpy as np
import pytest
from pydantic import BaseModel, ConfigDict
2024-10-04 02:57:54 +00:00
from zarr.core import Array as ZarrArray
from numpydantic import NDArray
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
from numpydantic.testing.cases import (
ALL_CASES_PASSING,
DTYPE_AND_INTERFACE_CASES_PASSING,
INTERFACE_CASES,
)
from numpydantic.testing.helpers import ValidationCase
def _test_roundtrip(source: BaseModel, target: BaseModel):
"""Test model equality for roundtrip tests"""
assert type(target.array) is type(source.array)
if isinstance(source.array, (np.ndarray, ZarrArray)):
assert np.array_equal(target.array, np.array(source.array))
elif isinstance(source.array, da.Array):
if target.array.dtype == object:
# object equality doesn't really work well with dask
# just check that the types match
target_type = type(target.array.ravel()[0].compute())
source_type = type(source.array.ravel()[0].compute())
assert target_type is source_type
else:
assert np.all(da.equal(target.array, source.array))
elif isinstance(source.array, BaseModel):
return _test_roundtrip(source.array, target.array)
else:
assert target.array == source.array
assert target.array.dtype == source.array.dtype
@pytest.mark.parametrize(
"interface",
[
pytest.param(i, marks=getattr(pytest.mark, i.interface.interface.name))
for i in INTERFACE_CASES
],
)
@pytest.mark.parametrize(
"cases", [ALL_CASES_PASSING, DTYPE_AND_INTERFACE_CASES_PASSING]
)
def test_cases_include_all_interfaces(interface: ValidationCase, cases):
"""
Test our test cases - we should hit all interfaces in the common "test all" fixtures
"""
cases = list(cases)
assert any(
[case.interface is interface.interface for case in cases]
), f"Interface case unused in general test cases: {interface.interface}"
def test_dunder_len(interface_cases, tmp_output_dir_func):
"""
Each interface or proxy type should support __len__
"""
case = ValidationCase(interface=interface_cases)
if interface_cases.interface.name == "video":
case.shape = (10, 10, 2, 3)
case.dtype = np.uint8
case.annotation_dtype = np.uint8
case.annotation_shape = (10, 10, "*", 3)
array = case.array(path=tmp_output_dir_func)
instance = case.model(array=array)
assert len(instance.array) == case.shape[0]
2024-09-03 20:10:26 +00:00
def test_interface_revalidate(all_passing_cases_instance):
2024-09-03 20:10:26 +00:00
"""
An interface should revalidate with the output of its initial validation
See: https://github.com/p2p-ld/numpydantic/pull/14
"""
_ = type(all_passing_cases_instance)(array=all_passing_cases_instance.array)
@pytest.mark.xfail
def test_interface_rematch(interface_cases, tmp_output_dir_func):
"""
All interfaces should match the results of the object they return after validation
"""
array = interface_cases.make_array(path=tmp_output_dir_func)
assert (
Interface.match(interface_cases.interface.validate(array))
is interface_cases.interface
)
def test_interface_to_numpy_array(dtype_by_interface_instance):
"""
All interfaces should be able to have the output of their validation stage
coerced to a numpy array with np.array()
"""
_ = np.array(dtype_by_interface_instance.array)
@pytest.mark.serialization
def test_interface_dump_json(dtype_by_interface_instance):
"""
All interfaces should be able to dump to json
"""
dtype_by_interface_instance.model_dump_json()
@pytest.mark.serialization
2024-10-18 05:15:30 +00:00
def test_interface_roundtrip_json(all_passing_cases, tmp_output_dir_func):
"""
All interfaces should be able to roundtrip to and from json
"""
2024-10-18 05:15:30 +00:00
array = all_passing_cases.array(path=tmp_output_dir_func)
case = all_passing_cases.model(array=array)
dumped_json = case.model_dump_json(round_trip=True)
model = case.model_validate_json(dumped_json)
_test_roundtrip(case, model)
@pytest.mark.serialization
@pytest.mark.parametrize("an_interface", Interface.interfaces())
def test_interface_mark_interface(an_interface):
"""
All interfaces should be able to mark the current version and interface info
"""
mark = an_interface.mark_interface()
assert isinstance(mark, InterfaceMark)
assert mark.name == an_interface.name
assert mark.cls == an_interface.__name__
assert mark.module == an_interface.__module__
assert mark.version == version(mark.module.split(".")[0])
2024-09-23 20:28:38 +00:00
@pytest.mark.serialization
@pytest.mark.parametrize("valid", [True, False])
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
2024-10-18 05:15:30 +00:00
def test_interface_mark_roundtrip(all_passing_cases, valid, tmp_output_dir_func):
2024-09-23 20:28:38 +00:00
"""
All interfaces should be able to roundtrip with the marked interface,
and a mismatch should raise a warning and attempt to proceed
2024-09-23 20:28:38 +00:00
"""
2024-10-18 05:15:30 +00:00
if "subclass" in all_passing_cases.id.lower():
pytest.xfail()
2024-10-18 05:15:30 +00:00
array = all_passing_cases.array(path=tmp_output_dir_func)
case = all_passing_cases.model(array=array)
dumped_json = case.model_dump_json(
round_trip=True, context={"mark_interface": True}
)
data = json.loads(dumped_json)
# ensure that we are a MarkedJson
_ = MarkedJson.model_validate_json(json.dumps(data["array"]))
if not valid:
# ruin the version
data["array"]["interface"]["version"] = "v99999999"
dumped_json = json.dumps(data)
with pytest.warns(match="Mismatch.*"):
model = case.model_validate_json(dumped_json)
else:
model = case.model_validate_json(dumped_json)
_test_roundtrip(case, model)
@pytest.mark.serialization
def test_roundtrip_from_extra(dtype_by_interface, tmp_output_dir_func):
"""
Arrays can be dumped when they are specified in an `__extra__` field
"""
class Model(BaseModel):
__pydantic_extra__: dict[str, dtype_by_interface.annotation]
model_config = ConfigDict(extra="allow")
instance = Model(array=dtype_by_interface.array(path=tmp_output_dir_func))
dumped = instance.model_dump_json(round_trip=True)
roundtripped = Model.model_validate_json(dumped)
_test_roundtrip(instance, roundtripped)
@pytest.mark.serialization
def test_roundtrip_from_union(dtype_by_interface, tmp_output_dir_func):
"""
Arrays can be dumped when they are specified along with a
union of another type field
"""
class Model(BaseModel):
2024-10-19 03:15:15 +00:00
array: Union[str, dtype_by_interface.annotation]
array = dtype_by_interface.array(path=tmp_output_dir_func)
instance = Model(array=array)
dumped = instance.model_dump_json(round_trip=True)
roundtripped = Model.model_validate_json(dumped)
_test_roundtrip(instance, roundtripped)
@pytest.mark.serialization
def test_roundtrip_from_generic(dtype_by_interface, tmp_output_dir_func):
"""
Arrays can be dumped when they are specified in an `__extra__` field
"""
T = TypeVar("T", bound=NDArray)
class GenType(BaseModel, Generic[T]):
array: T
class Model(BaseModel):
array: GenType[dtype_by_interface.annotation]
array = dtype_by_interface.array(path=tmp_output_dir_func)
instance = Model(**{"array": {"array": array}})
dumped = instance.model_dump_json(round_trip=True)
roundtripped = Model.model_validate_json(dumped)
_test_roundtrip(instance, roundtripped)