mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 13:44:26 +00:00
Better string dtype checking support, restructuring the validation hooks to allow finer grained control over the process.
This commit is contained in:
parent
880dafb151
commit
b2db1014bd
4 changed files with 146 additions and 34 deletions
|
@ -41,19 +41,42 @@ when cast to an `ndarray`, we only try as a last resort.
|
|||
|
||||
## Validation
|
||||
|
||||
Validation is a chain of lifecycle methods, with a single argument passed and returned
|
||||
to and from each:
|
||||
Validation is a chain of lifecycle methods, each of which can be overridden
|
||||
for interfaces to implement custom behavior that matches the array format.
|
||||
|
||||
{meth}`.Interface.validate` calls in order:
|
||||
{meth}`.Interface.validate` calls the following methods, in order:
|
||||
|
||||
An initial hook for modifying the input data before validation, eg.
|
||||
if it needs to be coerced or wrapped in some proxy class. This method
|
||||
should accept all and only the types specified in that interface's
|
||||
{attr}`~.Interface.input_types`.
|
||||
|
||||
- {meth}`.Interface.before_validation`
|
||||
- {meth}`.Interface.validate_dtype`
|
||||
- {meth}`.Interface.validate_shape`
|
||||
- {meth}`.Interface.after_validation`
|
||||
|
||||
The `before` and `after` methods provide hooks for coercion, loading, etc. such that
|
||||
`validate` can accept one of the types in the interface's
|
||||
{attr}`~.Interface.input_types` and return the {attr}`~.Interface.return_type` .
|
||||
A cluster of methods for validating dtype.
|
||||
Separating these methods allow for array formats that store dtype information
|
||||
in a nonstandard attribute, require additional coercion, or for implementing
|
||||
custom exception handlers or rescuers.
|
||||
Check the method signatures and return types
|
||||
when overriding and the docstrings for details.
|
||||
|
||||
- {meth}`.Interface.get_dtype`
|
||||
- {meth}`.Interface.validate_dtype`
|
||||
- {meth}`.Interface.raise_for_dtype`
|
||||
|
||||
A halftime hook for modifying the array or bailing early between validation phases.
|
||||
|
||||
- {meth}`.Interface.after_validate_dtype`
|
||||
|
||||
A cluster of methods for validating shape, similar to the dtype cluster.
|
||||
|
||||
- {meth}`.Interface.get_shape`
|
||||
- {meth}`.Interface.validate_shape`
|
||||
- {meth}`.Interface.raise_for_shape`
|
||||
|
||||
A final hook for modifying the array before passing it to be assigned to the field.
|
||||
This method should return an object matching the interface's {attr}`~.Interface.return_type`.
|
||||
- {meth}`.Interface.after_validation`
|
||||
|
||||
## Diagram
|
||||
|
||||
|
|
|
@ -40,12 +40,29 @@ class Interface(ABC, Generic[T]):
|
|||
|
||||
Calls the methods, in order:
|
||||
|
||||
* :meth:`.before_validation`
|
||||
* :meth:`.validate_dtype`
|
||||
* :meth:`.validate_shape`
|
||||
* :meth:`.after_validation`
|
||||
* array = :meth:`.before_validation` (array)
|
||||
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
||||
override if eg. the dtype is not contained in ``array.dtype``
|
||||
* valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
|
||||
the one in the NDArray specification. Override if special
|
||||
validation logic is needed for a given format
|
||||
* :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
|
||||
raise an exception if it was invalid. Override to implement custom
|
||||
exceptions or error conditions, or make validation errors conditional.
|
||||
* array = :meth:`.after_validate_dtype` (array) - hook for additional
|
||||
validation or array modification mid-validation
|
||||
* shape = :meth:`.get_shape` (array) - get the shape from the array,
|
||||
override if eg. the shape is not contained in ``array.shape``
|
||||
* valid = :meth:`.validate_shape` (shape) - check that the shape matches
|
||||
the one in the NDArray specification. Override if special validation
|
||||
logic is needed.
|
||||
* :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
|
||||
raise an exception if it was invalid. You know the deal bc it's the same
|
||||
as raise for dtype.
|
||||
* :meth:`.after_validation` - hook after validation for modifying the array
|
||||
that is set as the model field value
|
||||
|
||||
passing the ``array`` argument and returning it from each.
|
||||
Follow the method signatures and return types to override
|
||||
|
||||
Implementing an interface subclass largely consists of overriding these methods
|
||||
as needed.
|
||||
|
@ -58,8 +75,16 @@ class Interface(ABC, Generic[T]):
|
|||
of :class:`.InterfaceError` )
|
||||
"""
|
||||
array = self.before_validation(array)
|
||||
array = self.validate_dtype(array)
|
||||
array = self.validate_shape(array)
|
||||
|
||||
dtype = self.get_dtype(array)
|
||||
dtype_valid = self.validate_dtype(dtype)
|
||||
self.raise_for_dtype(dtype_valid, dtype)
|
||||
array = self.after_validate_dtype(array)
|
||||
|
||||
shape = self.get_shape(array)
|
||||
shape_valid = self.validate_shape(shape)
|
||||
self.raise_for_shape(shape_valid, shape)
|
||||
|
||||
array = self.after_validation(array)
|
||||
return array
|
||||
|
||||
|
@ -72,40 +97,76 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
return array
|
||||
|
||||
def validate_dtype(self, array: NDArrayType) -> NDArrayType:
|
||||
def get_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""
|
||||
Validate the dtype of the given array, returning it unmutated.
|
||||
Get the dtype from the input array
|
||||
"""
|
||||
return array.dtype
|
||||
|
||||
def validate_dtype(self, dtype: DtypeType) -> bool:
|
||||
"""
|
||||
Validate the dtype of the given array, returning
|
||||
``True`` if valid, ``False`` if not.
|
||||
|
||||
|
||||
"""
|
||||
if self.dtype is Any:
|
||||
return True
|
||||
|
||||
if isinstance(self.dtype, tuple):
|
||||
valid = dtype in self.dtype
|
||||
elif self.dtype is np.str_:
|
||||
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
|
||||
else:
|
||||
valid = dtype == self.dtype
|
||||
return valid
|
||||
|
||||
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
||||
"""
|
||||
After validating, raise an exception if invalid
|
||||
Raises:
|
||||
:class:`~numpydantic.exceptions.DtypeError`
|
||||
"""
|
||||
if self.dtype is Any:
|
||||
return array
|
||||
|
||||
if isinstance(self.dtype, tuple):
|
||||
valid = array.dtype in self.dtype
|
||||
else:
|
||||
valid = array.dtype == self.dtype
|
||||
|
||||
if not valid:
|
||||
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
|
||||
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
|
||||
|
||||
def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
|
||||
"""
|
||||
Hook to modify array after validating dtype.
|
||||
Default is a no-op.
|
||||
"""
|
||||
return array
|
||||
|
||||
def validate_shape(self, array: NDArrayType) -> NDArrayType:
|
||||
def get_shape(self, array: NDArrayType) -> Tuple[int, ...]:
|
||||
"""
|
||||
Validate the shape of the given array, returning it unmutated
|
||||
Get the shape from the array as a tuple of integers
|
||||
"""
|
||||
return array.shape
|
||||
|
||||
def validate_shape(self, shape: Tuple[int, ...]) -> bool:
|
||||
"""
|
||||
Validate the shape of the given array against the shape
|
||||
specifier, returning ``True`` if valid, ``False`` if not.
|
||||
|
||||
|
||||
"""
|
||||
if self.shape is Any:
|
||||
return True
|
||||
|
||||
return check_shape(shape, self.shape)
|
||||
|
||||
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
|
||||
"""
|
||||
Raise a ShapeError if the shape is invalid.
|
||||
|
||||
Raises:
|
||||
:class:`~numpydantic.exceptions.ShapeError`
|
||||
"""
|
||||
if self.shape is Any:
|
||||
return array
|
||||
if not check_shape(array.shape, self.shape):
|
||||
if not valid:
|
||||
raise ShapeError(
|
||||
f"Invalid shape! expected shape {self.shape.prepared_args}, "
|
||||
f"got shape {array.shape}"
|
||||
f"got shape {shape}"
|
||||
)
|
||||
return array
|
||||
|
||||
def after_validation(self, array: NDArrayType) -> T:
|
||||
"""
|
||||
|
|
|
@ -7,18 +7,22 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.types import DtypeType
|
||||
|
||||
try:
|
||||
import zarr
|
||||
from numcodecs import VLenUTF8
|
||||
from zarr.core import Array as ZarrArray
|
||||
from zarr.storage import StoreLike
|
||||
except ImportError: # pragma: no cover
|
||||
ZarrArray = None
|
||||
StoreLike = None
|
||||
storage = None
|
||||
VLenUTF8 = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -113,6 +117,19 @@ class ZarrInterface(Interface):
|
|||
"""
|
||||
return self._get_array(array)
|
||||
|
||||
def get_dtype(self, array: ZarrArray) -> DtypeType:
|
||||
"""
|
||||
Override base dtype getter to handle zarr's string-as-object encoding.
|
||||
"""
|
||||
if (
|
||||
getattr(array.dtype, "type", None) is np.object_
|
||||
and array.filters
|
||||
and any([isinstance(f, VLenUTF8) for f in array.filters])
|
||||
):
|
||||
return np.str_
|
||||
else:
|
||||
return array.dtype
|
||||
|
||||
@classmethod
|
||||
def to_json(
|
||||
cls,
|
||||
|
|
|
@ -67,6 +67,7 @@ RGB_UNION: TypeAlias = Union[
|
|||
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
||||
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
|
@ -121,10 +122,15 @@ def shape_cases(request) -> ValidationCase:
|
|||
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
||||
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
||||
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
||||
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
|
||||
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
||||
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
||||
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
||||
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
||||
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
|
||||
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
||||
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
||||
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
||||
],
|
||||
ids=[
|
||||
"float",
|
||||
|
@ -139,10 +145,15 @@ def shape_cases(request) -> ValidationCase:
|
|||
"integer-uint8",
|
||||
"integer-float",
|
||||
"integer-float32",
|
||||
"integer-str",
|
||||
"float-float",
|
||||
"float-float32",
|
||||
"float-int",
|
||||
"float-uint8",
|
||||
"float-str",
|
||||
"str-str",
|
||||
"str-int",
|
||||
"str-float",
|
||||
],
|
||||
)
|
||||
def dtype_cases(request) -> ValidationCase:
|
||||
|
|
Loading…
Reference in a new issue