add serialization info to to_json methods, zarr dump array context option, remove coerce_list and just let numpy interface handle it

This commit is contained in:
sneakers-the-rat 2024-05-17 17:39:30 -07:00
parent f3fd0a0ed2
commit 5e90c1bee1
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
8 changed files with 112 additions and 31 deletions

View file

@ -2,9 +2,10 @@
Interface for Dask arrays Interface for Dask arrays
""" """
from typing import Any from typing import Any, Optional
import numpy as np import numpy as np
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface
@ -37,7 +38,9 @@ class DaskInterface(Interface):
return DaskArray is not None return DaskArray is not None
@classmethod @classmethod
def to_json(cls, array: DaskArray) -> list: def to_json(
cls, array: DaskArray, info: Optional[SerializationInfo] = None
) -> list:
""" """
Convert an array to a JSON serializable array by first converting to a numpy Convert an array to a JSON serializable array by first converting to a numpy
array and then to a list. array and then to a list.

View file

@ -4,9 +4,10 @@ Interfaces for HDF5 Datasets
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any, NamedTuple, Tuple, Union from typing import Any, NamedTuple, Optional, Tuple, Union
import numpy as np import numpy as np
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface
from numpydantic.types import NDArrayType from numpydantic.types import NDArrayType
@ -179,7 +180,7 @@ class H5Interface(Interface):
return array return array
@classmethod @classmethod
def to_json(cls, array: H5Proxy) -> dict: def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
""" """
Dump to a dictionary containing Dump to a dictionary containing

View file

@ -4,10 +4,11 @@ Base Interface metaclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from operator import attrgetter from operator import attrgetter
from typing import Any, Generic, Tuple, Type, TypeVar, Union from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
from nptyping.shape_expression import check_shape from nptyping.shape_expression import check_shape
from pydantic import SerializationInfo
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import DtypeError, ShapeError
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
@ -107,7 +108,9 @@ class Interface(ABC, Generic[T]):
""" """
@classmethod @classmethod
def to_json(cls, array: Type[T]) -> Union[list, dict]: def to_json(
cls, array: Type[T], info: Optional[SerializationInfo] = None
) -> Union[list, dict]:
""" """
Convert an array of :attr:`.return_type` to a JSON-compatible format using Convert an array of :attr:`.return_type` to a JSON-compatible format using
base python types base python types

View file

@ -7,6 +7,8 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Sequence, Union from typing import Any, Optional, Sequence, Union
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface
try: try:
@ -113,14 +115,29 @@ class ZarrInterface(Interface):
@classmethod @classmethod
def to_json( def to_json(
cls, array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence] cls,
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence],
info: Optional[SerializationInfo] = None,
) -> dict: ) -> dict:
""" """
Dump just the metadata for an array from :meth:`zarr.core.Array.info_items` Dump just the metadata for an array from :meth:`zarr.core.Array.info_items`
plus the :meth:`zarr.core.Array.hexdigest` plus the :meth:`zarr.core.Array.hexdigest`.
The full array can be returned by passing ``'zarr_dump_array': True`` to the
serialization ``context`` ::
model.model_dump_json(context={'zarr_dump_array': True})
""" """
dump_array = False
if info is not None and info.context is not None:
dump_array = info.context.get("zarr_dump_array", False)
array = cls._get_array(array) array = cls._get_array(array)
info = array.info_items() info = array.info_items()
info_dict = {i[0]: i[1] for i in info} info_dict = {i[0]: i[1] for i in info}
info_dict["hexdigest"] = array.hexdigest() info_dict["hexdigest"] = array.hexdigest()
if dump_array:
info_dict["array"] = array[:].tolist()
return info_dict return info_dict

View file

@ -32,7 +32,6 @@ from numpydantic.maps import python_to_nptyping
from numpydantic.schema import ( from numpydantic.schema import (
_handler_type, _handler_type,
_jsonize_array, _jsonize_array,
coerce_list,
get_validate_interface, get_validate_interface,
make_json_schema, make_json_schema,
) )
@ -119,16 +118,11 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
return core_schema.json_or_python_schema( return core_schema.json_or_python_schema(
json_schema=list_schema, json_schema=list_schema,
python_schema=core_schema.chain_schema( python_schema=core_schema.with_info_plain_validator_function(
[
core_schema.no_info_plain_validator_function(coerce_list),
core_schema.with_info_plain_validator_function(
get_validate_interface(shape, dtype) get_validate_interface(shape, dtype)
), ),
]
),
serialization=core_schema.plain_serializer_function_ser_schema( serialization=core_schema.plain_serializer_function_ser_schema(
_jsonize_array, when_used="json" _jsonize_array, when_used="json", info_arg=True
), ),
) )

