From 4ad3dc06968c3bd2617db34854a3af9169c88abd Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 25 Sep 2024 17:30:26 -0700 Subject: [PATCH] unbreak numpydantic --- src/numpydantic/validation/dtype.py | 52 ++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/numpydantic/validation/dtype.py b/src/numpydantic/validation/dtype.py index a5bb6ac..5eeb124 100644 --- a/src/numpydantic/validation/dtype.py +++ b/src/numpydantic/validation/dtype.py @@ -5,7 +5,9 @@ For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype` """ import sys -from typing import Any, Union, get_origin +from typing import Any, Union, get_args, get_origin + +import numpy as np from numpydantic.types import DtypeType @@ -26,31 +28,29 @@ def validate_dtype(dtype: Any, target: DtypeType) -> bool: Returns: bool: ``True`` if valid, ``False`` otherwise """ - return False - # - # if target is Any: - # return True - # - # if isinstance(target, tuple): - # valid = dtype in target - # elif is_union(target): - # valid = any( - # [validate_dtype(dtype, target_dt) for target_dt in get_args(target)] - # ) - # elif target is np.str_: - # valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( - # np.str_, - # str, - # ) - # else: - # # try to match as any subclass, if target is a class - # try: - # valid = issubclass(dtype, target) - # except TypeError: - # # expected, if dtype or target is not a class - # valid = dtype == target - # - # return valid + if target is Any: + return True + + if isinstance(target, tuple): + valid = dtype in target + elif is_union(target): + valid = any( + [validate_dtype(dtype, target_dt) for target_dt in get_args(target)] + ) + elif target is np.str_: + valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( + np.str_, + str, + ) + else: + # try to match as any subclass, if target is a class + try: + valid = issubclass(dtype, target) + except TypeError: + # expected, if dtype or target is not a class + valid = dtype == target + + return valid def is_union(dtype: DtypeType) -> bool: