checkpointing

This commit is contained in:
sneakers-the-rat 2024-09-26 16:56:56 -07:00
parent 66ab444ec2
commit 0d0b310b6e
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 102 additions and 5 deletions

View file

@ -39,13 +39,25 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
""" """
import pdb
import sys import sys
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union from typing import (
Any,
Iterable,
List,
NamedTuple,
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
)
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import GetCoreSchemaHandler, SerializationInfo
from pydantic_core import CoreSchema, core_schema
from numpydantic.interface.interface import Interface, JsonDict from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType from numpydantic.types import DtypeType, NDArrayType
@ -76,6 +88,17 @@ class H5ArrayPath(NamedTuple):
"""Refer to a specific field within a compound dtype""" """Refer to a specific field within a compound dtype"""
class H5ArrayPathDict(TypedDict):
"""Location specifier for arrays within an HDF5 file"""
file: Union[Path, str]
"""Location of HDF5 file"""
path: str
"""Path within the HDF5 file"""
field: Optional[Union[str, List[str]]] = None
"""Refer to a specific field within a compound dtype"""
class H5JsonDict(JsonDict): class H5JsonDict(JsonDict):
"""Round-trip Json-able version of an HDF5 dataset""" """Round-trip Json-able version of an HDF5 dataset"""
@ -156,9 +179,12 @@ class H5Proxy:
return obj[:] return obj[:]
def __getattr__(self, item: str): def __getattr__(self, item: str):
if item == "__name__": if item == "__name__":
# special case for H5Proxies that don't refer to a real file during testing # special case for H5Proxies that don't refer to a real file during testing
return "H5Proxy" return "H5Proxy"
elif item.startswith("__"):
return object.__getattribute__(self, item)
with h5py.File(self.file, "r") as h5f: with h5py.File(self.file, "r") as h5f:
obj = h5f.get(self.path) obj = h5f.get(self.path)
val = getattr(obj, item) val = getattr(obj, item)
@ -268,6 +294,68 @@ class H5Proxy:
v = np.array(v).astype("S32") v = np.array(v).astype("S32")
return v return v
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
pdb.set_trace()
return core_schema.typed_dict_schema(
{
"file": core_schema.typed_dict_field(
core_schema.str_schema(), required=True
),
"path": core_schema.typed_dict_field(
core_schema.str_schema(), required=True
),
"field": core_schema.typed_dict_field(
core_schema.union_schema(
[
core_schema.str_schema(),
core_schema.list_schema(core_schema.str_schema()),
],
),
required=True,
),
},
serialization=core_schema.plain_serializer_function_ser_schema(
cls.to_json, when_used="json"
),
)
# file: Union[Path, str]
# """Location of HDF5 file"""
# path: str
# """Path within the HDF5 file"""
# field: Optional[Union[str, List[str]]] = None
# """Refer to a specific field within a compound dtype"""
# }
# )
#
#
# @model_serializer(when_used="json")
@staticmethod
def to_json(self, info: SerializationInfo):
"""
Serialize H5Proxy to JSON, as the interface does,
in cases when the interface is not able to be used
(eg. like when used as an `extra` field in a model without a type annotation)
"""
from numpydantic.serialization import postprocess_json
if info.round_trip:
as_json = {
"type": H5Interface.name,
}
as_json.update(self._h5arraypath._asdict())
else:
try:
dset = self.open()
as_json = dset[:].tolist()
finally:
self.close()
return postprocess_json(as_json, info)
class H5Interface(Interface): class H5Interface(Interface):
""" """

View file

@ -16,11 +16,20 @@ U = TypeVar("U")
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON""" """Use an interface class to render an array as JSON"""
# perf: keys to skip in generation - anything named "value" is array data.
skip = ["value"]
interface_cls = Interface.match_output(value) interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info) array = interface_cls.to_json(value, info)
array = postprocess_json(array, info)
return array
def postprocess_json(
array: Union[dict, list], info: SerializationInfo
) -> Union[dict, list]:
"""
Modify json after dumping from an interface
"""
# perf: keys to skip in generation - anything named "value" is array data.
skip = ["value"]
if isinstance(array, JsonDict): if isinstance(array, JsonDict):
array = array.model_dump(exclude_none=True) array = array.model_dump(exclude_none=True)