View file

@ -8,6 +8,7 @@ from typing import Any, Callable, Union
import nptyping.structure import nptyping.structure
import numpy as np import numpy as np
from nptyping import Shape from nptyping import Shape
from pydantic import SerializationInfo
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo from pydantic_core.core_schema import ListSchema, ValidationInfo
@ -173,17 +174,7 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
return validate_interface return validate_interface
def _jsonize_array(value: Any) -> Union[list, dict]: def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON""" """Use an interface class to render an array as JSON"""
interface_cls = Interface.match_output(value) interface_cls = Interface.match_output(value)
return interface_cls.to_json(value) return interface_cls.to_json(value, info)
def coerce_list(value: Any) -> np.ndarray:
"""
If a value is passed as a list or list of lists, try and coerce it into an array
rather than failing validation.
"""
if isinstance(value, list):
value = np.array(value)
return value

View file

@ -1,5 +1,7 @@
import pytest import pytest
import numpy as np
from numpydantic.interface import Interface from numpydantic.interface import Interface
@ -88,3 +90,12 @@ def test_interface_type_lists():
assert atype in Interface.return_types() assert atype in Interface.return_types()
else: else:
assert interface.return_type in Interface.return_types() assert interface.return_type in Interface.return_types()
def test_interfaces_sorting():
"""
Interfaces should be returned in descending order of priority
"""
ifaces = Interface.interfaces()
priorities = [i.priority for i in ifaces]
assert (np.diff(priorities) <= 0).all()

View file

@ -1,9 +1,12 @@
import json
import pytest import pytest
import zarr import zarr
from pydantic import ValidationError from pydantic import ValidationError
from numpydantic.interface import ZarrInterface from numpydantic.interface import ZarrInterface
from numpydantic.interface.zarr import ZarrArrayPath
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import DtypeError, ShapeError
from tests.conftest import ValidationCase from tests.conftest import ValidationCase
@ -27,12 +30,12 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore:
return store return store
def zarr_array(case: ValidationCase, store) -> zarr.core.Array: def _zarr_array(case: ValidationCase, store) -> zarr.core.Array:
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
def _test_zarr_case(case: ValidationCase, store): def _test_zarr_case(case: ValidationCase, store):
array = zarr_array(case, store) array = _zarr_array(case, store)
if case.passes: if case.passes:
case.model(array=array) case.model(array=array)
else: else:
@ -76,3 +79,61 @@ def test_zarr_shape(store, shape_cases):
def test_zarr_dtype(dtype_cases, store): def test_zarr_dtype(dtype_cases, store):
_test_zarr_case(dtype_cases, store) _test_zarr_case(dtype_cases, store)
@pytest.mark.parametrize("array", ["zarr_nested_array", "zarr_array"])
def test_zarr_from_tuple(array, model_blank, request):
"""Should be able to do the same validation logic from tuples as an input"""
array = request.getfixturevalue(array)
if isinstance(array, ZarrArrayPath):
instance = model_blank(array=(array.file, array.path))
else:
instance = model_blank(array=(array,))
def test_zarr_from_path(zarr_array, model_blank):
"""Should be able to just pass a path"""
instance = model_blank(array=zarr_array)
def test_zarr_array_path_from_iterable(zarr_array):
"""Construct a zarr array path from some iterable!!!"""
# from a single path
apath = ZarrArrayPath.from_iterable((zarr_array,))
assert apath.file == zarr_array
assert apath.path is None
inner_path = "/test/array"
apath = ZarrArrayPath.from_iterable((zarr_array, inner_path))
assert apath.file == zarr_array
assert apath.path == inner_path
def test_zarr_to_json(store, model_blank):
expected_fields = (
"Type",
"Data type",
"Shape",
"Chunk shape",
"Compressor",
"Store type",
"hexdigest",
)
lol_array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
array = zarr.array(lol_array, store=store)
instance = model_blank(array=array)
as_json = json.loads(instance.model_dump_json())["array"]
assert "array" not in as_json
for field in expected_fields:
assert field in as_json
assert len(as_json["hexdigest"]) == 40
# dump the array itself too
as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[
"array"
]
for field in expected_fields:
assert field in as_json
assert len(as_json["hexdigest"]) == 40
assert as_json["array"] == lol_array