recursively validate tuples instead of simple containment checking

This commit is contained in:
sneakers-the-rat 2024-12-13 17:55:35 -08:00
parent 22341c8b06
commit e942da4bec
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D

View file

@ -32,7 +32,7 @@ def validate_dtype(dtype: Any, target: DtypeType) -> bool:
return True return True
if isinstance(target, tuple): if isinstance(target, tuple):
valid = dtype in target valid = any(validate_dtype(dtype, target_dt) for target_dt in target)
elif is_union(target): elif is_union(target):
valid = any( valid = any(
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)] [validate_dtype(dtype, target_dt) for target_dt in get_args(target)]