mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
support datetimes in hdf5 proxies
This commit is contained in:
parent
c46015d306
commit
56c5b9ac79
4 changed files with 92 additions and 11 deletions
|
@ -25,8 +25,9 @@ Interfaces for HDF5 Datasets
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
@ -46,6 +47,8 @@ else:
|
||||||
|
|
||||||
H5Arraylike: TypeAlias = Tuple[Union[Path, str], str]
|
H5Arraylike: TypeAlias = Tuple[Union[Path, str], str]
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class H5ArrayPath(NamedTuple):
|
class H5ArrayPath(NamedTuple):
|
||||||
"""Location specifier for arrays within an HDF5 file"""
|
"""Location specifier for arrays within an HDF5 file"""
|
||||||
|
@ -77,6 +80,7 @@ class H5Proxy:
|
||||||
path (str): Path to array within hdf5 file
|
path (str): Path to array within hdf5 file
|
||||||
field (str, list[str]): Optional - refer to a specific field within
|
field (str, list[str]): Optional - refer to a specific field within
|
||||||
a compound dtype
|
a compound dtype
|
||||||
|
annotation_dtype (dtype): Optional - the dtype of our type annotation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -84,11 +88,13 @@ class H5Proxy:
|
||||||
file: Union[Path, str],
|
file: Union[Path, str],
|
||||||
path: str,
|
path: str,
|
||||||
field: Optional[Union[str, List[str]]] = None,
|
field: Optional[Union[str, List[str]]] = None,
|
||||||
|
annotation_dtype: Optional[DtypeType] = None,
|
||||||
):
|
):
|
||||||
self._h5f = None
|
self._h5f = None
|
||||||
self.file = Path(file)
|
self.file = Path(file)
|
||||||
self.path = path
|
self.path = path
|
||||||
self.field = field
|
self.field = field
|
||||||
|
self._annotation_dtype = annotation_dtype
|
||||||
|
|
||||||
def array_exists(self) -> bool:
|
def array_exists(self) -> bool:
|
||||||
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
||||||
|
@ -120,10 +126,12 @@ class H5Proxy:
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
|
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||||
) -> np.ndarray:
|
) -> Union[np.ndarray, DtypeType]:
|
||||||
with h5py.File(self.file, "r") as h5f:
|
with h5py.File(self.file, "r") as h5f:
|
||||||
obj = h5f.get(self.path)
|
obj = h5f.get(self.path)
|
||||||
|
# handle compound dtypes
|
||||||
if self.field is not None:
|
if self.field is not None:
|
||||||
|
# handle compound string dtype
|
||||||
if encoding := h5py.h5t.check_string_dtype(obj.dtype[self.field]):
|
if encoding := h5py.h5t.check_string_dtype(obj.dtype[self.field]):
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
item = (*item, self.field)
|
item = (*item, self.field)
|
||||||
|
@ -132,24 +140,41 @@ class H5Proxy:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# single string
|
# single string
|
||||||
return obj[item].decode(encoding.encoding)
|
val = obj[item].decode(encoding.encoding)
|
||||||
|
if self._annotation_dtype is np.datetime64:
|
||||||
|
return np.datetime64(val)
|
||||||
|
else:
|
||||||
|
return val
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# numpy array of bytes
|
# numpy array of bytes
|
||||||
return np.char.decode(obj[item], encoding=encoding.encoding)
|
val = np.char.decode(obj[item], encoding=encoding.encoding)
|
||||||
|
if self._annotation_dtype is np.datetime64:
|
||||||
|
return val.astype(np.datetime64)
|
||||||
|
else:
|
||||||
|
return val
|
||||||
|
# normal compound type
|
||||||
else:
|
else:
|
||||||
obj = obj.fields(self.field)
|
obj = obj.fields(self.field)
|
||||||
else:
|
else:
|
||||||
if h5py.h5t.check_string_dtype(obj.dtype):
|
if h5py.h5t.check_string_dtype(obj.dtype):
|
||||||
obj = obj.asstr()
|
obj = obj.asstr()
|
||||||
|
|
||||||
return obj[item]
|
val = obj[item]
|
||||||
|
if self._annotation_dtype is np.datetime64:
|
||||||
|
if isinstance(val, str):
|
||||||
|
return np.datetime64(val)
|
||||||
|
else:
|
||||||
|
return val.astype(np.datetime64)
|
||||||
|
else:
|
||||||
|
return val
|
||||||
|
|
||||||
def __setitem__(
|
def __setitem__(
|
||||||
self,
|
self,
|
||||||
key: Union[int, slice, Tuple[Union[int, slice], ...]],
|
key: Union[int, slice, Tuple[Union[int, slice], ...]],
|
||||||
value: Union[int, float, np.ndarray],
|
value: Union[int, float, datetime, np.ndarray],
|
||||||
):
|
):
|
||||||
|
# TODO: Make a generalized value serdes system instead of ad-hoc type conversion
|
||||||
|
value = self._serialize_datetime(value)
|
||||||
with h5py.File(self.file, "r+", locking=True) as h5f:
|
with h5py.File(self.file, "r+", locking=True) as h5f:
|
||||||
obj = h5f.get(self.path)
|
obj = h5f.get(self.path)
|
||||||
if self.field is None:
|
if self.field is None:
|
||||||
|
@ -184,6 +209,16 @@ class H5Proxy:
|
||||||
self._h5f.close()
|
self._h5f.close()
|
||||||
self._h5f = None
|
self._h5f = None
|
||||||
|
|
||||||
|
def _serialize_datetime(self, v: Union[T, datetime]) -> Union[T, bytes]:
|
||||||
|
"""
|
||||||
|
Convert a datetime into a bytestring
|
||||||
|
"""
|
||||||
|
if self._annotation_dtype is np.datetime64:
|
||||||
|
if not isinstance(v, Iterable):
|
||||||
|
v = [v]
|
||||||
|
v = np.array(v).astype("S32")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class H5Interface(Interface):
|
class H5Interface(Interface):
|
||||||
"""
|
"""
|
||||||
|
@ -253,6 +288,7 @@ class H5Interface(Interface):
|
||||||
"Need to specify a file and a path within an HDF5 file to use the HDF5 "
|
"Need to specify a file and a path within an HDF5 file to use the HDF5 "
|
||||||
"Interface"
|
"Interface"
|
||||||
)
|
)
|
||||||
|
array._annotation_dtype = self.dtype
|
||||||
|
|
||||||
if not array.array_exists():
|
if not array.array_exists():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -269,6 +305,13 @@ class H5Interface(Interface):
|
||||||
Subclasses to correctly handle
|
Subclasses to correctly handle
|
||||||
"""
|
"""
|
||||||
if h5py.h5t.check_string_dtype(array.dtype):
|
if h5py.h5t.check_string_dtype(array.dtype):
|
||||||
|
# check for datetimes
|
||||||
|
try:
|
||||||
|
if array[0].dtype.type is np.datetime64:
|
||||||
|
return np.datetime64
|
||||||
|
else:
|
||||||
|
return str
|
||||||
|
except (AttributeError, ValueError, TypeError):
|
||||||
return str
|
return str
|
||||||
else:
|
else:
|
||||||
return array.dtype
|
return array.dtype
|
||||||
|
|
|
@ -166,7 +166,11 @@ def _hash_schema(schema: CoreSchema) -> str:
|
||||||
to produce the same hash.
|
to produce the same hash.
|
||||||
"""
|
"""
|
||||||
schema_str = json.dumps(
|
schema_str = json.dumps(
|
||||||
schema, sort_keys=True, indent=None, separators=(",", ":")
|
schema,
|
||||||
|
sort_keys=True,
|
||||||
|
indent=None,
|
||||||
|
separators=(",", ":"),
|
||||||
|
default=lambda x: None,
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
hasher = hashlib.blake2b(digest_size=8)
|
hasher = hashlib.blake2b(digest_size=8)
|
||||||
hasher.update(schema_str)
|
hasher.update(schema_str)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Optional, Tuple, Type, Union
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -126,15 +127,24 @@ def hdf5_array(
|
||||||
if not compound:
|
if not compound:
|
||||||
if dtype is str:
|
if dtype is str:
|
||||||
data = np.random.random(shape).astype(bytes)
|
data = np.random.random(shape).astype(bytes)
|
||||||
|
elif dtype is datetime:
|
||||||
|
data = np.empty(shape, dtype="S32")
|
||||||
|
data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8"))
|
||||||
else:
|
else:
|
||||||
data = np.random.random(shape).astype(dtype)
|
data = np.random.random(shape).astype(dtype)
|
||||||
_ = hdf5_file.create_dataset(array_path, data=data)
|
_ = hdf5_file.create_dataset(array_path, data=data)
|
||||||
return H5ArrayPath(Path(hdf5_file.filename), array_path)
|
return H5ArrayPath(Path(hdf5_file.filename), array_path)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if dtype is str:
|
if dtype is str:
|
||||||
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
||||||
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
|
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
|
||||||
|
elif dtype is datetime:
|
||||||
|
dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")])
|
||||||
|
data = np.array(
|
||||||
|
[(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)]
|
||||||
|
* np.prod(shape),
|
||||||
|
dtype=dt,
|
||||||
|
).reshape(shape)
|
||||||
else:
|
else:
|
||||||
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
||||||
data = np.zeros(shape, dtype=dt)
|
data = np.zeros(shape, dtype=dt)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -174,3 +175,26 @@ def test_strings(hdf5_array, compound):
|
||||||
|
|
||||||
instance.array[1] = "sup"
|
instance.array[1] = "sup"
|
||||||
assert all(instance.array[1] == "sup")
|
assert all(instance.array[1] == "sup")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("compound", [True, False])
|
||||||
|
def test_datetime(hdf5_array, compound):
|
||||||
|
"""
|
||||||
|
We can treat S32 byte arrays as datetimes if our type annotation
|
||||||
|
says to, including validation, setting and getting values
|
||||||
|
"""
|
||||||
|
array = hdf5_array((10, 10), datetime, compound=compound)
|
||||||
|
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray[Any, datetime]
|
||||||
|
|
||||||
|
instance = MyModel(array=array)
|
||||||
|
assert isinstance(instance.array[0, 0], np.datetime64)
|
||||||
|
assert instance.array[0:5].dtype.type is np.datetime64
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
|
||||||
|
instance.array[0, 0] = now
|
||||||
|
assert instance.array[0, 0] == now
|
||||||
|
instance.array[0] = now
|
||||||
|
assert all(instance.array[0] == now)
|
||||||
|
|
Loading…
Reference in a new issue