diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 0bb8720..7861f3c 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -25,8 +25,9 @@ Interfaces for HDF5 Datasets """ import sys +from datetime import datetime from pathlib import Path -from typing import Any, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union import numpy as np from pydantic import SerializationInfo @@ -46,6 +47,8 @@ else: H5Arraylike: TypeAlias = Tuple[Union[Path, str], str] +T = TypeVar("T") + class H5ArrayPath(NamedTuple): """Location specifier for arrays within an HDF5 file""" @@ -77,6 +80,7 @@ class H5Proxy: path (str): Path to array within hdf5 file field (str, list[str]): Optional - refer to a specific field within a compound dtype + annotation_dtype (dtype): Optional - the dtype of our type annotation """ def __init__( @@ -84,11 +88,13 @@ class H5Proxy: file: Union[Path, str], path: str, field: Optional[Union[str, List[str]]] = None, + annotation_dtype: Optional[DtypeType] = None, ): self._h5f = None self.file = Path(file) self.path = path self.field = field + self._annotation_dtype = annotation_dtype def array_exists(self) -> bool: """Check that there is in fact an array at :attr:`.path` within :attr:`.file`""" @@ -120,10 +126,12 @@ class H5Proxy: def __getitem__( self, item: Union[int, slice, Tuple[Union[int, slice], ...]] - ) -> np.ndarray: + ) -> Union[np.ndarray, DtypeType]: with h5py.File(self.file, "r") as h5f: obj = h5f.get(self.path) + # handle compound dtypes if self.field is not None: + # handle compound string dtype if encoding := h5py.h5t.check_string_dtype(obj.dtype[self.field]): if isinstance(item, tuple): item = (*item, self.field) @@ -132,24 +140,41 @@ class H5Proxy: try: # single string - return obj[item].decode(encoding.encoding) + val = obj[item].decode(encoding.encoding) + if self._annotation_dtype is np.datetime64: + return np.datetime64(val) + else: + return val except AttributeError: # numpy array of bytes - return np.char.decode(obj[item], encoding=encoding.encoding) - + val = np.char.decode(obj[item], encoding=encoding.encoding) + if self._annotation_dtype is np.datetime64: + return val.astype(np.datetime64) + else: + return val + # normal compound type else: obj = obj.fields(self.field) else: if h5py.h5t.check_string_dtype(obj.dtype): obj = obj.asstr() - return obj[item] + val = obj[item] + if self._annotation_dtype is np.datetime64: + if isinstance(val, str): + return np.datetime64(val) + else: + return val.astype(np.datetime64) + else: + return val def __setitem__( self, key: Union[int, slice, Tuple[Union[int, slice], ...]], - value: Union[int, float, np.ndarray], + value: Union[int, float, datetime, np.ndarray], ): + # TODO: Make a generalized value serdes system instead of ad-hoc type conversion + value = self._serialize_datetime(value) with h5py.File(self.file, "r+", locking=True) as h5f: obj = h5f.get(self.path) if self.field is None: @@ -184,6 +209,16 @@ class H5Proxy: self._h5f.close() self._h5f = None + def _serialize_datetime(self, v: Union[T, datetime]) -> Union[T, bytes]: + """ + Convert a datetime into a bytestring + """ + if self._annotation_dtype is np.datetime64: + if not isinstance(v, Iterable): + v = [v] + v = np.array(v).astype("S32") + return v + class H5Interface(Interface): """ @@ -253,6 +288,7 @@ class H5Interface(Interface): "Need to specify a file and a path within an HDF5 file to use the HDF5 " "Interface" ) + array._annotation_dtype = self.dtype if not array.array_exists(): raise ValueError( @@ -269,7 +305,14 @@ class H5Interface(Interface): Subclasses to correctly handle """ if h5py.h5t.check_string_dtype(array.dtype): - return str + # check for datetimes + try: + if array[0].dtype.type is np.datetime64: + return np.datetime64 + else: + return str + except (AttributeError, ValueError, TypeError): + return str else: return array.dtype diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index 552c27a..d98f880 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -166,7 +166,11 @@ def _hash_schema(schema: CoreSchema) -> str: to produce the same hash. """ schema_str = json.dumps( - schema, sort_keys=True, indent=None, separators=(",", ":") + schema, + sort_keys=True, + indent=None, + separators=(",", ":"), + default=lambda x: None, ).encode("utf-8") hasher = hashlib.blake2b(digest_size=8) hasher.update(schema_str) diff --git a/tests/fixtures.py b/tests/fixtures.py index 9d5bba6..89359ae 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -2,6 +2,7 @@ import shutil from pathlib import Path from typing import Any, Callable, Optional, Tuple, Type, Union from warnings import warn +from datetime import datetime, timezone import h5py import numpy as np @@ -126,15 +127,24 @@ def hdf5_array( if not compound: if dtype is str: data = np.random.random(shape).astype(bytes) + elif dtype is datetime: + data = np.empty(shape, dtype="S32") + data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8")) else: data = np.random.random(shape).astype(dtype) _ = hdf5_file.create_dataset(array_path, data=data) return H5ArrayPath(Path(hdf5_file.filename), array_path) else: - if dtype is str: dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape) + elif dtype is datetime: + dt = np.dtype([("data", np.dtype("S32")), ("extra", "i8")]) + data = np.array( + [(datetime.now(timezone.utc).isoformat().encode("utf-8"), 0)] + * np.prod(shape), + dtype=dt, + ).reshape(shape) else: dt = np.dtype([("data", dtype), ("extra", "i8")]) data = np.zeros(shape, dtype=dt) diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 891dd9f..bf47a8d 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -1,8 +1,9 @@ import json +from datetime import datetime, timezone +from typing import Any import h5py import pytest - from pydantic import BaseModel, ValidationError import numpy as np @@ -174,3 +175,26 @@ def test_strings(hdf5_array, compound): instance.array[1] = "sup" assert all(instance.array[1] == "sup") + + +@pytest.mark.parametrize("compound", [True, False]) +def test_datetime(hdf5_array, compound): + """ + We can treat S32 byte arrays as datetimes if our type annotation + says to, including validation, setting and getting values + """ + array = hdf5_array((10, 10), datetime, compound=compound) + + class MyModel(BaseModel): + array: NDArray[Any, datetime] + + instance = MyModel(array=array) + assert isinstance(instance.array[0, 0], np.datetime64) + assert instance.array[0:5].dtype.type is np.datetime64 + + now = datetime.now() + + instance.array[0, 0] = now + assert instance.array[0, 0] == now + instance.array[0] = now + assert all(instance.array[0] == now)