mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 18:54:28 +00:00
checkpointing
This commit is contained in:
parent
66ab444ec2
commit
0d0b310b6e
2 changed files with 102 additions and 5 deletions
|
@ -39,13 +39,25 @@ as ``S32`` isoformatted byte strings (timezones optional) like:
|
|||
|
||||
"""
|
||||
|
||||
import pdb
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
from pydantic import GetCoreSchemaHandler, SerializationInfo
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
@ -76,6 +88,17 @@ class H5ArrayPath(NamedTuple):
|
|||
"""Refer to a specific field within a compound dtype"""
|
||||
|
||||
|
||||
class H5ArrayPathDict(TypedDict):
|
||||
"""Location specifier for arrays within an HDF5 file"""
|
||||
|
||||
file: Union[Path, str]
|
||||
"""Location of HDF5 file"""
|
||||
path: str
|
||||
"""Path within the HDF5 file"""
|
||||
field: Optional[Union[str, List[str]]] = None
|
||||
"""Refer to a specific field within a compound dtype"""
|
||||
|
||||
|
||||
class H5JsonDict(JsonDict):
|
||||
"""Round-trip Json-able version of an HDF5 dataset"""
|
||||
|
||||
|
@ -156,9 +179,12 @@ class H5Proxy:
|
|||
return obj[:]
|
||||
|
||||
def __getattr__(self, item: str):
|
||||
|
||||
if item == "__name__":
|
||||
# special case for H5Proxies that don't refer to a real file during testing
|
||||
return "H5Proxy"
|
||||
elif item.startswith("__"):
|
||||
return object.__getattribute__(self, item)
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
val = getattr(obj, item)
|
||||
|
@ -268,6 +294,68 @@ class H5Proxy:
|
|||
v = np.array(v).astype("S32")
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
pdb.set_trace()
|
||||
return core_schema.typed_dict_schema(
|
||||
{
|
||||
"file": core_schema.typed_dict_field(
|
||||
core_schema.str_schema(), required=True
|
||||
),
|
||||
"path": core_schema.typed_dict_field(
|
||||
core_schema.str_schema(), required=True
|
||||
),
|
||||
"field": core_schema.typed_dict_field(
|
||||
core_schema.union_schema(
|
||||
[
|
||||
core_schema.str_schema(),
|
||||
core_schema.list_schema(core_schema.str_schema()),
|
||||
],
|
||||
),
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
cls.to_json, when_used="json"
|
||||
),
|
||||
)
|
||||
|
||||
# file: Union[Path, str]
|
||||
# """Location of HDF5 file"""
|
||||
# path: str
|
||||
# """Path within the HDF5 file"""
|
||||
# field: Optional[Union[str, List[str]]] = None
|
||||
# """Refer to a specific field within a compound dtype"""
|
||||
# }
|
||||
# )
|
||||
#
|
||||
|
||||
#
|
||||
# @model_serializer(when_used="json")
|
||||
@staticmethod
|
||||
def to_json(self, info: SerializationInfo):
|
||||
"""
|
||||
Serialize H5Proxy to JSON, as the interface does,
|
||||
in cases when the interface is not able to be used
|
||||
(eg. like when used as an `extra` field in a model without a type annotation)
|
||||
"""
|
||||
from numpydantic.serialization import postprocess_json
|
||||
|
||||
if info.round_trip:
|
||||
as_json = {
|
||||
"type": H5Interface.name,
|
||||
}
|
||||
as_json.update(self._h5arraypath._asdict())
|
||||
else:
|
||||
try:
|
||||
dset = self.open()
|
||||
as_json = dset[:].tolist()
|
||||
finally:
|
||||
self.close()
|
||||
return postprocess_json(as_json, info)
|
||||
|
||||
|
||||
class H5Interface(Interface):
|
||||
"""
|
||||
|
|
|
@ -16,11 +16,20 @@ U = TypeVar("U")
|
|||
|
||||
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||
"""Use an interface class to render an array as JSON"""
|
||||
# perf: keys to skip in generation - anything named "value" is array data.
|
||||
skip = ["value"]
|
||||
|
||||
interface_cls = Interface.match_output(value)
|
||||
array = interface_cls.to_json(value, info)
|
||||
array = postprocess_json(array, info)
|
||||
return array
|
||||
|
||||
|
||||
def postprocess_json(
|
||||
array: Union[dict, list], info: SerializationInfo
|
||||
) -> Union[dict, list]:
|
||||
"""
|
||||
Modify json after dumping from an interface
|
||||
"""
|
||||
# perf: keys to skip in generation - anything named "value" is array data.
|
||||
skip = ["value"]
|
||||
if isinstance(array, JsonDict):
|
||||
array = array.model_dump(exclude_none=True)
|
||||
|
||||
|
|
Loading…
Reference in a new issue