Add support for strings in hdf5

This commit is contained in:
sneakers-the-rat 2024-09-02 22:14:47 -07:00
parent f699d2ab7b
commit 0e6ea07d5e
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
4 changed files with 73 additions and 7 deletions

View file

@ -1,5 +1,27 @@
""" """
Interfaces for HDF5 Datasets Interfaces for HDF5 Datasets
.. note::
HDF5 arrays are accessed through a proxy class :class:`.H5Proxy` .
Getting/setting values should work as normal, **except** that setting
values on nested views is impossible -
Specifically this doesn't work:
.. code-block:: python
my_model.array[0][0] = 1
But this does work:
.. code-block:: python
my_model.array[0,0] = 1
To have direct access to the hdf5 dataset, use the
:meth:`.H5Proxy.open` method.
""" """
import sys import sys
@ -10,7 +32,7 @@ import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface
from numpydantic.types import NDArrayType from numpydantic.types import DtypeType, NDArrayType
try: try:
import h5py import h5py
@ -102,7 +124,14 @@ class H5Proxy:
with h5py.File(self.file, "r") as h5f: with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path) obj = h5f.get(self.path)
if self.field is not None: if self.field is not None:
if h5py.h5t.check_string_dtype(obj.dtype[self.field]):
obj = obj.fields(self.field).asstr()
else:
obj = obj.fields(self.field) obj = obj.fields(self.field)
else:
if h5py.h5t.check_string_dtype(obj.dtype):
obj = obj.asstr()
return obj[item] return obj[item]
def __setitem__( def __setitem__(
@ -222,6 +251,22 @@ class H5Interface(Interface):
return array return array
def get_dtype(self, array: NDArrayType) -> DtypeType:
"""
Get the dtype from the input array
Subclasses to correctly handle
"""
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
if h5py.h5t.check_string_dtype(array.dtype):
return str
else:
return self.get_object_dtype(array)
elif h5py.h5t.check_string_dtype(array.dtype):
return str
else:
return array.dtype
@classmethod @classmethod
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict: def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
""" """

View file

@ -126,7 +126,10 @@ class Interface(ABC, Generic[T]):
if isinstance(self.dtype, tuple): if isinstance(self.dtype, tuple):
valid = dtype in self.dtype valid = dtype in self.dtype
elif self.dtype is np.str_: elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_ valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else: else:
# try to match as any subclass, if self.dtype is a class # try to match as any subclass, if self.dtype is a class
try: try:

View file

@ -127,6 +127,11 @@ def hdf5_array(
_ = hdf5_file.create_dataset(array_path, data=data) _ = hdf5_file.create_dataset(array_path, data=data)
return H5ArrayPath(Path(hdf5_file.filename), array_path) return H5ArrayPath(Path(hdf5_file.filename), array_path)
else: else:
if dtype is str:
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
else:
dt = np.dtype([("data", dtype), ("extra", "i8")]) dt = np.dtype([("data", dtype), ("extra", "i8")])
data = np.zeros(shape, dtype=dt) data = np.zeros(shape, dtype=dt)
_ = hdf5_file.create_dataset(array_path, data=data) _ = hdf5_file.create_dataset(array_path, data=data)

View file

@ -78,8 +78,6 @@ def test_hdf5_shape(shape_cases, hdf5_array, compound):
@pytest.mark.parametrize("compound", [True, False]) @pytest.mark.parametrize("compound", [True, False])
def test_hdf5_dtype(dtype_cases, hdf5_array, compound): def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
if dtype_cases.dtype is str:
pytest.skip("hdf5 cant do string arrays")
_test_hdf5_case(dtype_cases, hdf5_array, compound) _test_hdf5_case(dtype_cases, hdf5_array, compound)
@ -157,3 +155,18 @@ def test_compound_dtype(tmp_path):
assert all(instance.array[1, :] == 0) assert all(instance.array[1, :] == 0)
instance.array[1] = 2 instance.array[1] = 2
assert all(instance.array[1] == 2) assert all(instance.array[1] == 2)
def test_strings(hdf5_array):
"""
HDF5 proxy can get and set strings just like any other dtype
"""
array = hdf5_array((10, 10), str)
class MyModel(BaseModel):
array: NDArray[Shape["10, 10"], str]
instance = MyModel(array=array)
instance.array[0, 0] = "hey"
assert instance.array[0, 0] == "hey"
assert isinstance(instance.array[0, 1], str)