From 0e6ea07d5e9eb1be5718b45874107a17271911f8 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 2 Sep 2024 22:14:47 -0700 Subject: [PATCH] Add support for strings in hdf5 --- src/numpydantic/interface/hdf5.py | 49 ++++++++++++++++++++++++-- src/numpydantic/interface/interface.py | 5 ++- tests/fixtures.py | 9 +++-- tests/test_interface/test_hdf5.py | 17 +++++++-- 4 files changed, 73 insertions(+), 7 deletions(-) diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 656273d..b63e760 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -1,5 +1,27 @@ """ Interfaces for HDF5 Datasets + +.. note:: + + HDF5 arrays are accessed through a proxy class :class:`.H5Proxy` . + Getting/setting values should work as normal, **except** that setting + values on nested views is impossible - + + Specifically this doesn't work: + + .. code-block:: python + + my_model.array[0][0] = 1 + + But this does work: + + .. code-block:: python + + my_model.array[0,0] = 1 + + To have direct access to the hdf5 dataset, use the + :meth:`.H5Proxy.open` method. + """ import sys @@ -10,7 +32,7 @@ import numpy as np from pydantic import SerializationInfo from numpydantic.interface.interface import Interface -from numpydantic.types import NDArrayType +from numpydantic.types import DtypeType, NDArrayType try: import h5py @@ -102,7 +124,14 @@ class H5Proxy: with h5py.File(self.file, "r") as h5f: obj = h5f.get(self.path) if self.field is not None: - obj = obj.fields(self.field) + if h5py.h5t.check_string_dtype(obj.dtype[self.field]): + obj = obj.fields(self.field).asstr() + else: + obj = obj.fields(self.field) + else: + if h5py.h5t.check_string_dtype(obj.dtype): + obj = obj.asstr() + return obj[item] def __setitem__( @@ -222,6 +251,22 @@ class H5Interface(Interface): return array + def get_dtype(self, array: NDArrayType) -> DtypeType: + """ + Get the dtype from the input array + + Subclasses to correctly handle + """ + if hasattr(array.dtype, "type") and array.dtype.type is np.object_: + if h5py.h5t.check_string_dtype(array.dtype): + return str + else: + return self.get_object_dtype(array) + elif h5py.h5t.check_string_dtype(array.dtype): + return str + else: + return array.dtype + @classmethod def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict: """ diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 3dc3fdc..1ef307f 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -126,7 +126,10 @@ class Interface(ABC, Generic[T]): if isinstance(self.dtype, tuple): valid = dtype in self.dtype elif self.dtype is np.str_: - valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_ + valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( + np.str_, + str, + ) else: # try to match as any subclass, if self.dtype is a class try: diff --git a/tests/fixtures.py b/tests/fixtures.py index 6fbc5ad..cb5b59b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -127,8 +127,13 @@ def hdf5_array( _ = hdf5_file.create_dataset(array_path, data=data) return H5ArrayPath(Path(hdf5_file.filename), array_path) else: - dt = np.dtype([("data", dtype), ("extra", "i8")]) - data = np.zeros(shape, dtype=dt) + + if dtype is str: + dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) + data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape) + else: + dt = np.dtype([("data", dtype), ("extra", "i8")]) + data = np.zeros(shape, dtype=dt) _ = hdf5_file.create_dataset(array_path, data=data) return H5ArrayPath(Path(hdf5_file.filename), array_path, "data") diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 23d26a3..67dcc44 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -78,8 +78,6 @@ def test_hdf5_shape(shape_cases, hdf5_array, compound): @pytest.mark.parametrize("compound", [True, False]) def test_hdf5_dtype(dtype_cases, hdf5_array, compound): - if dtype_cases.dtype is str: - pytest.skip("hdf5 cant do string arrays") _test_hdf5_case(dtype_cases, hdf5_array, compound) @@ -157,3 +155,18 @@ def test_compound_dtype(tmp_path): assert all(instance.array[1, :] == 0) instance.array[1] = 2 assert all(instance.array[1] == 2) + + +def test_strings(hdf5_array): + """ + HDF5 proxy can get and set strings just like any other dtype + """ + array = hdf5_array((10, 10), str) + + class MyModel(BaseModel): + array: NDArray[Shape["10, 10"], str] + + instance = MyModel(array=array) + instance.array[0, 0] = "hey" + assert instance.array[0, 0] == "hey" + assert isinstance(instance.array[0, 1], str)