From 5e90c1bee1251072d34c7919060c2456f56a0201 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 17 May 2024 17:39:30 -0700 Subject: [PATCH] add serialization info to to_json methods, zarr dump array context option, remove coerce_list and just let numpy interface handle it --- src/numpydantic/interface/dask.py | 7 ++- src/numpydantic/interface/hdf5.py | 5 +- src/numpydantic/interface/interface.py | 7 ++- src/numpydantic/interface/zarr.py | 21 ++++++++- src/numpydantic/ndarray.py | 12 ++--- src/numpydantic/schema.py | 15 ++---- tests/test_interface/test_interface.py | 11 +++++ tests/test_interface/test_zarr.py | 65 +++++++++++++++++++++++++- 8 files changed, 112 insertions(+), 31 deletions(-) diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index 8073c92..5f3f3c2 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -2,9 +2,10 @@ Interface for Dask arrays """ -from typing import Any +from typing import Any, Optional import numpy as np +from pydantic import SerializationInfo from numpydantic.interface.interface import Interface @@ -37,7 +38,9 @@ class DaskInterface(Interface): return DaskArray is not None @classmethod - def to_json(cls, array: DaskArray) -> list: + def to_json( + cls, array: DaskArray, info: Optional[SerializationInfo] = None + ) -> list: """ Convert an array to a JSON serializable array by first converting to a numpy array and then to a list. diff --git a/src/numpydantic/interface/hdf5.py b/src/numpydantic/interface/hdf5.py index ae294ae..0bac99f 100644 --- a/src/numpydantic/interface/hdf5.py +++ b/src/numpydantic/interface/hdf5.py @@ -4,9 +4,10 @@ Interfaces for HDF5 Datasets import sys from pathlib import Path -from typing import Any, NamedTuple, Tuple, Union +from typing import Any, NamedTuple, Optional, Tuple, Union import numpy as np +from pydantic import SerializationInfo from numpydantic.interface.interface import Interface from numpydantic.types import NDArrayType @@ -179,7 +180,7 @@ class H5Interface(Interface): return array @classmethod - def to_json(cls, array: H5Proxy) -> dict: + def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict: """ Dump to a dictionary containing diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 3dfddaf..825c4c1 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -4,10 +4,11 @@ Base Interface metaclass from abc import ABC, abstractmethod from operator import attrgetter -from typing import Any, Generic, Tuple, Type, TypeVar, Union +from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union import numpy as np from nptyping.shape_expression import check_shape +from pydantic import SerializationInfo from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.types import DtypeType, NDArrayType, ShapeType @@ -107,7 +108,9 @@ class Interface(ABC, Generic[T]): """ @classmethod - def to_json(cls, array: Type[T]) -> Union[list, dict]: + def to_json( + cls, array: Type[T], info: Optional[SerializationInfo] = None + ) -> Union[list, dict]: """ Convert an array of :attr:`.return_type` to a JSON-compatible format using base python types diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index 4244b0e..5a300e6 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -7,6 +7,8 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Optional, Sequence, Union +from pydantic import SerializationInfo + from numpydantic.interface.interface import Interface try: @@ -113,14 +115,29 @@ class ZarrInterface(Interface): @classmethod def to_json( - cls, array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence] + cls, + array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence], + info: Optional[SerializationInfo] = None, ) -> dict: """ Dump just the metadata for an array from :meth:`zarr.core.Array.info_items` - plus the :meth:`zarr.core.Array.hexdigest` + plus the :meth:`zarr.core.Array.hexdigest`. + + The full array can be returned by passing ``'zarr_dump_array': True`` to the + serialization ``context`` :: + + model.model_dump_json(context={'zarr_dump_array': True}) """ + dump_array = False + if info is not None and info.context is not None: + dump_array = info.context.get("zarr_dump_array", False) + array = cls._get_array(array) info = array.info_items() info_dict = {i[0]: i[1] for i in info} info_dict["hexdigest"] = array.hexdigest() + + if dump_array: + info_dict["array"] = array[:].tolist() + return info_dict diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index bcb35bd..2ef7da7 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -32,7 +32,6 @@ from numpydantic.maps import python_to_nptyping from numpydantic.schema import ( _handler_type, _jsonize_array, - coerce_list, get_validate_interface, make_json_schema, ) @@ -119,16 +118,11 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta): return core_schema.json_or_python_schema( json_schema=list_schema, - python_schema=core_schema.chain_schema( - [ - core_schema.no_info_plain_validator_function(coerce_list), - core_schema.with_info_plain_validator_function( - get_validate_interface(shape, dtype) - ), - ] + python_schema=core_schema.with_info_plain_validator_function( + get_validate_interface(shape, dtype) ), serialization=core_schema.plain_serializer_function_ser_schema( - _jsonize_array, when_used="json" + _jsonize_array, when_used="json", info_arg=True ), ) diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index 568ab5d..b7733c6 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Union import nptyping.structure import numpy as np from nptyping import Shape +from pydantic import SerializationInfo from pydantic_core import CoreSchema, core_schema from pydantic_core.core_schema import ListSchema, ValidationInfo @@ -173,17 +174,7 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable: return validate_interface -def _jsonize_array(value: Any) -> Union[list, dict]: +def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: """Use an interface class to render an array as JSON""" interface_cls = Interface.match_output(value) - return interface_cls.to_json(value) - - -def coerce_list(value: Any) -> np.ndarray: - """ - If a value is passed as a list or list of lists, try and coerce it into an array - rather than failing validation. - """ - if isinstance(value, list): - value = np.array(value) - return value + return interface_cls.to_json(value, info) diff --git a/tests/test_interface/test_interface.py b/tests/test_interface/test_interface.py index 7555997..bbafb7a 100644 --- a/tests/test_interface/test_interface.py +++ b/tests/test_interface/test_interface.py @@ -1,5 +1,7 @@ import pytest +import numpy as np + from numpydantic.interface import Interface @@ -88,3 +90,12 @@ def test_interface_type_lists(): assert atype in Interface.return_types() else: assert interface.return_type in Interface.return_types() + + +def test_interfaces_sorting(): + """ + Interfaces should be returned in descending order of priority + """ + ifaces = Interface.interfaces() + priorities = [i.priority for i in ifaces] + assert (np.diff(priorities) <= 0).all() diff --git a/tests/test_interface/test_zarr.py b/tests/test_interface/test_zarr.py index 6913b40..eab3e52 100644 --- a/tests/test_interface/test_zarr.py +++ b/tests/test_interface/test_zarr.py @@ -1,9 +1,12 @@ +import json + import pytest import zarr from pydantic import ValidationError from numpydantic.interface import ZarrInterface +from numpydantic.interface.zarr import ZarrArrayPath from numpydantic.exceptions import DtypeError, ShapeError from tests.conftest import ValidationCase @@ -27,12 +30,12 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore: return store -def zarr_array(case: ValidationCase, store) -> zarr.core.Array: +def _zarr_array(case: ValidationCase, store) -> zarr.core.Array: return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) def _test_zarr_case(case: ValidationCase, store): - array = zarr_array(case, store) + array = _zarr_array(case, store) if case.passes: case.model(array=array) else: @@ -76,3 +79,61 @@ def test_zarr_shape(store, shape_cases): def test_zarr_dtype(dtype_cases, store): _test_zarr_case(dtype_cases, store) + + +@pytest.mark.parametrize("array", ["zarr_nested_array", "zarr_array"]) +def test_zarr_from_tuple(array, model_blank, request): + """Should be able to do the same validation logic from tuples as an input""" + array = request.getfixturevalue(array) + if isinstance(array, ZarrArrayPath): + instance = model_blank(array=(array.file, array.path)) + else: + instance = model_blank(array=(array,)) + + +def test_zarr_from_path(zarr_array, model_blank): + """Should be able to just pass a path""" + instance = model_blank(array=zarr_array) + + +def test_zarr_array_path_from_iterable(zarr_array): + """Construct a zarr array path from some iterable!!!""" + # from a single path + apath = ZarrArrayPath.from_iterable((zarr_array,)) + assert apath.file == zarr_array + assert apath.path is None + + inner_path = "/test/array" + apath = ZarrArrayPath.from_iterable((zarr_array, inner_path)) + assert apath.file == zarr_array + assert apath.path == inner_path + + +def test_zarr_to_json(store, model_blank): + expected_fields = ( + "Type", + "Data type", + "Shape", + "Chunk shape", + "Compressor", + "Store type", + "hexdigest", + ) + lol_array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + + array = zarr.array(lol_array, store=store) + instance = model_blank(array=array) + as_json = json.loads(instance.model_dump_json())["array"] + assert "array" not in as_json + for field in expected_fields: + assert field in as_json + assert len(as_json["hexdigest"]) == 40 + + # dump the array itself too + as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[ + "array" + ] + for field in expected_fields: + assert field in as_json + assert len(as_json["hexdigest"]) == 40 + assert as_json["array"] == lol_array