mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
Better string dtype checking support, restructuring the validation hooks to allow finer grained control over the process.
This commit is contained in:
parent
880dafb151
commit
b2db1014bd
4 changed files with 146 additions and 34 deletions
|
@ -41,19 +41,42 @@ when cast to an `ndarray`, we only try as a last resort.
|
||||||
|
|
||||||
## Validation
|
## Validation
|
||||||
|
|
||||||
Validation is a chain of lifecycle methods, with a single argument passed and returned
|
Validation is a chain of lifecycle methods, each of which can be overridden
|
||||||
to and from each:
|
for interfaces to implement custom behavior that matches the array format.
|
||||||
|
|
||||||
{meth}`.Interface.validate` calls in order:
|
{meth}`.Interface.validate` calls the following methods, in order:
|
||||||
|
|
||||||
|
An initial hook for modifying the input data before validation, eg.
|
||||||
|
if it needs to be coerced or wrapped in some proxy class. This method
|
||||||
|
should accept all and only the types specified in that interface's
|
||||||
|
{attr}`~.Interface.input_types`.
|
||||||
|
|
||||||
- {meth}`.Interface.before_validation`
|
- {meth}`.Interface.before_validation`
|
||||||
- {meth}`.Interface.validate_dtype`
|
|
||||||
- {meth}`.Interface.validate_shape`
|
|
||||||
- {meth}`.Interface.after_validation`
|
|
||||||
|
|
||||||
The `before` and `after` methods provide hooks for coercion, loading, etc. such that
|
A cluster of methods for validating dtype.
|
||||||
`validate` can accept one of the types in the interface's
|
Separating these methods allow for array formats that store dtype information
|
||||||
{attr}`~.Interface.input_types` and return the {attr}`~.Interface.return_type` .
|
in a nonstandard attribute, require additional coercion, or for implementing
|
||||||
|
custom exception handlers or rescuers.
|
||||||
|
Check the method signatures and return types
|
||||||
|
when overriding and the docstrings for details.
|
||||||
|
|
||||||
|
- {meth}`.Interface.get_dtype`
|
||||||
|
- {meth}`.Interface.validate_dtype`
|
||||||
|
- {meth}`.Interface.raise_for_dtype`
|
||||||
|
|
||||||
|
A halftime hook for modifying the array or bailing early between validation phases.
|
||||||
|
|
||||||
|
- {meth}`.Interface.after_validate_dtype`
|
||||||
|
|
||||||
|
A cluster of methods for validating shape, similar to the dtype cluster.
|
||||||
|
|
||||||
|
- {meth}`.Interface.get_shape`
|
||||||
|
- {meth}`.Interface.validate_shape`
|
||||||
|
- {meth}`.Interface.raise_for_shape`
|
||||||
|
|
||||||
|
A final hook for modifying the array before passing it to be assigned to the field.
|
||||||
|
This method should return an object matching the interface's {attr}`~.Interface.return_type`.
|
||||||
|
- {meth}`.Interface.after_validation`
|
||||||
|
|
||||||
## Diagram
|
## Diagram
|
||||||
|
|
||||||
|
|
|
@ -40,12 +40,29 @@ class Interface(ABC, Generic[T]):
|
||||||
|
|
||||||
Calls the methods, in order:
|
Calls the methods, in order:
|
||||||
|
|
||||||
* :meth:`.before_validation`
|
* array = :meth:`.before_validation` (array)
|
||||||
* :meth:`.validate_dtype`
|
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
||||||
* :meth:`.validate_shape`
|
override if eg. the dtype is not contained in ``array.dtype``
|
||||||
* :meth:`.after_validation`
|
* valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
|
||||||
|
the one in the NDArray specification. Override if special
|
||||||
|
validation logic is needed for a given format
|
||||||
|
* :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
|
||||||
|
raise an exception if it was invalid. Override to implement custom
|
||||||
|
exceptions or error conditions, or make validation errors conditional.
|
||||||
|
* array = :meth:`.after_validate_dtype` (array) - hook for additional
|
||||||
|
validation or array modification mid-validation
|
||||||
|
* shape = :meth:`.get_shape` (array) - get the shape from the array,
|
||||||
|
override if eg. the shape is not contained in ``array.shape``
|
||||||
|
* valid = :meth:`.validate_shape` (shape) - check that the shape matches
|
||||||
|
the one in the NDArray specification. Override if special validation
|
||||||
|
logic is needed.
|
||||||
|
* :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
|
||||||
|
raise an exception if it was invalid. You know the deal bc it's the same
|
||||||
|
as raise for dtype.
|
||||||
|
* :meth:`.after_validation` - hook after validation for modifying the array
|
||||||
|
that is set as the model field value
|
||||||
|
|
||||||
passing the ``array`` argument and returning it from each.
|
Follow the method signatures and return types to override
|
||||||
|
|
||||||
Implementing an interface subclass largely consists of overriding these methods
|
Implementing an interface subclass largely consists of overriding these methods
|
||||||
as needed.
|
as needed.
|
||||||
|
@ -58,8 +75,16 @@ class Interface(ABC, Generic[T]):
|
||||||
of :class:`.InterfaceError` )
|
of :class:`.InterfaceError` )
|
||||||
"""
|
"""
|
||||||
array = self.before_validation(array)
|
array = self.before_validation(array)
|
||||||
array = self.validate_dtype(array)
|
|
||||||
array = self.validate_shape(array)
|
dtype = self.get_dtype(array)
|
||||||
|
dtype_valid = self.validate_dtype(dtype)
|
||||||
|
self.raise_for_dtype(dtype_valid, dtype)
|
||||||
|
array = self.after_validate_dtype(array)
|
||||||
|
|
||||||
|
shape = self.get_shape(array)
|
||||||
|
shape_valid = self.validate_shape(shape)
|
||||||
|
self.raise_for_shape(shape_valid, shape)
|
||||||
|
|
||||||
array = self.after_validation(array)
|
array = self.after_validation(array)
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
@ -72,40 +97,76 @@ class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
return array
|
return array
|
||||||
|
|
||||||
def validate_dtype(self, array: NDArrayType) -> NDArrayType:
|
def get_dtype(self, array: NDArrayType) -> DtypeType:
|
||||||
"""
|
"""
|
||||||
Validate the dtype of the given array, returning it unmutated.
|
Get the dtype from the input array
|
||||||
|
"""
|
||||||
|
return array.dtype
|
||||||
|
|
||||||
|
def validate_dtype(self, dtype: DtypeType) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the dtype of the given array, returning
|
||||||
|
``True`` if valid, ``False`` if not.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.dtype is Any:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(self.dtype, tuple):
|
||||||
|
valid = dtype in self.dtype
|
||||||
|
elif self.dtype is np.str_:
|
||||||
|
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
|
||||||
|
else:
|
||||||
|
valid = dtype == self.dtype
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
||||||
|
"""
|
||||||
|
After validating, raise an exception if invalid
|
||||||
Raises:
|
Raises:
|
||||||
:class:`~numpydantic.exceptions.DtypeError`
|
:class:`~numpydantic.exceptions.DtypeError`
|
||||||
"""
|
"""
|
||||||
if self.dtype is Any:
|
|
||||||
return array
|
|
||||||
|
|
||||||
if isinstance(self.dtype, tuple):
|
|
||||||
valid = array.dtype in self.dtype
|
|
||||||
else:
|
|
||||||
valid = array.dtype == self.dtype
|
|
||||||
|
|
||||||
if not valid:
|
if not valid:
|
||||||
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
|
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
|
||||||
|
|
||||||
|
def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
|
||||||
|
"""
|
||||||
|
Hook to modify array after validating dtype.
|
||||||
|
Default is a no-op.
|
||||||
|
"""
|
||||||
return array
|
return array
|
||||||
|
|
||||||
def validate_shape(self, array: NDArrayType) -> NDArrayType:
|
def get_shape(self, array: NDArrayType) -> Tuple[int, ...]:
|
||||||
"""
|
"""
|
||||||
Validate the shape of the given array, returning it unmutated
|
Get the shape from the array as a tuple of integers
|
||||||
|
"""
|
||||||
|
return array.shape
|
||||||
|
|
||||||
|
def validate_shape(self, shape: Tuple[int, ...]) -> bool:
|
||||||
|
"""
|
||||||
|
Validate the shape of the given array against the shape
|
||||||
|
specifier, returning ``True`` if valid, ``False`` if not.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.shape is Any:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return check_shape(shape, self.shape)
|
||||||
|
|
||||||
|
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
|
||||||
|
"""
|
||||||
|
Raise a ShapeError if the shape is invalid.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
:class:`~numpydantic.exceptions.ShapeError`
|
:class:`~numpydantic.exceptions.ShapeError`
|
||||||
"""
|
"""
|
||||||
if self.shape is Any:
|
if not valid:
|
||||||
return array
|
|
||||||
if not check_shape(array.shape, self.shape):
|
|
||||||
raise ShapeError(
|
raise ShapeError(
|
||||||
f"Invalid shape! expected shape {self.shape.prepared_args}, "
|
f"Invalid shape! expected shape {self.shape.prepared_args}, "
|
||||||
f"got shape {array.shape}"
|
f"got shape {shape}"
|
||||||
)
|
)
|
||||||
return array
|
|
||||||
|
|
||||||
def after_validation(self, array: NDArrayType) -> T:
|
def after_validation(self, array: NDArrayType) -> T:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -7,18 +7,22 @@ from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Sequence, Union
|
from typing import Any, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface
|
||||||
|
from numpydantic.types import DtypeType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import zarr
|
import zarr
|
||||||
|
from numcodecs import VLenUTF8
|
||||||
from zarr.core import Array as ZarrArray
|
from zarr.core import Array as ZarrArray
|
||||||
from zarr.storage import StoreLike
|
from zarr.storage import StoreLike
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
ZarrArray = None
|
ZarrArray = None
|
||||||
StoreLike = None
|
StoreLike = None
|
||||||
storage = None
|
storage = None
|
||||||
|
VLenUTF8 = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -113,6 +117,19 @@ class ZarrInterface(Interface):
|
||||||
"""
|
"""
|
||||||
return self._get_array(array)
|
return self._get_array(array)
|
||||||
|
|
||||||
|
def get_dtype(self, array: ZarrArray) -> DtypeType:
|
||||||
|
"""
|
||||||
|
Override base dtype getter to handle zarr's string-as-object encoding.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
getattr(array.dtype, "type", None) is np.object_
|
||||||
|
and array.filters
|
||||||
|
and any([isinstance(f, VLenUTF8) for f in array.filters])
|
||||||
|
):
|
||||||
|
return np.str_
|
||||||
|
else:
|
||||||
|
return array.dtype
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_json(
|
def to_json(
|
||||||
cls,
|
cls,
|
||||||
|
|
|
@ -67,6 +67,7 @@ RGB_UNION: TypeAlias = Union[
|
||||||
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
||||||
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||||
|
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
|
@ -121,10 +122,15 @@ def shape_cases(request) -> ValidationCase:
|
||||||
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
||||||
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
||||||
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
||||||
|
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
|
||||||
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
||||||
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
||||||
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
||||||
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
||||||
|
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
|
||||||
|
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
||||||
|
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
||||||
|
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"float",
|
"float",
|
||||||
|
@ -139,10 +145,15 @@ def shape_cases(request) -> ValidationCase:
|
||||||
"integer-uint8",
|
"integer-uint8",
|
||||||
"integer-float",
|
"integer-float",
|
||||||
"integer-float32",
|
"integer-float32",
|
||||||
|
"integer-str",
|
||||||
"float-float",
|
"float-float",
|
||||||
"float-float32",
|
"float-float32",
|
||||||
"float-int",
|
"float-int",
|
||||||
"float-uint8",
|
"float-uint8",
|
||||||
|
"float-str",
|
||||||
|
"str-str",
|
||||||
|
"str-int",
|
||||||
|
"str-float",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def dtype_cases(request) -> ValidationCase:
|
def dtype_cases(request) -> ValidationCase:
|
||||||
|
|
Loading…
Reference in a new issue