mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 13:44:26 +00:00
Add support for strings in hdf5
This commit is contained in:
parent
f699d2ab7b
commit
0e6ea07d5e
4 changed files with 73 additions and 7 deletions
|
@ -1,5 +1,27 @@
|
|||
"""
|
||||
Interfaces for HDF5 Datasets
|
||||
|
||||
.. note::
|
||||
|
||||
HDF5 arrays are accessed through a proxy class :class:`.H5Proxy` .
|
||||
Getting/setting values should work as normal, **except** that setting
|
||||
values on nested views is impossible -
|
||||
|
||||
Specifically this doesn't work:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
my_model.array[0][0] = 1
|
||||
|
||||
But this does work:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
my_model.array[0,0] = 1
|
||||
|
||||
To have direct access to the hdf5 dataset, use the
|
||||
:meth:`.H5Proxy.open` method.
|
||||
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
@ -10,7 +32,7 @@ import numpy as np
|
|||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.types import NDArrayType
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
||||
try:
|
||||
import h5py
|
||||
|
@ -102,7 +124,14 @@ class H5Proxy:
|
|||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
if self.field is not None:
|
||||
obj = obj.fields(self.field)
|
||||
if h5py.h5t.check_string_dtype(obj.dtype[self.field]):
|
||||
obj = obj.fields(self.field).asstr()
|
||||
else:
|
||||
obj = obj.fields(self.field)
|
||||
else:
|
||||
if h5py.h5t.check_string_dtype(obj.dtype):
|
||||
obj = obj.asstr()
|
||||
|
||||
return obj[item]
|
||||
|
||||
def __setitem__(
|
||||
|
@ -222,6 +251,22 @@ class H5Interface(Interface):
|
|||
|
||||
return array
|
||||
|
||||
def get_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""
|
||||
Get the dtype from the input array
|
||||
|
||||
Subclasses to correctly handle
|
||||
"""
|
||||
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
|
||||
if h5py.h5t.check_string_dtype(array.dtype):
|
||||
return str
|
||||
else:
|
||||
return self.get_object_dtype(array)
|
||||
elif h5py.h5t.check_string_dtype(array.dtype):
|
||||
return str
|
||||
else:
|
||||
return array.dtype
|
||||
|
||||
@classmethod
|
||||
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
|
||||
"""
|
||||
|
|
|
@ -126,7 +126,10 @@ class Interface(ABC, Generic[T]):
|
|||
if isinstance(self.dtype, tuple):
|
||||
valid = dtype in self.dtype
|
||||
elif self.dtype is np.str_:
|
||||
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
|
||||
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
|
||||
np.str_,
|
||||
str,
|
||||
)
|
||||
else:
|
||||
# try to match as any subclass, if self.dtype is a class
|
||||
try:
|
||||
|
|
|
@ -127,8 +127,13 @@ def hdf5_array(
|
|||
_ = hdf5_file.create_dataset(array_path, data=data)
|
||||
return H5ArrayPath(Path(hdf5_file.filename), array_path)
|
||||
else:
|
||||
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
||||
data = np.zeros(shape, dtype=dt)
|
||||
|
||||
if dtype is str:
|
||||
dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")])
|
||||
data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape)
|
||||
else:
|
||||
dt = np.dtype([("data", dtype), ("extra", "i8")])
|
||||
data = np.zeros(shape, dtype=dt)
|
||||
_ = hdf5_file.create_dataset(array_path, data=data)
|
||||
return H5ArrayPath(Path(hdf5_file.filename), array_path, "data")
|
||||
|
||||
|
|
|
@ -78,8 +78,6 @@ def test_hdf5_shape(shape_cases, hdf5_array, compound):
|
|||
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
|
||||
if dtype_cases.dtype is str:
|
||||
pytest.skip("hdf5 cant do string arrays")
|
||||
_test_hdf5_case(dtype_cases, hdf5_array, compound)
|
||||
|
||||
|
||||
|
@ -157,3 +155,18 @@ def test_compound_dtype(tmp_path):
|
|||
assert all(instance.array[1, :] == 0)
|
||||
instance.array[1] = 2
|
||||
assert all(instance.array[1] == 2)
|
||||
|
||||
|
||||
def test_strings(hdf5_array):
|
||||
"""
|
||||
HDF5 proxy can get and set strings just like any other dtype
|
||||
"""
|
||||
array = hdf5_array((10, 10), str)
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Shape["10, 10"], str]
|
||||
|
||||
instance = MyModel(array=array)
|
||||
instance.array[0, 0] = "hey"
|
||||
assert instance.array[0, 0] == "hey"
|
||||
assert isinstance(instance.array[0, 1], str)
|
||||
|
|
Loading…
Reference in a new issue