support datetimes in hdf5 proxies

This commit is contained in:
sneakers-the-rat 2024-09-03 16:54:31 -07:00
parent c46015d306
commit 56c5b9ac79
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
4 changed files with 92 additions and 11 deletions

View file

@ -25,8 +25,9 @@ Interfaces for HDF5 Datasets
"""
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
import numpy as np
from pydantic import SerializationInfo
@ -46,6 +47,8 @@ else:
H5Arraylike: TypeAlias = Tuple[Union[Path, str], str]
T = TypeVar("T")
class H5ArrayPath(NamedTuple):
"""Location specifier for arrays within an HDF5 file"""
@ -77,6 +80,7 @@ class H5Proxy:
path (str): Path to array within hdf5 file
field (str, list[str]): Optional - refer to a specific field within
a compound dtype
annotation_dtype (dtype): Optional - the dtype of our type annotation
"""
def __init__(
@ -84,11 +88,13 @@ class H5Proxy:
file: Union[Path, str],
path: str,
field: Optional[Union[str, List[str]]] = None,
annotation_dtype: Optional[DtypeType] = None,
):
self._h5f = None
self.file = Path(file)
self.path = path
self.field = field
self._annotation_dtype = annotation_dtype
def array_exists(self) -> bool:
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
@ -120,10 +126,12 @@ class H5Proxy:
def __getitem__(
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
) -> np.ndarray:
) -> Union[np.ndarray, DtypeType]:
with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path)
# handle compound dtypes
if self.field is not None:
# handle compound string dtype
if encoding := h5py.h5t.check_string_dtype(obj.dtype[self.field]):
if isinstance(item, tuple):
item = (*item, self.field)
@ -132,24 +140,41 @@ class H5Proxy:
try:
# single string
return obj[item].decode(encoding.encoding)
val = obj[item].decode(encoding.encoding)
if self._annotation_dtype is np.datetime64:
return np.datetime64(val)
else:
return val
except AttributeError:
# numpy array of bytes
return np.char.decode(obj[item], encoding=encoding.encoding)
val = np.char.decode(obj[item], encoding=encoding.encoding)
if self._annotation_dtype is np.datetime64:
return val.astype(np.datetime64)
else:
return val
# normal compound type
else:
obj = obj.fields(self.field)
else:
if h5py.h5t.check_string_dtype(obj.dtype):
obj = obj.asstr()
return obj[item]
val = obj[item]
if self._annotation_dtype is np.datetime64:
if isinstance(val, str):
return np.datetime64(val)
else:
return val.astype(np.datetime64)
else:
return val
def __setitem__(
self,
key: Union[int, slice, Tuple[Union[int, slice], ...]],
value: Union[int, float, np.ndarray],
value: Union[int, float, datetime, np.ndarray],
):
# TODO: Make a generalized value serdes system instead of ad-hoc type conversion
value = self._serialize_datetime(value)
with h5py.File(self.file, "r+", locking=True) as h5f:
obj = h5f.get(self.path)
if self.field is None:
@ -184,6 +209,16 @@ class H5Proxy:
self._h5f.close()
self._h5f = None
def _serialize_datetime(self, v: Union[T, datetime]) -> Union[T, bytes]:
"""
Convert a datetime into a bytestring
"""
if self._annotation_dtype is np.datetime64:
if not isinstance(v, Iterable):
v = [v]
v = np.array(v).astype("S32")
return v
class H5Interface(Interface):
"""
@ -253,6 +288,7 @@ class H5Interface(Interface):
"Need to specify a file and a path within an HDF5 file to use the HDF5 "
"Interface"
)
array._annotation_dtype = self.dtype
if not array.array_exists():
raise ValueError(
@ -269,7 +305,14 @@ class H5Interface(Interface):
Subclasses to correctly handle
"""
if h5py.h5t.check_string_dtype(array.dtype):
return str
# check for datetimes
try:
if array[0].dtype.type is np.datetime64:
return np.datetime64
else:
return str
except (AttributeError, ValueError, TypeError):
return str
else:
return array.dtype

View file

@ -166,7 +166,11 @@ def _hash_schema(schema: CoreSchema) -> str:
to produce the same hash.
"""
schema_str = json.dumps(
schema, sort_keys=True, indent=None, separators=(",", ":")
schema,
sort_keys=True,
indent=None,
separators=(",", ":"),
default=lambda x: None,
).encode("utf-8")
hasher = hashlib.blake2b(digest_size=8)
hasher.update(schema_str)

View file

@ -2,6 +2,7 @@ import shutil
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Type, Union
from warnings import warn
from datetime import datetime, timezone
import h5py
import numpy as np
@ -126,15 +127,24 @@ def hdf5_array(
if not compound:
if dtype is str:
data = np.random.random(shape).astype(bytes)
elif dtype is datetime:
data = np.empty(shape, dtype="S32")
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
else:
data = np.random.random(shape).astype(dtype)
_ = hdf5_file.create_dataset(array_path, data=data)
return H5ArrayPath(Path(hdf5_file.filename), array_path)
else:
if dtype is str:
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
elif dtype is datetime:
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
data = np.array(
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
* np.prod(shape),
dtype=dt,
).reshape(shape)
else:
dt = np.dtype([("data", dtype), ("extra", "i8")])
data = np.zeros(shape, dtype=dt)

View file

@ -1,8 +1,9 @@
import json
from datetime import datetime, timezone
from typing import Any
import h5py
import pytest
from pydantic import BaseModel, ValidationError
import numpy as np
@ -174,3 +175,26 @@ def test_strings(hdf5_array, compound):
instance.array[1] = "sup"
assert all(instance.array[1] == "sup")
@pytest.mark.parametrize("compound", [True, False])
def test_datetime(hdf5_array, compound):
"""
We can treat S32 byte arrays as datetimes if our type annotation
says to, including validation, setting and getting values
"""
array = hdf5_array((10, 10), datetime, compound=compound)
class MyModel(BaseModel):
array: NDArray[Any, datetime]
instance = MyModel(array=array)
assert isinstance(instance.array[0, 0], np.datetime64)
assert instance.array[0:5].dtype.type is np.datetime64
now = datetime.now()
instance.array[0, 0] = now
assert instance.array[0, 0] == now
instance.array[0] = now
assert all(instance.array[0] == now)