mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-12 17:54:29 +00:00
add ability to index hdf5 compound datasets
This commit is contained in:
parent
6a397a9aba
commit
03fe97b7e0
3 changed files with 96 additions and 20 deletions
|
@ -4,7 +4,7 @@ Interfaces for HDF5 Datasets
|
|||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
@ -32,6 +32,8 @@ class H5ArrayPath(NamedTuple):
|
|||
"""Location of HDF5 file"""
|
||||
path: str
|
||||
"""Path within the HDF5 file"""
|
||||
field: Optional[Union[str, List[str]]] = None
|
||||
"""Refer to a specific field within a compound dtype"""
|
||||
|
||||
|
||||
class H5Proxy:
|
||||
|
@ -51,12 +53,20 @@ class H5Proxy:
|
|||
Args:
|
||||
file (pathlib.Path | str): Location of hdf5 file on filesystem
|
||||
path (str): Path to array within hdf5 file
|
||||
field (str, list[str]): Optional - refer to a specific field within
|
||||
a compound dtype
|
||||
"""
|
||||
|
||||
def __init__(self, file: Union[Path, str], path: str):
|
||||
def __init__(
|
||||
self,
|
||||
file: Union[Path, str],
|
||||
path: str,
|
||||
field: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
self._h5f = None
|
||||
self.file = Path(file)
|
||||
self.path = path
|
||||
self.field = field
|
||||
|
||||
def array_exists(self) -> bool:
|
||||
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
||||
|
@ -67,21 +77,43 @@ class H5Proxy:
|
|||
@classmethod
|
||||
def from_h5array(cls, h5array: H5ArrayPath) -> "H5Proxy":
|
||||
"""Instantiate using :class:`.H5ArrayPath`"""
|
||||
return H5Proxy(file=h5array.file, path=h5array.path)
|
||||
return H5Proxy(file=h5array.file, path=h5array.path, field=h5array.field)
|
||||
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
"""
|
||||
Get dtype of array, using :attr:`.field` if present
|
||||
"""
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
if self.field is None:
|
||||
return obj.dtype
|
||||
else:
|
||||
return obj.dtype[self.field]
|
||||
|
||||
def __getattr__(self, item: str):
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
return getattr(obj, item)
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> np.ndarray:
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
) -> np.ndarray:
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
if self.field is not None:
|
||||
obj = obj.fields(self.field)
|
||||
return obj[item]
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Union[int, float, np.ndarray]):
|
||||
def __setitem__(
|
||||
self,
|
||||
key: Union[int, slice, Tuple[Union[int, slice], ...]],
|
||||
value: Union[int, float, np.ndarray],
|
||||
):
|
||||
with h5py.File(self.file, "r+", locking=True) as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
if self.field is not None:
|
||||
obj = obj.fields(self.field)
|
||||
obj[key] = value
|
||||
|
||||
def open(self, mode: str = "r") -> "h5py.Dataset":
|
||||
|
@ -133,7 +165,7 @@ class H5Interface(Interface):
|
|||
if isinstance(array, H5ArrayPath):
|
||||
return True
|
||||
|
||||
if isinstance(array, (tuple, list)) and len(array) == 2:
|
||||
if isinstance(array, (tuple, list)) and len(array) in (2, 3):
|
||||
# check that the first arg is an hdf5 file
|
||||
try:
|
||||
file = Path(array[0])
|
||||
|
@ -163,6 +195,8 @@ class H5Interface(Interface):
|
|||
array = H5Proxy.from_h5array(h5array=array)
|
||||
elif isinstance(array, (tuple, list)) and len(array) == 2: # pragma: no cover
|
||||
array = H5Proxy(file=array[0], path=array[1])
|
||||
elif isinstance(array, (tuple, list)) and len(array) == 3:
|
||||
array = H5Proxy(file=array[0], path=array[1], field=array[2])
|
||||
else: # pragma: no cover
|
||||
# this should never happen really since `check` confirms this before
|
||||
# we'd reach here, but just to complete the if else...
|
||||
|
|
|
@ -116,12 +116,20 @@ def hdf5_array(
|
|||
) -> Callable[[Tuple[int, ...], Union[np.dtype, type]], H5ArrayPath]:
|
||||
|
||||
def _hdf5_array(
|
||||
shape: Tuple[int, ...] = (10, 10), dtype: Union[np.dtype, type] = float
|
||||
shape: Tuple[int, ...] = (10, 10),
|
||||
dtype: Union[np.dtype, type] = float,
|
||||
compound: bool = False,
|
||||
) -> H5ArrayPath:
|
||||
array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__
|
||||
data = np.random.random(shape).astype(dtype)
|
||||
_ = hdf5_file.create_dataset(array_path, data=data)
|
||||
return H5ArrayPath(Path(hdf5_file.filename), array_path)
|
||||
if not compound:
|
||||
data = np.random.random(shape).astype(dtype)
|
||||
_ = hdf5_file.create_dataset(array_path, data=data)
|
||||
return H5ArrayPath(Path(hdf5_file.filename), array_path)
|
||||
else:
|
||||
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
||||
data = np.zeros(shape, dtype=dt)
|
||||
_ = hdf5_file.create_dataset(array_path, data=data)
|
||||
return H5ArrayPath(Path(hdf5_file.filename), array_path, "data")
|
||||
|
||||
return _hdf5_array
|
||||
|
||||
|
|
|
@ -1,17 +1,22 @@
|
|||
import pdb
|
||||
import json
|
||||
|
||||
import h5py
|
||||
import pytest
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import numpy as np
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.interface import H5Interface
|
||||
from numpydantic.interface.hdf5 import H5ArrayPath
|
||||
from numpydantic.interface.hdf5 import H5ArrayPath, H5Proxy
|
||||
from numpydantic.exceptions import DtypeError, ShapeError
|
||||
|
||||
from tests.conftest import ValidationCase
|
||||
|
||||
|
||||
def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
|
||||
def hdf5_array_case(
|
||||
case: ValidationCase, array_func, compound: bool = False
|
||||
) -> H5ArrayPath:
|
||||
"""
|
||||
Args:
|
||||
case:
|
||||
|
@ -22,11 +27,11 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
|
|||
"""
|
||||
if issubclass(case.dtype, BaseModel):
|
||||
pytest.skip("hdf5 cant support arbitrary python objects")
|
||||
return array_func(case.shape, case.dtype)
|
||||
return array_func(case.shape, case.dtype, compound)
|
||||
|
||||
|
||||
def _test_hdf5_case(case: ValidationCase, array_func):
|
||||
array = hdf5_array_case(case, array_func)
|
||||
def _test_hdf5_case(case: ValidationCase, array_func, compound: bool = False) -> None:
|
||||
array = hdf5_array_case(case, array_func, compound)
|
||||
if case.passes:
|
||||
case.model(array=array)
|
||||
else:
|
||||
|
@ -66,14 +71,16 @@ def test_hdf5_check_not_hdf5(tmp_path):
|
|||
assert not H5Interface.check(spec)
|
||||
|
||||
|
||||
def test_hdf5_shape(shape_cases, hdf5_array):
|
||||
_test_hdf5_case(shape_cases, hdf5_array)
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_hdf5_shape(shape_cases, hdf5_array, compound):
|
||||
_test_hdf5_case(shape_cases, hdf5_array, compound)
|
||||
|
||||
|
||||
def test_hdf5_dtype(dtype_cases, hdf5_array):
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
|
||||
if dtype_cases.dtype is str:
|
||||
pytest.skip("hdf5 cant do string arrays")
|
||||
_test_hdf5_case(dtype_cases, hdf5_array)
|
||||
_test_hdf5_case(dtype_cases, hdf5_array, compound)
|
||||
|
||||
|
||||
def test_hdf5_dataset_not_exists(hdf5_array, model_blank):
|
||||
|
@ -116,3 +123,30 @@ def test_to_json(hdf5_array, array_model):
|
|||
assert json_dict["path"] == str(array.path)
|
||||
assert json_dict["attrs"] == {}
|
||||
assert json_dict["array"] == instance.array[:].tolist()
|
||||
|
||||
|
||||
def test_compound_dtype(tmp_path):
|
||||
"""
|
||||
hdf5 proxy indexes compound dtypes as single fields when field is given
|
||||
"""
|
||||
h5f_path = tmp_path / "test.h5"
|
||||
dataset_path = "/dataset"
|
||||
field = "data"
|
||||
dtype = np.dtype([(field, "i8"), ("extra", "f8")])
|
||||
data = np.zeros((10, 20), dtype=dtype)
|
||||
with h5py.File(h5f_path, "w") as h5f:
|
||||
dset = h5f.create_dataset(dataset_path, data=data)
|
||||
assert dset.dtype == dtype
|
||||
|
||||
proxy = H5Proxy(h5f_path, dataset_path, field=field)
|
||||
assert proxy.dtype == np.dtype("int64")
|
||||
assert proxy.shape == (10, 20)
|
||||
assert proxy[0, 0] == 0
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Shape["10, 20"], np.int64]
|
||||
|
||||
instance = MyModel(array=(h5f_path, dataset_path, field))
|
||||
assert instance.array.dtype == np.dtype("int64")
|
||||
assert instance.array.shape == (10, 20)
|
||||
assert instance.array[0, 0] == 0
|
||||
|
|
Loading…
Reference in a new issue