From 1701ef9d7ed9be18dd84663f82bd6504919a48b6 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 17 Oct 2024 22:15:30 -0700 Subject: [PATCH] get tests working again --- src/numpydantic/serialization.py | 4 ++-- tests/test_interface/conftest.py | 8 ++++---- tests/test_interface/test_hdf5.py | 9 +++------ tests/test_interface/test_interfaces.py | 16 ++++++++-------- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index 357e3c9..2b570cf 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -20,12 +20,12 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: # pdb.set_trace() interface_cls = Interface.match_output(value) array = interface_cls.to_json(value, info) - array = postprocess_json(array, info) + array = postprocess_json(array, info, interface_cls) return array def postprocess_json( - array: Union[dict, list], info: SerializationInfo + array: Union[dict, list], info: SerializationInfo, interface_cls: type[Interface] ) -> Union[dict, list]: """ Modify json after dumping from an interface diff --git a/tests/test_interface/conftest.py b/tests/test_interface/conftest.py index 0f1048a..5ba52fb 100644 --- a/tests/test_interface/conftest.py +++ b/tests/test_interface/conftest.py @@ -136,7 +136,7 @@ def all_passing_cases_instance(all_passing_cases, tmp_output_dir_func): for p in DTYPE_AND_INTERFACE_CASES_PASSING ) ) -def dtype_by_interface(request): +def all_passing_cases(request): """ Tests for all dtypes by all interfaces """ @@ -144,7 +144,7 @@ def dtype_by_interface(request): @pytest.fixture() -def dtype_by_interface_instance(dtype_by_interface, tmp_output_dir_func): - array = dtype_by_interface.array(path=tmp_output_dir_func) - instance = dtype_by_interface.model(array=array) +def dtype_by_interface_instance(all_passing_cases, tmp_output_dir_func): + array = all_passing_cases.array(path=tmp_output_dir_func) + instance = all_passing_cases.model(array=array) return instance diff --git a/tests/test_interface/test_hdf5.py b/tests/test_interface/test_hdf5.py index 42d1a5b..e5905d0 100644 --- a/tests/test_interface/test_hdf5.py +++ b/tests/test_interface/test_hdf5.py @@ -220,8 +220,8 @@ def test_empty_dataset(dtype, tmp_path): (H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True), (H5Proxy(file="test_file.h5", path="/subpath"), False), (H5Proxy(file="different_file.h5", path="/subpath"), False), - (("different_file.h5", "/subpath", "sup"), ValueError), - ("not even a proxy-like thing", ValueError), + (("different_file.h5", "/subpath", "sup"), False), + ("not even a proxy-like thing", False), ], ) def test_proxy_eq(comparison, valid): @@ -232,8 +232,5 @@ def test_proxy_eq(comparison, valid): proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup") if valid is True: assert proxy_a == comparison - elif valid is False: - assert proxy_a != comparison else: - with pytest.raises(valid): - assert proxy_a == comparison + assert proxy_a != comparison diff --git a/tests/test_interface/test_interfaces.py b/tests/test_interface/test_interfaces.py index a368aee..10adec3 100644 --- a/tests/test_interface/test_interfaces.py +++ b/tests/test_interface/test_interfaces.py @@ -91,15 +91,15 @@ def test_interface_dump_json(dtype_by_interface_instance): @pytest.mark.serialization -def test_interface_roundtrip_json(dtype_by_interface, tmp_output_dir_func): +def test_interface_roundtrip_json(all_passing_cases, tmp_output_dir_func): """ All interfaces should be able to roundtrip to and from json """ - if "subclass" in dtype_by_interface.id.lower(): + if "subclass" in all_passing_cases.id.lower(): pytest.xfail() - array = dtype_by_interface.array(path=tmp_output_dir_func) - case = dtype_by_interface.model(array=array) + array = all_passing_cases.array(path=tmp_output_dir_func) + case = all_passing_cases.model(array=array) dumped_json = case.model_dump_json(round_trip=True) model = case.model_validate_json(dumped_json) @@ -123,16 +123,16 @@ def test_interface_mark_interface(an_interface): @pytest.mark.serialization @pytest.mark.parametrize("valid", [True, False]) @pytest.mark.filterwarnings("ignore:Mismatch between serialized mark") -def test_interface_mark_roundtrip(dtype_by_interface, valid, tmp_output_dir_func): +def test_interface_mark_roundtrip(all_passing_cases, valid, tmp_output_dir_func): """ All interfaces should be able to roundtrip with the marked interface, and a mismatch should raise a warning and attempt to proceed """ - if "subclass" in dtype_by_interface.id.lower(): + if "subclass" in all_passing_cases.id.lower(): pytest.xfail() - array = dtype_by_interface.array(path=tmp_output_dir_func) - case = dtype_by_interface.model(array=array) + array = all_passing_cases.array(path=tmp_output_dir_func) + case = all_passing_cases.model(array=array) dumped_json = case.model_dump_json( round_trip=True, context={"mark_interface": True}