add ability to index hdf5 compound datasets

This commit is contained in:
sneakers-the-rat 2024-09-02 16:45:56 -07:00
parent 6a397a9aba
commit 03fe97b7e0
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
3 changed files with 96 additions and 20 deletions

View file

@ -4,7 +4,7 @@ Interfaces for HDF5 Datasets
import sys
from pathlib import Path
from typing import Any, NamedTuple, Optional, Tuple, Union
from typing import Any, List, NamedTuple, Optional, Tuple, Union
import numpy as np
from pydantic import SerializationInfo
@ -32,6 +32,8 @@ class H5ArrayPath(NamedTuple):
"""Location of HDF5 file"""
path: str
"""Path within the HDF5 file"""
field: Optional[Union[str, List[str]]] = None
"""Refer to a specific field within a compound dtype"""
class H5Proxy:
@ -51,12 +53,20 @@ class H5Proxy:
Args:
file (pathlib.Path | str): Location of hdf5 file on filesystem
path (str): Path to array within hdf5 file
field (str, list[str]): Optional - refer to a specific field within
a compound dtype
"""
def __init__(self, file: Union[Path, str], path: str):
def __init__(
self,
file: Union[Path, str],
path: str,
field: Optional[Union[str, List[str]]] = None,
):
self._h5f = None
self.file = Path(file)
self.path = path
self.field = field
def array_exists(self) -> bool:
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
@ -67,21 +77,43 @@ class H5Proxy:
@classmethod
def from_h5array(cls, h5array: H5ArrayPath) -> "H5Proxy":
"""Instantiate using :class:`.H5ArrayPath`"""
return H5Proxy(file=h5array.file, path=h5array.path)
return H5Proxy(file=h5array.file, path=h5array.path, field=h5array.field)
@property
def dtype(self) -> np.dtype:
"""
Get dtype of array, using :attr:`.field` if present
"""
with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path)
if self.field is None:
return obj.dtype
else:
return obj.dtype[self.field]
def __getattr__(self, item: str):
with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path)
return getattr(obj, item)
def __getitem__(self, item: Union[int, slice]) -> np.ndarray:
def __getitem__(
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
) -> np.ndarray:
with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path)
if self.field is not None:
obj = obj.fields(self.field)
return obj[item]
def __setitem__(self, key: Union[int, slice], value: Union[int, float, np.ndarray]):
def __setitem__(
self,
key: Union[int, slice, Tuple[Union[int, slice], ...]],
value: Union[int, float, np.ndarray],
):
with h5py.File(self.file, "r+", locking=True) as h5f:
obj = h5f.get(self.path)
if self.field is not None:
obj = obj.fields(self.field)
obj[key] = value
def open(self, mode: str = "r") -> "h5py.Dataset":
@ -133,7 +165,7 @@ class H5Interface(Interface):
if isinstance(array, H5ArrayPath):
return True
if isinstance(array, (tuple, list)) and len(array) == 2:
if isinstance(array, (tuple, list)) and len(array) in (2, 3):
# check that the first arg is an hdf5 file
try:
file = Path(array[0])
@ -163,6 +195,8 @@ class H5Interface(Interface):
array = H5Proxy.from_h5array(h5array=array)
elif isinstance(array, (tuple, list)) and len(array) == 2: # pragma: no cover
array = H5Proxy(file=array[0], path=array[1])
elif isinstance(array, (tuple, list)) and len(array) == 3:
array = H5Proxy(file=array[0], path=array[1], field=array[2])
else: # pragma: no cover
# this should never happen really since `check` confirms this before
# we'd reach here, but just to complete the if else...

View file

@ -116,12 +116,20 @@ def hdf5_array(
) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]:
def _hdf5_array(
shape: Tuple[int, ...] = (10, 10), dtype: Union[np.dtype, type] = float
shape: Tuple[int, ...] = (10, 10),
dtype: Union[np.dtype, type] = float,
compound: bool = False,
) -> H5ArrayPath:
array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
if not compound:
data = np.random.random(shape).astype(dtype)
_ = hdf5_file.create_dataset(array_path, data=data)
return H5ArrayPath(Path(hdf5_file.filename), array_path)
else:
dt = np.dtype([("data", dtype), ("extra", "i8")])
data = np.zeros(shape, dtype=dt)
_ = hdf5_file.create_dataset(array_path, data=data)
return H5ArrayPath(Path(hdf5_file.filename), array_path, "data")
return _hdf5_array

View file

@ -1,17 +1,22 @@
import pdb
import json
import h5py
import pytest
from pydantic import BaseModel, ValidationError
import numpy as np
from numpydantic import NDArray, Shape
from numpydantic.interface import H5Interface
from numpydantic.interface.hdf5 import H5ArrayPath
from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy
from numpydantic.exceptions import DtypeError, ShapeError
from tests.conftest import ValidationCase
def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
def hdf5_array_case(
case: ValidationCase, array_func, compound: bool = False
) -> H5ArrayPath:
"""
Args:
case:
@ -22,11 +27,11 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
"""
if issubclass(case.dtype, BaseModel):
pytest.skip("hdf5 cant support arbitrary python objects")
return array_func(case.shape, case.dtype)
return array_func(case.shape, case.dtype, compound)
def _test_hdf5_case(case: ValidationCase, array_func):
array = hdf5_array_case(case, array_func)
def _test_hdf5_case(case: ValidationCase, array_func, compound: bool = False) -> None:
array = hdf5_array_case(case, array_func, compound)
if case.passes:
case.model(array=array)
else:
@ -66,14 +71,16 @@ def test_hdf5_check_not_hdf5(tmp_path):
assert not H5Interface.check(spec)
def test_hdf5_shape(shape_cases, hdf5_array):
_test_hdf5_case(shape_cases, hdf5_array)
@pytest.mark.parametrize("compound", [True, False])
def test_hdf5_shape(shape_cases, hdf5_array, compound):
_test_hdf5_case(shape_cases, hdf5_array, compound)
def test_hdf5_dtype(dtype_cases, hdf5_array):
@pytest.mark.parametrize("compound", [True, False])
def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
if dtype_cases.dtype is str:
pytest.skip("hdf5 cant do string arrays")
_test_hdf5_case(dtype_cases, hdf5_array)
_test_hdf5_case(dtype_cases, hdf5_array, compound)
def test_hdf5_dataset_not_exists(hdf5_array, model_blank):
@ -116,3 +123,30 @@ def test_to_json(hdf5_array, array_model):
assert json_dict["path"] == str(array.path)
assert json_dict["attrs"] == {}
assert json_dict["array"] == instance.array[:].tolist()
def test_compound_dtype(tmp_path):
"""
hdf5 proxy indexes compound dtypes as single fields when field is given
"""
h5f_path = tmp_path / "test.h5"
dataset_path = "/dataset"
field = "data"
dtype = np.dtype([(field, "i8"), ("extra", "f8")])
data = np.zeros((10, 20), dtype=dtype)
with h5py.File(h5f_path, "w") as h5f:
dset = h5f.create_dataset(dataset_path, data=data)
assert dset.dtype == dtype
proxy = H5Proxy(h5f_path, dataset_path, field=field)
assert proxy.dtype == np.dtype("int64")
assert proxy.shape == (10, 20)
assert proxy[0, 0] == 0
class MyModel(BaseModel):
array: NDArray[Shape["10, 20"], np.int64]
instance = MyModel(array=(h5f_path, dataset_path, field))
assert instance.array.dtype == np.dtype("int64")
assert instance.array.shape == (10, 20)
assert instance.array[0, 0] == 0