mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-10 00:34:29 +00:00
commit
13a8fce4ef
10 changed files with 242 additions and 19 deletions
|
@ -17,6 +17,7 @@ relatively low. Its `Dtype[ArrayClass, "{shape_expression}"]` syntax is not well
|
||||||
suited for modeling arrays intended to be general across implementations, and
|
suited for modeling arrays intended to be general across implementations, and
|
||||||
makes it challenging to adapt to pydantic's schema generation system.
|
makes it challenging to adapt to pydantic's schema generation system.
|
||||||
|
|
||||||
|
(design_challenges)=
|
||||||
## Challenges
|
## Challenges
|
||||||
|
|
||||||
The Python type annotation system is weird and not like the rest of Python!
|
The Python type annotation system is weird and not like the rest of Python!
|
||||||
|
|
|
@ -57,6 +57,25 @@ model = MyModel(array=('data.zarr', '/nested/dataset'))
|
||||||
model = MyModel(array="data.mp4")
|
model = MyModel(array="data.mp4")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
And use the `NDArray` type annotation like a regular type outside
|
||||||
|
of pydantic -- eg. to validate an array anywhere, use `isinstance`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
array_type = NDArray[Shape["1, 2, 3"], int]
|
||||||
|
isinstance(np.zeros((1,2,3), dtype=int), array_type)
|
||||||
|
# True
|
||||||
|
isinstance(zarr.zeros((1,2,3), dtype=int), array_type)
|
||||||
|
# True
|
||||||
|
isinstance(np.zeros((4,5,6), dtype=int), array_type)
|
||||||
|
# False
|
||||||
|
isinstance(np.zeros((1,2,3), dtype=float), array_type)
|
||||||
|
# False
|
||||||
|
```
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
`NDArray` can't do validation with static type checkers yet, see
|
||||||
|
{ref}`design_challenges` and {ref}`type_checkers`
|
||||||
|
```
|
||||||
|
|
||||||
## Features:
|
## Features:
|
||||||
- **Types** - Annotations (based on [npytyping](https://github.com/ramonhagenaars/nptyping))
|
- **Types** - Annotations (based on [npytyping](https://github.com/ramonhagenaars/nptyping))
|
||||||
|
|
15
docs/todo.md
15
docs/todo.md
|
@ -10,6 +10,21 @@ type system and is no longer actively maintained. We will be reimplementing a sy
|
||||||
that extends its array specification syntax to include things like ranges and extensible
|
that extends its array specification syntax to include things like ranges and extensible
|
||||||
dtypes with varying precision (and is much less finnicky to deal with).
|
dtypes with varying precision (and is much less finnicky to deal with).
|
||||||
|
|
||||||
|
(type_checkers)=
|
||||||
|
## Type Checker Integration
|
||||||
|
|
||||||
|
The `.pyi` stubfile generation ({mod}`numpydantic.meta`) works for
|
||||||
|
keeping type checkers from complaining about various array formats
|
||||||
|
not literally being `NDArray` objects, but it doesn't do the kind of
|
||||||
|
validation we would want to be able to use `NDArray` objects as full-fledged
|
||||||
|
python types, including validation propagation through scopes and
|
||||||
|
IDE type checking for invalid literals.
|
||||||
|
|
||||||
|
We want to hook into the type checking process to satisfy these type checkers:
|
||||||
|
- mypy - has hooks, can be done with an extension
|
||||||
|
- pyright - unclear if has hooks, might nee to monkeypatch
|
||||||
|
- pycharm - unlikely this is possible, extensions need to be in Java and installed separately
|
||||||
|
|
||||||
|
|
||||||
## Validation
|
## Validation
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,25 @@ Exceptions used within numpydantic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class DtypeError(TypeError):
|
class InterfaceError(Exception):
|
||||||
|
"""Parent mixin class for errors raised by :class:`.Interface` subclasses"""
|
||||||
|
|
||||||
|
|
||||||
|
class DtypeError(TypeError, InterfaceError):
|
||||||
"""Exception raised for invalid dtypes"""
|
"""Exception raised for invalid dtypes"""
|
||||||
|
|
||||||
|
|
||||||
class ShapeError(ValueError):
|
class ShapeError(ValueError, InterfaceError):
|
||||||
"""Exception raise for invalid shapes"""
|
"""Exception raise for invalid shapes"""
|
||||||
|
|
||||||
|
|
||||||
|
class MatchError(ValueError, InterfaceError):
|
||||||
|
"""Exception for errors raised during :class:`.Interface.match`-ing"""
|
||||||
|
|
||||||
|
|
||||||
|
class NoMatchError(MatchError):
|
||||||
|
"""No match was found by :class:`.Interface.match`"""
|
||||||
|
|
||||||
|
|
||||||
|
class TooManyMatchesError(MatchError):
|
||||||
|
"""Too many matches found by :class:`.Interface.match`"""
|
||||||
|
|
|
@ -10,7 +10,12 @@ import numpy as np
|
||||||
from nptyping.shape_expression import check_shape
|
from nptyping.shape_expression import check_shape
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.exceptions import DtypeError, ShapeError
|
from numpydantic.exceptions import (
|
||||||
|
DtypeError,
|
||||||
|
NoMatchError,
|
||||||
|
ShapeError,
|
||||||
|
TooManyMatchesError,
|
||||||
|
)
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
|
||||||
T = TypeVar("T", bound=NDArrayType)
|
T = TypeVar("T", bound=NDArrayType)
|
||||||
|
@ -32,6 +37,25 @@ class Interface(ABC, Generic[T]):
|
||||||
def validate(self, array: Any) -> T:
|
def validate(self, array: Any) -> T:
|
||||||
"""
|
"""
|
||||||
Validate input, returning final array type
|
Validate input, returning final array type
|
||||||
|
|
||||||
|
Calls the methods, in order:
|
||||||
|
|
||||||
|
* :meth:`.before_validation`
|
||||||
|
* :meth:`.validate_dtype`
|
||||||
|
* :meth:`.validate_shape`
|
||||||
|
* :meth:`.after_validation`
|
||||||
|
|
||||||
|
passing the ``array`` argument and returning it from each.
|
||||||
|
|
||||||
|
Implementing an interface subclass largely consists of overriding these methods
|
||||||
|
as needed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
If validation fails, rather than eg. returning ``False``, exceptions will
|
||||||
|
be raised (to halt the rest of the pydantic validation process).
|
||||||
|
When using interfaces outside of pydantic, you must catch both
|
||||||
|
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
||||||
|
of :class:`.InterfaceError` )
|
||||||
"""
|
"""
|
||||||
array = self.before_validation(array)
|
array = self.before_validation(array)
|
||||||
array = self.validate_dtype(array)
|
array = self.validate_dtype(array)
|
||||||
|
@ -120,17 +144,38 @@ class Interface(ABC, Generic[T]):
|
||||||
return array.tolist()
|
return array.tolist()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def interfaces(cls) -> Tuple[Type["Interface"], ...]:
|
def interfaces(
|
||||||
|
cls, with_disabled: bool = False, sort: bool = True
|
||||||
|
) -> Tuple[Type["Interface"], ...]:
|
||||||
"""
|
"""
|
||||||
Enabled interface subclasses
|
Enabled interface subclasses
|
||||||
|
|
||||||
|
Args:
|
||||||
|
with_disabled (bool): If ``True`` , get every known interface.
|
||||||
|
If ``False`` (default), get only enabled interfaces.
|
||||||
|
sort (bool): If ``True`` (default), sort interfaces by priority.
|
||||||
|
If ``False`` , sorted by definition order. Used for recursion:
|
||||||
|
we only want to sort once at the top level.
|
||||||
"""
|
"""
|
||||||
return tuple(
|
# get recursively
|
||||||
sorted(
|
subclasses = []
|
||||||
[i for i in Interface.__subclasses__() if i.enabled()],
|
for i in cls.__subclasses__():
|
||||||
|
if with_disabled:
|
||||||
|
subclasses.append(i)
|
||||||
|
|
||||||
|
if i.enabled():
|
||||||
|
subclasses.append(i)
|
||||||
|
|
||||||
|
subclasses.extend(i.interfaces(with_disabled=with_disabled, sort=False))
|
||||||
|
|
||||||
|
if sort:
|
||||||
|
subclasses = sorted(
|
||||||
|
subclasses,
|
||||||
key=attrgetter("priority"),
|
key=attrgetter("priority"),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
return tuple(subclasses)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def return_types(cls) -> Tuple[NDArrayType, ...]:
|
def return_types(cls) -> Tuple[NDArrayType, ...]:
|
||||||
|
@ -150,9 +195,21 @@ class Interface(ABC, Generic[T]):
|
||||||
return tuple(in_types)
|
return tuple(in_types)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def match(cls, array: Any) -> Type["Interface"]:
|
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
||||||
"""
|
"""
|
||||||
Find the interface that should be used for this array based on its input type
|
Find the interface that should be used for this array based on its input type
|
||||||
|
|
||||||
|
First runs the ``check`` method for all interfaces returned by
|
||||||
|
:meth:`.Interface.interfaces` **except** for :class:`.NumpyInterface` ,
|
||||||
|
and if no match is found then try the numpy interface. This is because
|
||||||
|
:meth:`.NumpyInterface.check` can be expensive, as we could potentially
|
||||||
|
try to
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fast (bool): if ``False`` , check all interfaces and raise exceptions for
|
||||||
|
having multiple matching interfaces (default). If ``True`` ,
|
||||||
|
check each interface (as ordered by its ``priority`` , decreasing),
|
||||||
|
and return on the first match.
|
||||||
"""
|
"""
|
||||||
# first try and find a non-numpy interface, since the numpy interface
|
# first try and find a non-numpy interface, since the numpy interface
|
||||||
# will try and load the array into memory in its check method
|
# will try and load the array into memory in its check method
|
||||||
|
@ -160,17 +217,24 @@ class Interface(ABC, Generic[T]):
|
||||||
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
|
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
|
||||||
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
|
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
|
||||||
|
|
||||||
|
if fast:
|
||||||
|
matches = []
|
||||||
|
for i in non_np_interfaces:
|
||||||
|
if i.check(array):
|
||||||
|
return i
|
||||||
|
else:
|
||||||
matches = [i for i in non_np_interfaces if i.check(array)]
|
matches = [i for i in non_np_interfaces if i.check(array)]
|
||||||
|
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
msg = f"More than one interface matches input {array}:\n"
|
msg = f"More than one interface matches input {array}:\n"
|
||||||
msg += "\n".join([f" - {i}" for i in matches])
|
msg += "\n".join([f" - {i}" for i in matches])
|
||||||
raise ValueError(msg)
|
raise TooManyMatchesError(msg)
|
||||||
elif len(matches) == 0:
|
elif len(matches) == 0:
|
||||||
# now try the numpy interface
|
# now try the numpy interface
|
||||||
if np_interface.check(array):
|
if np_interface.check(array):
|
||||||
return np_interface
|
return np_interface
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"No matching interfaces found for input {array}")
|
raise NoMatchError(f"No matching interfaces found for input {array}")
|
||||||
else:
|
else:
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
||||||
|
@ -186,8 +250,8 @@ class Interface(ABC, Generic[T]):
|
||||||
if len(matches) > 1:
|
if len(matches) > 1:
|
||||||
msg = f"More than one interface matches output {array}:\n"
|
msg = f"More than one interface matches output {array}:\n"
|
||||||
msg += "\n".join([f" - {i}" for i in matches])
|
msg += "\n".join([f" - {i}" for i in matches])
|
||||||
raise ValueError(msg)
|
raise TooManyMatchesError(msg)
|
||||||
elif len(matches) == 0:
|
elif len(matches) == 0:
|
||||||
raise ValueError(f"No matching interfaces found for output {array}")
|
raise NoMatchError(f"No matching interfaces found for output {array}")
|
||||||
else:
|
else:
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
|
@ -24,6 +24,7 @@ def generate_ndarray_stub() -> str:
|
||||||
# Create import statements, saving aliased name of type if needed
|
# Create import statements, saving aliased name of type if needed
|
||||||
if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing":
|
if arr.__module__.startswith("numpydantic") or arr.__module__ == "typing":
|
||||||
type_name = str(arr) if arr.__module__ == "typing" else arr.__name__
|
type_name = str(arr) if arr.__module__ == "typing" else arr.__name__
|
||||||
|
if arr.__module__ != "typing":
|
||||||
import_strings.append(f"from {arr.__module__} import {type_name}")
|
import_strings.append(f"from {arr.__module__} import {type_name}")
|
||||||
else:
|
else:
|
||||||
# since other packages could use the same name for an imported object
|
# since other packages could use the same name for an imported object
|
||||||
|
@ -39,6 +40,7 @@ def generate_ndarray_stub() -> str:
|
||||||
type_names.append(type_name)
|
type_names.append(type_name)
|
||||||
|
|
||||||
import_strings.extend(_BUILTIN_IMPORTS)
|
import_strings.extend(_BUILTIN_IMPORTS)
|
||||||
|
import_strings = list(dict.fromkeys(import_strings))
|
||||||
import_string = "\n".join(import_strings)
|
import_string = "\n".join(import_strings)
|
||||||
|
|
||||||
class_union = " | ".join(type_names)
|
class_union = " | ".join(type_names)
|
||||||
|
|
|
@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Tuple
|
from typing import TYPE_CHECKING, Any, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nptyping.error import InvalidArgumentsError
|
from nptyping.error import InvalidArgumentsError
|
||||||
|
@ -28,6 +28,8 @@ from pydantic import GetJsonSchemaHandler
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
from numpydantic.dtype import DType
|
from numpydantic.dtype import DType
|
||||||
|
from numpydantic.exceptions import InterfaceError
|
||||||
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.maps import python_to_nptyping
|
from numpydantic.maps import python_to_nptyping
|
||||||
from numpydantic.schema import (
|
from numpydantic.schema import (
|
||||||
_handler_type,
|
_handler_type,
|
||||||
|
@ -37,6 +39,9 @@ from numpydantic.schema import (
|
||||||
)
|
)
|
||||||
from numpydantic.types import DtypeType, ShapeType
|
from numpydantic.types import DtypeType, ShapeType
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from nptyping.base_meta_classes import SubscriptableMeta
|
||||||
|
|
||||||
|
|
||||||
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
"""
|
"""
|
||||||
|
@ -44,6 +49,35 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
completion of the transition away from nptyping
|
completion of the transition away from nptyping
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
__getitem__ = SubscriptableMeta.__getitem__
|
||||||
|
|
||||||
|
def __instancecheck__(self, instance: Any):
|
||||||
|
"""
|
||||||
|
Extended type checking that determines whether
|
||||||
|
|
||||||
|
1) the ``type`` of the given instance is one of those in
|
||||||
|
:meth:`.Interface.input_types`
|
||||||
|
|
||||||
|
but also
|
||||||
|
|
||||||
|
2) it satisfies the constraints set on the :class:`.NDArray` annotation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance (:class:`typing.Any`): Thing to check!
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if matches constraints, ``False`` otherwise.
|
||||||
|
"""
|
||||||
|
shape, dtype = self.__args__
|
||||||
|
try:
|
||||||
|
interface_cls = Interface.match(instance, fast=True)
|
||||||
|
interface = interface_cls(shape, dtype)
|
||||||
|
_ = interface.validate(instance)
|
||||||
|
return True
|
||||||
|
except InterfaceError:
|
||||||
|
return False
|
||||||
|
|
||||||
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
def _get_dtype(cls, dtype_candidate: Any) -> DType:
|
||||||
"""
|
"""
|
||||||
Override of base _get_dtype method to allow for compound tuple types
|
Override of base _get_dtype method to allow for compound tuple types
|
||||||
|
|
|
@ -225,7 +225,9 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
||||||
:meth:`.Interface.validate` method
|
:meth:`.Interface.validate` method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def validate_interface(value: Any, info: "ValidationInfo") -> NDArrayType:
|
def validate_interface(
|
||||||
|
value: Any, info: Optional["ValidationInfo"] = None
|
||||||
|
) -> NDArrayType:
|
||||||
interface_cls = Interface.match(value)
|
interface_cls = Interface.match(value)
|
||||||
interface = interface_cls(shape, dtype)
|
interface = interface_cls(shape, dtype)
|
||||||
value = interface.validate(value)
|
value = interface.validate(value)
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import pdb
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -14,9 +12,12 @@ def interfaces():
|
||||||
class Interface1(Interface):
|
class Interface1(Interface):
|
||||||
input_types = (list,)
|
input_types = (list,)
|
||||||
return_type = tuple
|
return_type = tuple
|
||||||
|
priority = 1000
|
||||||
|
checked = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check(cls, array):
|
def check(cls, array):
|
||||||
|
cls.checked = True
|
||||||
if isinstance(array, list):
|
if isinstance(array, list):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@ -26,18 +27,34 @@ def interfaces():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__))
|
Interface2 = type("Interface2", Interface1.__bases__, dict(Interface1.__dict__))
|
||||||
|
Interface2.checked = False
|
||||||
|
Interface2.priority = 999
|
||||||
|
|
||||||
class Interface3(Interface1):
|
class Interface3(Interface1):
|
||||||
|
priority = 998
|
||||||
|
checked = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class Interface4(Interface3):
|
||||||
|
priority = 997
|
||||||
|
checked = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
class Interfaces:
|
class Interfaces:
|
||||||
interface1 = Interface1
|
interface1 = Interface1
|
||||||
interface2 = Interface2
|
interface2 = Interface2
|
||||||
interface3 = Interface3
|
interface3 = Interface3
|
||||||
|
interface4 = Interface4
|
||||||
|
|
||||||
yield Interfaces
|
yield Interfaces
|
||||||
|
# Interface.__subclasses__().remove(Interface1)
|
||||||
|
# Interface.__subclasses__().remove(Interface2)
|
||||||
del Interface1
|
del Interface1
|
||||||
del Interface2
|
del Interface2
|
||||||
del Interface3
|
del Interface3
|
||||||
|
@ -66,6 +83,20 @@ def test_interface_match_error(interfaces):
|
||||||
assert "No matching interfaces" in str(e.value)
|
assert "No matching interfaces" in str(e.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_match_fast(interfaces):
|
||||||
|
"""
|
||||||
|
fast matching should return as soon as an interface is found
|
||||||
|
and not raise an error for duplicates
|
||||||
|
"""
|
||||||
|
Interface.interfaces()[0].checked = False
|
||||||
|
Interface.interfaces()[1].checked = False
|
||||||
|
# this doesnt' raise an error
|
||||||
|
matched = Interface.match([1, 2, 3], fast=True)
|
||||||
|
assert matched == Interface.interfaces()[0]
|
||||||
|
assert Interface.interfaces()[0].checked
|
||||||
|
assert not Interface.interfaces()[1].checked
|
||||||
|
|
||||||
|
|
||||||
def test_interface_enabled(interfaces):
|
def test_interface_enabled(interfaces):
|
||||||
"""
|
"""
|
||||||
An interface shouldn't be included if it's not enabled
|
An interface shouldn't be included if it's not enabled
|
||||||
|
@ -101,3 +132,22 @@ def test_interfaces_sorting():
|
||||||
ifaces = Interface.interfaces()
|
ifaces = Interface.interfaces()
|
||||||
priorities = [i.priority for i in ifaces]
|
priorities = [i.priority for i in ifaces]
|
||||||
assert (np.diff(priorities) <= 0).all()
|
assert (np.diff(priorities) <= 0).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_with_disabled(interfaces):
|
||||||
|
"""
|
||||||
|
Get all interfaces, even if not enabled
|
||||||
|
"""
|
||||||
|
ifaces = Interface.interfaces(with_disabled=True)
|
||||||
|
assert interfaces.interface3 in ifaces
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_recursive(interfaces):
|
||||||
|
"""
|
||||||
|
Get all interfaces, including subclasses of subclasses
|
||||||
|
"""
|
||||||
|
ifaces = Interface.interfaces()
|
||||||
|
assert issubclass(interfaces.interface4, interfaces.interface3)
|
||||||
|
assert issubclass(interfaces.interface3, interfaces.interface1)
|
||||||
|
assert issubclass(interfaces.interface1, Interface)
|
||||||
|
assert interfaces.interface4 in ifaces
|
||||||
|
|
|
@ -223,3 +223,23 @@ def test_json_schema_ellipsis():
|
||||||
|
|
||||||
schema = ConstrainedAnyShape.model_json_schema()
|
schema = ConstrainedAnyShape.model_json_schema()
|
||||||
_recursive_array(schema)
|
_recursive_array(schema)
|
||||||
|
|
||||||
|
|
||||||
|
def test_instancecheck():
|
||||||
|
"""
|
||||||
|
NDArray should handle ``isinstance()`` s.t. valid arrays are ``True``
|
||||||
|
and invalid arrays are ``False``
|
||||||
|
|
||||||
|
We don't make this test exhaustive because correctness of validation
|
||||||
|
is tested elsewhere. We are just testing that the type checking works
|
||||||
|
"""
|
||||||
|
array_type = NDArray[Shape["1, 2, 3"], int]
|
||||||
|
|
||||||
|
assert isinstance(np.zeros((1, 2, 3), dtype=int), array_type)
|
||||||
|
assert not isinstance(np.zeros((2, 2, 3), dtype=int), array_type)
|
||||||
|
assert not isinstance(np.zeros((1, 2, 3), dtype=float), array_type)
|
||||||
|
|
||||||
|
def my_function(array: NDArray[Shape["1, 2, 3"], int]):
|
||||||
|
return array
|
||||||
|
|
||||||
|
my_function(np.zeros((1, 2, 3), int))
|
||||||
|
|
Loading…
Reference in a new issue