From 03fe97b7e09406481d6863eecd2ef4d9fb7dc9c0 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 2 Sep 2024 16:45:56 -0700 Subject: [PATCH] add ability to index hdf5 compound datasets --- src/numpydantic/interface/hdf5.py | 46 ++++++++++++++++++++++---- tests/fixtures.py | 16 ++++++--- tests/test_interface/test_hdf5.py | 54 +++++++++++++++++++++++++------ 3 files changed, 96 insertions(+), 20 deletions(-) diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 0bac99f..0dac4c6 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -4,7 +4,7 @@ Interfaces for HDF5 Datasets import sys from pathlib import Path -from typing import Any, NamedTuple, Optional, Tuple, Union +from typing import Any, List, NamedTuple, Optional, Tuple, Union import numpy as np from pydantic import SerializationInfo @@ -32,6 +32,8 @@ class H5ArrayPath(NamedTuple): """Location of HDF5 file""" path: str """Path within the HDF5 file""" + field: Optional[Union[str, List[str]]] = None + """Refer to a specific field within a compound dtype""" class H5Proxy: @@ -51,12 +53,20 @@ class H5Proxy: Args: file (pathlib.Path | str): Location of hdf5 file on filesystem path (str): Path to array within hdf5 file + field (str, list[str]): Optional - refer to a specific field within + a compound dtype """ - def __init__(self, file: Union[Path, str], path: str): + def __init__( + self, + file: Union[Path, str], + path: str, + field: Optional[Union[str, List[str]]] = None, + ): self._h5f = None self.file = Path(file) self.path = path + self.field = field def array_exists(self) -> bool: """Check that there is in fact an array at :attr:`.path` within :attr:`.file`""" @@ -67,21 +77,43 @@ class H5Proxy: @classmethod def from_h5array(cls, h5array: H5ArrayPath) -> "H5Proxy": """Instantiate using :class:`.H5ArrayPath`""" - return H5Proxy(file=h5array.file, path=h5array.path) + return H5Proxy(file=h5array.file, path=h5array.path, field=h5array.field) + + @property + def dtype(self) -> np.dtype: + """ + Get dtype of array, using :attr:`.field` if present + """ + with h5py.File(self.file, "r") as h5f: + obj = h5f.get(self.path) + if self.field is None: + return obj.dtype + else: + return obj.dtype[self.field] def __getattr__(self, item: str): with h5py.File(self.file, "r") as h5f: obj = h5f.get(self.path) return getattr(obj, item) - def __getitem__(self, item: Union[int, slice]) -> np.ndarray: + def __getitem__( + self, item: Union[int, slice, Tuple[Union[int, slice], ...]] + ) -> np.ndarray: with h5py.File(self.file, "r") as h5f: obj = h5f.get(self.path) + if self.field is not None: + obj = obj.fields(self.field) return obj[item] - def __setitem__(self, key: Union[int, slice], value: Union[int, float, np.ndarray]): + def __setitem__( + self, + key: Union[int, slice, Tuple[Union[int, slice], ...]], + value: Union[int, float, np.ndarray], + ): with h5py.File(self.file, "r+", locking=True) as h5f: obj = h5f.get(self.path) + if self.field is not None: + obj = obj.fields(self.field) obj[key] = value def open(self, mode: str = "r") -> "h5py.Dataset": @@ -133,7 +165,7 @@ class H5Interface(Interface): if isinstance(array, H5ArrayPath): return True - if isinstance(array, (tuple, list)) and len(array) == 2: + if isinstance(array, (tuple, list)) and len(array) in (2, 3): # check that the first arg is an hdf5 file try: file = Path(array[0]) @@ -163,6 +195,8 @@ class H5Interface(Interface): array = H5Proxy.from_h5array(h5array=array) elif isinstance(array, (tuple, list)) and len(array) == 2: # pragma: no cover array = H5Proxy(file=array[0], path=array[1]) + elif isinstance(array, (tuple, list)) and len(array) == 3: + array = H5Proxy(file=array[0], path=array[1], field=array[2]) else: # pragma: no cover # this should never happen really since `check` confirms this before # we'd reach here, but just to complete the if else... diff --git a/tests/fixtures.py b/tests/fixtures.py index 7c14b35..347152c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -116,12 +116,20 @@ def hdf5_array( ) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]: def _hdf5_array( - shape: Tuple[int, ...] = (10, 10), dtype: Union[np.dtype, type] = float + shape: Tuple[int, ...] = (10, 10), + dtype: Union[np.dtype, type] = float, + compound: bool = False, ) -> H5ArrayPath: array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ - data = np.random.random(shape).astype(dtype) - _ = hdf5_file.create_dataset(array_path, data=data) - return H5ArrayPath(Path(hdf5_file.filename), array_path) + if not compound: + data = np.random.random(shape).astype(dtype) + _ = hdf5_file.create_dataset(array_path, data=data) + return H5ArrayPath(Path(hdf5_file.filename), array_path) + else: + dt = np.dtype([("data", dtype), ("extra", "i8")]) + data = np.zeros(shape, dtype=dt) + _ = hdf5_file.create_dataset(array_path, data=data) + return H5ArrayPath(Path(hdf5_file.filename), array_path, "data") return _hdf5_array diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 78af785..cdb255a 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -1,17 +1,22 @@ -import pdb import json + +import h5py import pytest from pydantic import BaseModel, ValidationError +import numpy as np +from numpydantic import NDArray, Shape from numpydantic.interface import H5Interface -from numpydantic.interface.hdf5 import H5ArrayPath +from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy from numpydantic.exceptions import DtypeError, ShapeError from tests.conftest import ValidationCase -def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath: +def hdf5_array_case( + case: ValidationCase, array_func, compound: bool = False +) -> H5ArrayPath: """ Args: case: @@ -22,11 +27,11 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath: """ if issubclass(case.dtype, BaseModel): pytest.skip("hdf5 cant support arbitrary python objects") - return array_func(case.shape, case.dtype) + return array_func(case.shape, case.dtype, compound) -def _test_hdf5_case(case: ValidationCase, array_func): - array = hdf5_array_case(case, array_func) +def _test_hdf5_case(case: ValidationCase, array_func, compound: bool = False) -> None: + array = hdf5_array_case(case, array_func, compound) if case.passes: case.model(array=array) else: @@ -66,14 +71,16 @@ def test_hdf5_check_not_hdf5(tmp_path): assert not H5Interface.check(spec) -def test_hdf5_shape(shape_cases, hdf5_array): - _test_hdf5_case(shape_cases, hdf5_array) +@pytest.mark.parametrize("compound", [True, False]) +def test_hdf5_shape(shape_cases, hdf5_array, compound): + _test_hdf5_case(shape_cases, hdf5_array, compound) -def test_hdf5_dtype(dtype_cases, hdf5_array): +@pytest.mark.parametrize("compound", [True, False]) +def test_hdf5_dtype(dtype_cases, hdf5_array, compound): if dtype_cases.dtype is str: pytest.skip("hdf5 cant do string arrays") - _test_hdf5_case(dtype_cases, hdf5_array) + _test_hdf5_case(dtype_cases, hdf5_array, compound) def test_hdf5_dataset_not_exists(hdf5_array, model_blank): @@ -116,3 +123,30 @@ def test_to_json(hdf5_array, array_model): assert json_dict["path"] == str(array.path) assert json_dict["attrs"] == {} assert json_dict["array"] == instance.array[:].tolist() + + +def test_compound_dtype(tmp_path): + """ + hdf5 proxy indexes compound dtypes as single fields when field is given + """ + h5f_path = tmp_path / "test.h5" + dataset_path = "/dataset" + field = "data" + dtype = np.dtype([(field, "i8"), ("extra", "f8")]) + data = np.zeros((10, 20), dtype=dtype) + with h5py.File(h5f_path, "w") as h5f: + dset = h5f.create_dataset(dataset_path, data=data) + assert dset.dtype == dtype + + proxy = H5Proxy(h5f_path, dataset_path, field=field) + assert proxy.dtype == np.dtype("int64") + assert proxy.shape == (10, 20) + assert proxy[0, 0] == 0 + + class MyModel(BaseModel): + array: NDArray[Shape["10, 20"], np.int64] + + instance = MyModel(array=(h5f_path, dataset_path, field)) + assert instance.array.dtype == np.dtype("int64") + assert instance.array.shape == (10, 20) + assert instance.array[0, 0] == 0