numpydantic/tests/test_interface/test_zarr.py

162 lines
4.4 KiB
Python
Raw Normal View History

import json
2024-04-30 02:49:38 +00:00
import pytest
import zarr
from pydantic import BaseModel, ValidationError
from numcodecs import Pickle
2024-04-30 02:49:38 +00:00
from numpydantic.interface import ZarrInterface
from numpydantic.interface.zarr import ZarrArrayPath
2024-05-09 05:06:41 +00:00
from numpydantic.exceptions import DtypeError, ShapeError
from tests.conftest import ValidationCase
2024-04-30 02:49:38 +00:00
@pytest.fixture()
def dir_array(tmp_output_dir_func) -> zarr.DirectoryStore:
store = zarr.DirectoryStore(tmp_output_dir_func / "array.zarr")
return store
@pytest.fixture()
def zip_array(tmp_output_dir_func) -> zarr.ZipStore:
store = zarr.ZipStore(tmp_output_dir_func / "array.zip", mode="w")
return store
@pytest.fixture()
def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore:
store = zarr.NestedDirectoryStore(tmp_output_dir_func / "nested")
return store
def _zarr_array(case: ValidationCase, store) -> zarr.core.Array:
if issubclass(case.dtype, BaseModel):
pytest.skip(
f"Zarr can't handle objects properly at the moment, "
"see https://github.com/zarr-developers/zarr-python/issues/2081"
)
# return zarr.full(
# shape=case.shape,
# fill_value=case.dtype(x=1),
# dtype=object,
# object_codec=Pickle(),
# )
else:
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
2024-05-09 05:06:41 +00:00
def _test_zarr_case(case: ValidationCase, store):
array = _zarr_array(case, store)
2024-05-09 05:06:41 +00:00
if case.passes:
case.model(array=array)
else:
with pytest.raises((ValidationError, DtypeError, ShapeError)):
case.model(array=array)
@pytest.fixture(
params=[
None, # use the default store
"dir_array",
"zip_array",
"nested_dir_array",
],
ids=["MutableMapping", "DirectoryStore", "ZipStore", "NestedDirectoryStore"],
2024-04-30 02:49:38 +00:00
)
2024-05-09 05:06:41 +00:00
def store(request):
if isinstance(request.param, str):
return request.getfixturevalue(request.param)
else:
return request.param
2024-04-30 02:49:38 +00:00
def test_zarr_enabled():
assert ZarrInterface.enabled()
def test_zarr_check(interface_type):
"""
We should only use the zarr interface for zarr-like things
"""
if interface_type[1] is ZarrInterface:
assert ZarrInterface.check(interface_type[0])
else:
assert not ZarrInterface.check(interface_type[0])
2024-05-09 05:06:41 +00:00
def test_zarr_shape(store, shape_cases):
_test_zarr_case(shape_cases, store)
def test_zarr_dtype(dtype_cases, store):
_test_zarr_case(dtype_cases, store)
@pytest.mark.parametrize("array", ["zarr_nested_array", "zarr_array"])
def test_zarr_from_tuple(array, model_blank, request):
"""Should be able to do the same validation logic from tuples as an input"""
array = request.getfixturevalue(array)
if isinstance(array, ZarrArrayPath):
instance = model_blank(array=(array.file, array.path))
else:
instance = model_blank(array=(array,))
def test_zarr_from_path(zarr_array, model_blank):
"""Should be able to just pass a path"""
instance = model_blank(array=zarr_array)
def test_zarr_array_path_from_iterable(zarr_array):
"""Construct a zarr array path from some iterable!!!"""
# from a single path
apath = ZarrArrayPath.from_iterable((zarr_array,))
assert apath.file == zarr_array
assert apath.path is None
inner_path = "/test/array"
apath = ZarrArrayPath.from_iterable((zarr_array, inner_path))
assert apath.file == zarr_array
assert apath.path == inner_path
@pytest.mark.serialization
@pytest.mark.parametrize("dump_array", [True, False])
@pytest.mark.parametrize("roundtrip", [True, False])
def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
expected_fields = (
"Type",
"Data type",
"Shape",
"Chunk shape",
"Compressor",
"Store type",
"hexdigest",
)
lol_array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
array = zarr.array(lol_array, store=store)
instance = model_blank(array=array)
2024-09-22 01:26:25 +00:00
context = {"dump_array": dump_array}
as_json = json.loads(
instance.model_dump_json(round_trip=roundtrip, context=context)
)["array"]
if roundtrip:
if dump_array:
assert as_json["array"] == lol_array
else:
if as_json.get("file", False):
assert "array" not in as_json
for field in expected_fields:
assert field in as_json["info"]
assert len(as_json["info"]["hexdigest"]) == 40
else:
assert as_json == lol_array