numpydantic/tests/test_interface/test_dask.py

62 lines
1.6 KiB
Python
Raw Normal View History

2024-05-15 03:18:04 +00:00
import json
2024-04-09 01:36:47 +00:00
import dask.array as da
2024-10-04 02:57:54 +00:00
import pytest
from pydantic import BaseModel, ValidationError
2024-04-09 01:36:47 +00:00
2024-05-09 05:06:41 +00:00
from numpydantic.exceptions import DtypeError, ShapeError
2024-10-04 02:57:54 +00:00
from numpydantic.interface import DaskInterface
2024-10-04 02:33:40 +00:00
from numpydantic.testing.helpers import ValidationCase
2024-05-09 05:06:41 +00:00
2024-09-23 20:28:38 +00:00
pytestmark = pytest.mark.dask
2024-05-09 05:06:41 +00:00
def dask_array(case: ValidationCase) -> da.Array:
if issubclass(case.dtype, BaseModel):
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
else:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
2024-05-09 05:06:41 +00:00
def _test_dask_case(case: ValidationCase):
array = dask_array(case)
if case.passes:
case.model(array=array)
else:
with pytest.raises((ValidationError, DtypeError, ShapeError)):
case.model(array=array)
2024-04-09 01:36:47 +00:00
def test_dask_enabled():
"""
We need dask to be available to run these tests :)
"""
assert DaskInterface.enabled()
def test_dask_check(interface_type):
if interface_type[1] is DaskInterface:
assert DaskInterface.check(interface_type[0])
else:
assert not DaskInterface.check(interface_type[0])
2024-09-23 20:28:38 +00:00
@pytest.mark.shape
2024-05-09 05:06:41 +00:00
def test_dask_shape(shape_cases):
_test_dask_case(shape_cases)
2024-04-09 01:36:47 +00:00
2024-09-23 20:28:38 +00:00
@pytest.mark.dtype
2024-05-09 05:06:41 +00:00
def test_dask_dtype(dtype_cases):
_test_dask_case(dtype_cases)
2024-05-15 03:18:04 +00:00
2024-09-23 20:28:38 +00:00
@pytest.mark.serialization
2024-05-15 03:18:04 +00:00
def test_dask_to_json(array_model):
array_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
array = da.array(array_list)
model = array_model((3, 3), int)
instance = model(array=array)
jsonified = json.loads(instance.model_dump_json())
assert jsonified["array"] == array_list