diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index 07924eb..f901994 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -93,7 +93,7 @@ def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict: return _walk_and_apply(value, _a_path, skip) -def _walk_and_apply(value: T, f: Callable[[U], U], skip: Iterable = tuple()) -> T: +def _walk_and_apply(value: T, f: Callable[[U, bool], U], skip: Iterable = tuple()) -> T: """ Walk an object, applying a function """ diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 702dc1a..5d0b2d8 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,6 +10,8 @@ from typing import Callable import numpy as np import json +from numpydantic.serialization import _walk_and_apply + pytestmark = pytest.mark.serialization @@ -93,3 +95,34 @@ def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank): # shouldn't have absolutized subpath even if it's pathlike assert data["path"] == expected_dataset + + +def test_walk_and_apply(): + """ + Walk and apply should recursively apply a function to everything in a nesty structure + """ + test = { + "a": 1, + "b": 1, + "c": [ + {"a": 1, "b": {"a": 1, "b": 1}, "c": [1, 1, 1]}, + {"a": 1, "b": [1, 1, 1]}, + ], + } + + def _mult_2(v, skip: bool = False): + return v * 2 + + def _assert_2(v, skip: bool = False): + assert v == 2 + return v + + walked = _walk_and_apply(test, _mult_2) + _walk_and_apply(walked, _assert_2) + + assert walked["a"] == 2 + assert walked["c"][0]["a"] == 2 + assert walked["c"][0]["b"]["a"] == 2 + assert all([w == 2 for w in walked["c"][0]["c"]]) + assert walked["c"][1]["a"] == 2 + assert all([w == 2 for w in walked["c"][1]["b"]])