2024-09-03 20:10:26 +00:00
|
|
|
"""
|
|
|
|
Tests that should be applied to all interfaces
|
|
|
|
"""
|
|
|
|
|
2024-09-21 06:44:59 +00:00
|
|
|
import pytest
|
2024-09-21 01:28:38 +00:00
|
|
|
from typing import Callable
|
|
|
|
import numpy as np
|
2024-09-21 06:44:59 +00:00
|
|
|
import dask.array as da
|
|
|
|
from zarr.core import Array as ZarrArray
|
2024-09-21 01:28:38 +00:00
|
|
|
from numpydantic.interface import Interface
|
|
|
|
|
2024-09-03 20:10:26 +00:00
|
|
|
|
|
|
|
def test_interface_revalidate(all_interfaces):
|
|
|
|
"""
|
|
|
|
An interface should revalidate with the output of its initial validation
|
|
|
|
|
|
|
|
See: https://github.com/p2p-ld/numpydantic/pull/14
|
|
|
|
"""
|
|
|
|
_ = type(all_interfaces)(array=all_interfaces.array)
|
2024-09-21 01:28:38 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_interface_rematch(interface_type):
|
|
|
|
"""
|
|
|
|
All interfaces should match the results of the object they return after validation
|
|
|
|
"""
|
|
|
|
array, interface = interface_type
|
|
|
|
if isinstance(array, Callable):
|
|
|
|
array = array()
|
|
|
|
|
|
|
|
assert Interface.match(interface().validate(array)) is interface
|
|
|
|
|
|
|
|
|
|
|
|
def test_interface_to_numpy_array(all_interfaces):
|
|
|
|
"""
|
|
|
|
All interfaces should be able to have the output of their validation stage
|
|
|
|
coerced to a numpy array with np.array()
|
|
|
|
"""
|
|
|
|
_ = np.array(all_interfaces.array)
|
|
|
|
|
|
|
|
|
2024-09-21 06:44:59 +00:00
|
|
|
@pytest.mark.serialization
|
2024-09-21 01:28:38 +00:00
|
|
|
def test_interface_dump_json(all_interfaces):
|
|
|
|
"""
|
|
|
|
All interfaces should be able to dump to json
|
|
|
|
"""
|
|
|
|
all_interfaces.model_dump_json()
|
2024-09-21 06:44:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.serialization
|
|
|
|
@pytest.mark.parametrize("round_trip", [True, False])
|
|
|
|
def test_interface_roundtrip_json(all_interfaces, round_trip):
|
|
|
|
"""
|
|
|
|
All interfaces should be able to roundtrip to and from json
|
|
|
|
"""
|
|
|
|
json = all_interfaces.model_dump_json(round_trip=round_trip)
|
|
|
|
model = all_interfaces.model_validate_json(json)
|
|
|
|
if round_trip:
|
|
|
|
assert type(model.array) is type(all_interfaces.array)
|
|
|
|
if isinstance(all_interfaces.array, (np.ndarray, ZarrArray)):
|
|
|
|
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
|
|
|
elif isinstance(all_interfaces.array, da.Array):
|
|
|
|
assert np.all(da.equal(model.array, all_interfaces.array))
|
|
|
|
else:
|
|
|
|
assert model.array == all_interfaces.array
|
|
|
|
|
|
|
|
assert model.array.dtype == all_interfaces.array.dtype
|
|
|
|
else:
|
|
|
|
assert np.array_equal(model.array, np.array(all_interfaces.array))
|
2024-09-23 20:28:38 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_dunder_len(all_interfaces):
|
|
|
|
"""
|
|
|
|
Each interface or proxy type should support __len__
|
|
|
|
"""
|
|
|
|
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|