unbreak numpydantic

This commit is contained in:
sneakers-the-rat 2024-09-25 17:30:26 -07:00
parent 5b59e9fc07
commit 4ad3dc0696
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D

View file

@ -5,7 +5,9 @@ For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
"""
import sys
from typing import Any, Union, get_origin
from typing import Any, Union, get_args, get_origin
import numpy as np
from numpydantic.types import DtypeType
@ -26,31 +28,29 @@ def validate_dtype(dtype: Any, target: DtypeType) -> bool:
Returns:
bool: ``True`` if valid, ``False`` otherwise
"""
return False
#
# if target is Any:
# return True
#
# if isinstance(target, tuple):
# valid = dtype in target
# elif is_union(target):
# valid = any(
# [validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
# )
# elif target is np.str_:
# valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
# np.str_,
# str,
# )
# else:
# # try to match as any subclass, if target is a class
# try:
# valid = issubclass(dtype, target)
# except TypeError:
# # expected, if dtype or target is not a class
# valid = dtype == target
#
# return valid
if target is Any:
return True
if isinstance(target, tuple):
valid = dtype in target
elif is_union(target):
valid = any(
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
)
elif target is np.str_:
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else:
# try to match as any subclass, if target is a class
try:
valid = issubclass(dtype, target)
except TypeError:
# expected, if dtype or target is not a class
valid = dtype == target
return valid
def is_union(dtype: DtypeType) -> bool: