python 3.9 compat

This commit is contained in:
sneakers-the-rat 2024-09-23 23:30:40 -07:00
parent 85cef50603
commit e63d9268b1
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D

View file

@ -3,14 +3,18 @@ Helper functions for validation of dtype.
For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype` For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
""" """
import sys
from types import UnionType
from typing import Any, Union, get_args, get_origin from typing import Any, Union, get_args, get_origin
import numpy as np import numpy as np
from numpydantic.types import DtypeType from numpydantic.types import DtypeType
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = None
def validate_dtype(dtype: Any, target: DtypeType) -> bool: def validate_dtype(dtype: Any, target: DtypeType) -> bool:
""" """
@ -52,4 +56,7 @@ def is_union(dtype: DtypeType) -> bool:
""" """
Check if a dtype is a union Check if a dtype is a union
""" """
return get_origin(dtype) in (Union, UnionType) if UnionType is None:
return get_origin(dtype) is Union
else:
return get_origin(dtype) in (Union, UnionType)