From f28c766b96d87f5f3d938588ce16d97878fbd095 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 2 Sep 2024 22:40:17 -0700 Subject: [PATCH] strings in compound dtypes --- src/numpydantic/interface/hdf5.py | 9 +++++++-- tests/test_interface/test_hdf5.py | 5 +++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index b63e760..0e72669 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -124,8 +124,13 @@ class H5Proxy: with h5py.File(self.file, "r") as h5f: obj = h5f.get(self.path) if self.field is not None: - if h5py.h5t.check_string_dtype(obj.dtype[self.field]): - obj = obj.fields(self.field).asstr() + if encoding := h5py.h5t.check_string_dtype(obj.dtype[self.field]): + if isinstance(item, tuple): + item = (*item, self.field) + else: + item = (item, self.field) + + return obj[item].decode(encoding.encoding) else: obj = obj.fields(self.field) else: diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 67dcc44..7d00174 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -157,11 +157,12 @@ def test_compound_dtype(tmp_path): assert all(instance.array[1] == 2) -def test_strings(hdf5_array): +@pytest.mark.parametrize("compound", [True, False]) +def test_strings(hdf5_array, compound): """ HDF5 proxy can get and set strings just like any other dtype """ - array = hdf5_array((10, 10), str) + array = hdf5_array((10, 10), str, compound=compound) class MyModel(BaseModel): array: NDArray[Shape["10, 10"], str]