From e78c170a2b6e22aa13d13130edfcbb49dcafb398 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 2 Sep 2024 22:49:13 -0700 Subject: [PATCH] correct decoding of byte arrays --- src/numpydantic/interface/hdf5.py | 12 +++++------- tests/test_interface/test_hdf5.py | 3 +++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 0e72669..1f75409 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -130,7 +130,10 @@ class H5Proxy: else: item = (item, self.field) - return obj[item].decode(encoding.encoding) + try: + return obj[item].decode(encoding.encoding) + except AttributeError: + return np.strings.decode(obj[item], encoding=encoding.encoding) else: obj = obj.fields(self.field) else: @@ -262,12 +265,7 @@ class H5Interface(Interface): Subclasses to correctly handle """ - if hasattr(array.dtype, "type") and array.dtype.type is np.object_: - if h5py.h5t.check_string_dtype(array.dtype): - return str - else: - return self.get_object_dtype(array) - elif h5py.h5t.check_string_dtype(array.dtype): + if h5py.h5t.check_string_dtype(array.dtype): return str else: return array.dtype diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 7d00174..891dd9f 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -171,3 +171,6 @@ def test_strings(hdf5_array, compound): instance.array[0, 0] = "hey" assert instance.array[0, 0] == "hey" assert isinstance(instance.array[0, 1], str) + + instance.array[1] = "sup" + assert all(instance.array[1] == "sup")