scratch work on serializing from a proxy

This commit is contained in:
sneakers-the-rat 2024-10-02 16:07:44 -07:00
parent 0d0b310b6e
commit 1f7955d6ef
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D

View file

@ -50,14 +50,13 @@ from typing import (
NamedTuple,
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
)
import numpy as np
from pydantic import GetCoreSchemaHandler, SerializationInfo
from pydantic_core import CoreSchema, core_schema
from pydantic import SerializationInfo
from pydantic_core import SchemaSerializer, core_schema
from numpydantic.interface.interface import Interface, JsonDict
from numpydantic.types import DtypeType, NDArrayType
@ -88,17 +87,6 @@ class H5ArrayPath(NamedTuple):
"""Refer to a specific field within a compound dtype"""
class H5ArrayPathDict(TypedDict):
"""Location specifier for arrays within an HDF5 file"""
file: Union[Path, str]
"""Location of HDF5 file"""
path: str
"""Path within the HDF5 file"""
field: Optional[Union[str, List[str]]] = None
"""Refer to a specific field within a compound dtype"""
class H5JsonDict(JsonDict):
"""Round-trip Json-able version of an HDF5 dataset"""
@ -113,6 +101,51 @@ class H5JsonDict(JsonDict):
)
def to_json(self, info: SerializationInfo):
"""
Serialize H5Proxy to JSON, as the interface does,
in cases when the interface is not able to be used
(eg. like when used as an `extra` field in a model without a type annotation)
"""
from numpydantic.serialization import postprocess_json
if info.round_trip:
as_json = {
"type": H5Interface.name,
}
as_json.update(self._h5arraypath._asdict())
else:
try:
dset = self.open()
as_json = dset[:].tolist()
finally:
self.close()
return postprocess_json(as_json, info)
def _make_pydantic_schema():
return core_schema.typed_dict_schema(
{
"file": core_schema.typed_dict_field(
core_schema.str_schema(), required=True
),
"path": core_schema.typed_dict_field(
core_schema.str_schema(), required=True
),
"field": core_schema.typed_dict_field(
core_schema.union_schema(
[
core_schema.str_schema(),
core_schema.list_schema(core_schema.str_schema()),
],
),
required=True,
),
},
# serialization=
)
class H5Proxy:
"""
Proxy class to mimic numpy-like array behavior with an HDF5 array
@ -135,6 +168,12 @@ class H5Proxy:
annotation_dtype (dtype): Optional - the dtype of our type annotation
"""
__pydantic_serializer__ = SchemaSerializer(
core_schema.plain_serializer_function_ser_schema(
to_json, when_used="json", info_arg=True
)
)
def __init__(
self,
file: Union[Path, str],
@ -179,7 +218,8 @@ class H5Proxy:
return obj[:]
def __getattr__(self, item: str):
if item not in ("shape", "__pydantic_validator__"):
pdb.set_trace()
if item == "__name__":
# special case for H5Proxies that don't refer to a real file during testing
return "H5Proxy"
@ -294,67 +334,24 @@ class H5Proxy:
v = np.array(v).astype("S32")
return v
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
pdb.set_trace()
return core_schema.typed_dict_schema(
{
"file": core_schema.typed_dict_field(
core_schema.str_schema(), required=True
),
"path": core_schema.typed_dict_field(
core_schema.str_schema(), required=True
),
"field": core_schema.typed_dict_field(
core_schema.union_schema(
[
core_schema.str_schema(),
core_schema.list_schema(core_schema.str_schema()),
],
),
required=True,
),
},
serialization=core_schema.plain_serializer_function_ser_schema(
cls.to_json, when_used="json"
),
)
# @classmethod
# def __get_pydantic_core_schema__(
# cls, source_type: Any, handler: GetCoreSchemaHandler
# ) -> CoreSchema:
# return cls._make_pydantic_schema()
# file: Union[Path, str]
# """Location of HDF5 file"""
# path: str
# """Path within the HDF5 file"""
# field: Optional[Union[str, List[str]]] = None
# """Refer to a specific field within a compound dtype"""
# }
# )
#
# file: Union[Path, str]
# """Location of HDF5 file"""
# path: str
# """Path within the HDF5 file"""
# field: Optional[Union[str, List[str]]] = None
# """Refer to a specific field within a compound dtype"""
# }
# )
#
#
# @model_serializer(when_used="json")
@staticmethod
def to_json(self, info: SerializationInfo):
"""
Serialize H5Proxy to JSON, as the interface does,
in cases when the interface is not able to be used
(eg. like when used as an `extra` field in a model without a type annotation)
"""
from numpydantic.serialization import postprocess_json
if info.round_trip:
as_json = {
"type": H5Interface.name,
}
as_json.update(self._h5arraypath._asdict())
else:
try:
dset = self.open()
as_json = dset[:].tolist()
finally:
self.close()
return postprocess_json(as_json, info)
class H5Interface(Interface):