From 0d0b310b6ec9578108989570eb5b1c25cb9a252e Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 26 Sep 2024 16:56:56 -0700 Subject: [PATCH] checkpointing --- src/numpydantic/interface/hdf5.py | 92 ++++++++++++++++++++++++++++++- src/numpydantic/serialization.py | 15 ++++- 2 files changed, 102 insertions(+), 5 deletions(-) diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index 9215ec2..6cbb03c 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -39,13 +39,25 @@ as ``S32`` isoformatted byte strings (timezones optional) like: """ +import pdb import sys from datetime import datetime from pathlib import Path -from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + TypedDict, + TypeVar, + Union, +) import numpy as np -from pydantic import SerializationInfo +from pydantic import GetCoreSchemaHandler, SerializationInfo +from pydantic_core import CoreSchema, core_schema from numpydantic.interface.interface import Interface, JsonDict from numpydantic.types import DtypeType, NDArrayType @@ -76,6 +88,17 @@ class H5ArrayPath(NamedTuple): """Refer to a specific field within a compound dtype""" +class H5ArrayPathDict(TypedDict): + """Location specifier for arrays within an HDF5 file""" + + file: Union[Path, str] + """Location of HDF5 file""" + path: str + """Path within the HDF5 file""" + field: Optional[Union[str, List[str]]] = None + """Refer to a specific field within a compound dtype""" + + class H5JsonDict(JsonDict): """Round-trip Json-able version of an HDF5 dataset""" @@ -156,9 +179,12 @@ class H5Proxy: return obj[:] def __getattr__(self, item: str): + if item == "__name__": # special case for H5Proxies that don't refer to a real file during testing return "H5Proxy" + elif item.startswith("__"): + return object.__getattribute__(self, item) with h5py.File(self.file, "r") as h5f: obj = h5f.get(self.path) val = getattr(obj, item) @@ -268,6 +294,68 @@ class H5Proxy: v = np.array(v).astype("S32") return v + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + pdb.set_trace() + return core_schema.typed_dict_schema( + { + "file": core_schema.typed_dict_field( + core_schema.str_schema(), required=True + ), + "path": core_schema.typed_dict_field( + core_schema.str_schema(), required=True + ), + "field": core_schema.typed_dict_field( + core_schema.union_schema( + [ + core_schema.str_schema(), + core_schema.list_schema(core_schema.str_schema()), + ], + ), + required=True, + ), + }, + serialization=core_schema.plain_serializer_function_ser_schema( + cls.to_json, when_used="json" + ), + ) + + # file: Union[Path, str] + # """Location of HDF5 file""" + # path: str + # """Path within the HDF5 file""" + # field: Optional[Union[str, List[str]]] = None + # """Refer to a specific field within a compound dtype""" + # } + # ) + # + + # + # @model_serializer(when_used="json") + @staticmethod + def to_json(self, info: SerializationInfo): + """ + Serialize H5Proxy to JSON, as the interface does, + in cases when the interface is not able to be used + (eg. like when used as an `extra` field in a model without a type annotation) + """ + from numpydantic.serialization import postprocess_json + + if info.round_trip: + as_json = { + "type": H5Interface.name, + } + as_json.update(self._h5arraypath._asdict()) + else: + try: + dset = self.open() + as_json = dset[:].tolist() + finally: + self.close() + return postprocess_json(as_json, info) + class H5Interface(Interface): """ diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index eb5b4bc..502d43d 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -16,11 +16,20 @@ U = TypeVar("U") def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: """Use an interface class to render an array as JSON""" - # perf: keys to skip in generation - anything named "value" is array data. - skip = ["value"] - interface_cls = Interface.match_output(value) array = interface_cls.to_json(value, info) + array = postprocess_json(array, info) + return array + + +def postprocess_json( + array: Union[dict, list], info: SerializationInfo +) -> Union[dict, list]: + """ + Modify json after dumping from an interface + """ + # perf: keys to skip in generation - anything named "value" is array data. + skip = ["value"] if isinstance(array, JsonDict): array = array.model_dump(exclude_none=True)