From 22341c8b06dfa0c53e68847ed25b9898b4aa1890 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 13 Dec 2024 17:53:20 -0800 Subject: [PATCH] correctly failing tests for np.str_ in a tuple --- src/numpydantic/testing/cases.py | 21 +++++++++++++++++++++ src/numpydantic/testing/interfaces.py | 4 +++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/numpydantic/testing/cases.py b/src/numpydantic/testing/cases.py index 042317d..4f715fd 100644 --- a/src/numpydantic/testing/cases.py +++ b/src/numpydantic/testing/cases.py @@ -126,6 +126,27 @@ DTYPE_CASES = [ ValidationCase(annotation_dtype=str, dtype=str, passes=True, id="str-str"), ValidationCase(annotation_dtype=str, dtype=int, passes=False, id="str-int"), ValidationCase(annotation_dtype=str, dtype=float, passes=False, id="str-float"), + ValidationCase( + annotation_dtype=np.str_, + dtype=str, + passes=True, + id="np_str-str", + marks={"np_str", "str"}, + ), + ValidationCase( + annotation_dtype=np.str_, + dtype=np.str_, + passes=True, + id="np_str-np_str", + marks={"np_str", "str"}, + ), + ValidationCase( + annotation_dtype=(int, np.str_), + dtype=str, + passes=True, + id="tuple_np_str-str", + marks={"np_str", "str", "tuple"}, + ), ValidationCase( annotation_dtype=BasicModel, dtype=BasicModel, passes=True, id="model-model" ), diff --git a/src/numpydantic/testing/interfaces.py b/src/numpydantic/testing/interfaces.py index a11bc6a..c247169 100644 --- a/src/numpydantic/testing/interfaces.py +++ b/src/numpydantic/testing/interfaces.py @@ -75,6 +75,8 @@ class HDF5Case(_HDF5MetaCase): data = np.array(array, dtype=dtype) elif dtype is str: data = generator.random(shape).astype(bytes) + elif dtype is np.str_: + data = generator.random(shape).astype("S32") elif dtype is datetime: data = np.empty(shape, dtype="S32") data.fill(datetime.now(timezone.utc).isoformat().encode("utf-8")) @@ -106,7 +108,7 @@ class HDF5CompoundCase(_HDF5MetaCase): array_path = "/" + "_".join([str(s) for s in shape]) + "__" + dtype.__name__ if array is not None: data = np.array(array, dtype=dtype) - elif dtype is str: + elif dtype in (str, np.str_): dt = np.dtype([("data", np.dtype("S10")), ("extra", "i8")]) data = np.array([("hey", 0)] * np.prod(shape), dtype=dt).reshape(shape) elif dtype is datetime: