numpydantic/tests/fixtures.py

202 lines
6.2 KiB
Python
Raw Permalink Normal View History

import shutil
from pathlib import Path
2024-05-15 03:18:04 +00:00
from typing import Any, Callable, Optional, Tuple, Type, Union
from warnings import warn
2024-09-03 23:54:31 +00:00
from datetime import datetime, timezone
import h5py
import numpy as np
import pytest
from pydantic import BaseModel, Field
2024-04-30 02:49:38 +00:00
import zarr
import cv2
from numpydantic.interface.hdf5 import H5ArrayPath
2024-04-30 02:49:38 +00:00
from numpydantic.interface.zarr import ZarrArrayPath
from numpydantic import NDArray, Shape
from numpydantic.maps import python_to_nptyping
from numpydantic.dtype import Number
@pytest.fixture(scope="session")
def tmp_output_dir(request: pytest.FixtureRequest) -> Path:
path = Path(__file__).parent.resolve() / "__tmp__"
if path.exists():
shutil.rmtree(str(path))
path.mkdir()
yield path
if not request.config.getvalue("--with-output"):
try:
shutil.rmtree(str(path))
except PermissionError as e:
# sporadic error on windows machines...
warn(
f"Temporary directory could not be removed due to a permissions error: \n{str(e)}"
)
@pytest.fixture(scope="function")
def tmp_output_dir_func(tmp_output_dir, request: pytest.FixtureRequest) -> Path:
"""
tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected
"""
subpath = tmp_output_dir / f"__tmpfunc_{request.node.name}__"
if subpath.exists():
shutil.rmtree(str(subpath))
subpath.mkdir()
return subpath
@pytest.fixture(scope="module")
def tmp_output_dir_mod(tmp_output_dir, request: pytest.FixtureRequest) -> Path:
"""
tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected
"""
subpath = tmp_output_dir / f"__tmpmod_{request.module}__"
if subpath.exists():
shutil.rmtree(str(subpath))
subpath.mkdir()
return subpath
2024-02-05 23:39:29 +00:00
@pytest.fixture(scope="function")
def array_model() -> (
Callable[[Tuple[int, ...], Union[Type, np.dtype]], Type[BaseModel]]
):
def _model(
shape: Tuple[int, ...] = (10, 10), dtype: Union[Type, np.dtype] = float
) -> Type[BaseModel]:
shape_str = ", ".join([str(s) for s in shape])
class MyModel(BaseModel):
2024-05-16 03:49:15 +00:00
array: NDArray[Shape[shape_str], dtype]
return MyModel
return _model
@pytest.fixture(scope="session")
def model_rgb() -> Type[BaseModel]:
class RGB(BaseModel):
array: Optional[
Union[
NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
]
] = Field(None)
return RGB
2024-05-15 03:18:04 +00:00
@pytest.fixture(scope="session")
def model_blank() -> Type[BaseModel]:
"""A model with any shape and dtype"""
class BlankModel(BaseModel):
array: NDArray[Shape["*, ..."], Any]
return BlankModel
@pytest.fixture(scope="function")
def hdf5_file(tmp_output_dir_func) -> h5py.File:
h5f_file = tmp_output_dir_func / "h5f.h5"
h5f = h5py.File(h5f_file, "w")
yield h5f
h5f.close()
@pytest.fixture(scope="function")
def hdf5_array(
hdf5_file, request
) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]:
def _hdf5_array(
shape: Tuple[int, ...] = (10, 10),
dtype: Union[np.dtype, type] = float,
compound: bool = False,
) -> H5ArrayPath:
array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
2024-09-03 05:29:58 +00:00
if not compound:
2024-09-03 05:29:58 +00:00
if dtype is str:
data = np.random.random(shape).astype(bytes)
2024-09-03 23:54:31 +00:00
elif dtype is datetime:
data = np.empty(shape, dtype="S32")
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
2024-09-03 05:29:58 +00:00
else:
data = np.random.random(shape).astype(dtype)
_ = hdf5_file.create_dataset(array_path, data=data)
return H5ArrayPath(Path(hdf5_file.filename), array_path)
else:
2024-09-03 05:14:47 +00:00
if dtype is str:
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
2024-09-03 23:54:31 +00:00
elif dtype is datetime:
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
data = np.array(
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
* np.prod(shape),
dtype=dt,
).reshape(shape)
2024-09-03 05:14:47 +00:00
else:
dt = np.dtype([("data", dtype), ("extra", "i8")])
data = np.zeros(shape, dtype=dt)
_ = hdf5_file.create_dataset(array_path, data=data)
return H5ArrayPath(Path(hdf5_file.filename), array_path, "data")
return _hdf5_array
2024-04-30 02:49:38 +00:00
@pytest.fixture(scope="function")
def zarr_nested_array(tmp_output_dir_func) -> ZarrArrayPath:
"""Zarr array within a nested array"""
file = tmp_output_dir_func / "nested.zarr"
path = "a/b/c"
root = zarr.open(str(file), mode="w")
array = root.zeros(path, shape=(100, 100), chunks=(10, 10))
return ZarrArrayPath(file=file, path=path)
@pytest.fixture(scope="function")
def zarr_array(tmp_output_dir_func) -> Path:
file = tmp_output_dir_func / "array.zarr"
array = zarr.open(str(file), mode="w", shape=(100, 100), chunks=(10, 10))
array[:] = 0
return file
@pytest.fixture(scope="function")
def avi_video(tmp_path) -> Callable[[Tuple[int, int], int, bool], Path]:
video_path = tmp_path / "test.avi"
def _make_video(shape=(100, 50), frames=10, is_color=True) -> Path:
writer = cv2.VideoWriter(
str(video_path),
cv2.VideoWriter_fourcc(*"RGBA"), # raw video for testing purposes
30,
(shape[1], shape[0]),
is_color,
)
if is_color:
shape = (*shape, 3)
for i in range(frames):
# make fresh array every time bc opencv eats them
array = np.zeros(shape, dtype=np.uint8)
if not is_color:
array[i, i] = i
else:
array[i, i, :] = i
writer.write(array)
writer.release()
return video_path
return _make_video