mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
support datetimes in hdf5 proxies
This commit is contained in:
parent
c46015d306
commit
56c5b9ac79
4 changed files with 92 additions and 11 deletions
|
@ -25,8 +25,9 @@ Interfaces for HDF5 Datasets
|
|||
"""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
@ -46,6 +47,8 @@ else:
|
|||
|
||||
H5Arraylike: TypeAlias = Tuple[Union[Path, str], str]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class H5ArrayPath(NamedTuple):
|
||||
"""Location specifier for arrays within an HDF5 file"""
|
||||
|
@ -77,6 +80,7 @@ class H5Proxy:
|
|||
path (str): Path to array within hdf5 file
|
||||
field (str, list[str]): Optional - refer to a specific field within
|
||||
a compound dtype
|
||||
annotation_dtype (dtype): Optional - the dtype of our type annotation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -84,11 +88,13 @@ class H5Proxy:
|
|||
file: Union[Path, str],
|
||||
path: str,
|
||||
field: Optional[Union[str, List[str]]] = None,
|
||||
annotation_dtype: Optional[DtypeType] = None,
|
||||
):
|
||||
self._h5f = None
|
||||
self.file = Path(file)
|
||||
self.path = path
|
||||
self.field = field
|
||||
self._annotation_dtype = annotation_dtype
|
||||
|
||||
def array_exists(self) -> bool:
|
||||
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
||||
|
@ -120,10 +126,12 @@ class H5Proxy:
|
|||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
) -> np.ndarray:
|
||||
) -> Union[np.ndarray, DtypeType]:
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
# handle compound dtypes
|
||||
if self.field is not None:
|
||||
# handle compound string dtype
|
||||
if encoding := h5py.h5t.check_string_dtype(obj.dtype[self.field]):
|
||||
if isinstance(item, tuple):
|
||||
item = (*item, self.field)
|
||||
|
@ -132,24 +140,41 @@ class H5Proxy:
|
|||
|
||||
try:
|
||||
# single string
|
||||
return obj[item].decode(encoding.encoding)
|
||||
val = obj[item].decode(encoding.encoding)
|
||||
if self._annotation_dtype is np.datetime64:
|
||||
return np.datetime64(val)
|
||||
else:
|
||||
return val
|
||||
except AttributeError:
|
||||
# numpy array of bytes
|
||||
return np.char.decode(obj[item], encoding=encoding.encoding)
|
||||
|
||||
val = np.char.decode(obj[item], encoding=encoding.encoding)
|
||||
if self._annotation_dtype is np.datetime64:
|
||||
return val.astype(np.datetime64)
|
||||
else:
|
||||
return val
|
||||
# normal compound type
|
||||
else:
|
||||
obj = obj.fields(self.field)
|
||||
else:
|
||||
if h5py.h5t.check_string_dtype(obj.dtype):
|
||||
obj = obj.asstr()
|
||||
|
||||
return obj[item]
|
||||
val = obj[item]
|
||||
if self._annotation_dtype is np.datetime64:
|
||||
if isinstance(val, str):
|
||||
return np.datetime64(val)
|
||||
else:
|
||||
return val.astype(np.datetime64)
|
||||
else:
|
||||
return val
|
||||
|
||||
def __setitem__(
|
||||
self,
|
||||
key: Union[int, slice, Tuple[Union[int, slice], ...]],
|
||||
value: Union[int, float, np.ndarray],
|
||||
value: Union[int, float, datetime, np.ndarray],
|
||||
):
|
||||
# TODO: Make a generalized value serdes system instead of ad-hoc type conversion
|
||||
value = self._serialize_datetime(value)
|
||||
with h5py.File(self.file, "r+", locking=True) as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
if self.field is None:
|
||||
|
@ -184,6 +209,16 @@ class H5Proxy:
|
|||
self._h5f.close()
|
||||
self._h5f = None
|
||||
|
||||
def _serialize_datetime(self, v: Union[T, datetime]) -> Union[T, bytes]:
|
||||
"""
|
||||
Convert a datetime into a bytestring
|
||||
"""
|
||||
if self._annotation_dtype is np.datetime64:
|
||||
if not isinstance(v, Iterable):
|
||||
v = [v]
|
||||
v = np.array(v).astype("S32")
|
||||
return v
|
||||
|
||||
|
||||
class H5Interface(Interface):
|
||||
"""
|
||||
|
@ -253,6 +288,7 @@ class H5Interface(Interface):
|
|||
"Need to specify a file and a path within an HDF5 file to use the HDF5 "
|
||||
"Interface"
|
||||
)
|
||||
array._annotation_dtype = self.dtype
|
||||
|
||||
if not array.array_exists():
|
||||
raise ValueError(
|
||||
|
@ -269,7 +305,14 @@ class H5Interface(Interface):
|
|||
Subclasses to correctly handle
|
||||
"""
|
||||
if h5py.h5t.check_string_dtype(array.dtype):
|
||||
return str
|
||||
# check for datetimes
|
||||
try:
|
||||
if array[0].dtype.type is np.datetime64:
|
||||
return np.datetime64
|
||||
else:
|
||||
return str
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
return str
|
||||
else:
|
||||
return array.dtype
|
||||
|
||||
|
|
|
@ -166,7 +166,11 @@ def _hash_schema(schema: CoreSchema) -> str:
|
|||
to produce the same hash.
|
||||
"""
|
||||
schema_str = json.dumps(
|
||||
schema, sort_keys=True, indent=None, separators=(",", ":")
|
||||
schema,
|
||||
sort_keys=True,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
default=lambda x: None,
|
||||
).encode("utf-8")
|
||||
hasher = hashlib.blake2b(digest_size=8)
|
||||
hasher.update(schema_str)
|
||||
|
|
|
@ -2,6 +2,7 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Tuple, Type, Union
|
||||
from warnings import warn
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
@ -126,15 +127,24 @@ def hdf5_array(
|
|||
if not compound:
|
||||
if dtype is str:
|
||||
data = np.random.random(shape).astype(bytes)
|
||||
elif dtype is datetime:
|
||||
data = np.empty(shape, dtype="S32")
|
||||
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
|
||||
else:
|
||||
data = np.random.random(shape).astype(dtype)
|
||||
_ = hdf5_file.create_dataset(array_path, data=data)
|
||||
return H5ArrayPath(Path(hdf5_file.filename), array_path)
|
||||
else:
|
||||
|
||||
if dtype is str:
|
||||
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
||||
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
|
||||
elif dtype is datetime:
|
||||
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
|
||||
data = np.array(
|
||||
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
|
||||
* np.prod(shape),
|
||||
dtype=dt,
|
||||
).reshape(shape)
|
||||
else:
|
||||
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
||||
data = np.zeros(shape, dtype=dt)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import h5py
|
||||
import pytest
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import numpy as np
|
||||
|
@ -174,3 +175,26 @@ def test_strings(hdf5_array, compound):
|
|||
|
||||
instance.array[1] = "sup"
|
||||
assert all(instance.array[1] == "sup")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_datetime(hdf5_array, compound):
|
||||
"""
|
||||
We can treat S32 byte arrays as datetimes if our type annotation
|
||||
says to, including validation, setting and getting values
|
||||
"""
|
||||
array = hdf5_array((10, 10), datetime, compound=compound)
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Any, datetime]
|
||||
|
||||
instance = MyModel(array=array)
|
||||
assert isinstance(instance.array[0, 0], np.datetime64)
|
||||
assert instance.array[0:5].dtype.type is np.datetime64
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
instance.array[0, 0] = now
|
||||
assert instance.array[0, 0] == now
|
||||
instance.array[0] = now
|
||||
assert all(instance.array[0] == now)
|
||||
|
|
Loading…
Reference in a new issue