Better string dtype checking support, restructuring the validation hooks to allow finer grained control over the process.

This commit is contained in:
sneakers-the-rat 2024-08-05 19:43:15 -07:00
parent 880dafb151
commit b2db1014bd
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
4 changed files with 146 additions and 34 deletions

View file

@ -41,19 +41,42 @@ when cast to an `ndarray`, we only try as a last resort.
## Validation
Validation is a chain of lifecycle methods, with a single argument passed and returned
to and from each:
Validation is a chain of lifecycle methods, each of which can be overridden
for interfaces to implement custom behavior that matches the array format.
{meth}`.Interface.validate` calls in order:
{meth}`.Interface.validate` calls the following methods, in order:
An initial hook for modifying the input data before validation, eg.
if it needs to be coerced or wrapped in some proxy class. This method
should accept all and only the types specified in that interface's
{attr}`~.Interface.input_types`.
- {meth}`.Interface.before_validation`
- {meth}`.Interface.validate_dtype`
- {meth}`.Interface.validate_shape`
- {meth}`.Interface.after_validation`
The `before` and `after` methods provide hooks for coercion, loading, etc. such that
`validate` can accept one of the types in the interface's
{attr}`~.Interface.input_types` and return the {attr}`~.Interface.return_type` .
A cluster of methods for validating dtype.
Separating these methods allow for array formats that store dtype information
in a nonstandard attribute, require additional coercion, or for implementing
custom exception handlers or rescuers.
Check the method signatures and return types
when overriding and the docstrings for details.
- {meth}`.Interface.get_dtype`
- {meth}`.Interface.validate_dtype`
- {meth}`.Interface.raise_for_dtype`
A halftime hook for modifying the array or bailing early between validation phases.
- {meth}`.Interface.after_validate_dtype`
A cluster of methods for validating shape, similar to the dtype cluster.
- {meth}`.Interface.get_shape`
- {meth}`.Interface.validate_shape`
- {meth}`.Interface.raise_for_shape`
A final hook for modifying the array before passing it to be assigned to the field.
This method should return an object matching the interface's {attr}`~.Interface.return_type`.
- {meth}`.Interface.after_validation`
## Diagram

View file

@ -40,12 +40,29 @@ class Interface(ABC, Generic[T]):
Calls the methods, in order:
* :meth:`.before_validation`
* :meth:`.validate_dtype`
* :meth:`.validate_shape`
* :meth:`.after_validation`
* array = :meth:`.before_validation` (array)
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
override if eg. the dtype is not contained in ``array.dtype``
* valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
the one in the NDArray specification. Override if special
validation logic is needed for a given format
* :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
raise an exception if it was invalid. Override to implement custom
exceptions or error conditions, or make validation errors conditional.
* array = :meth:`.after_validate_dtype` (array) - hook for additional
validation or array modification mid-validation
* shape = :meth:`.get_shape` (array) - get the shape from the array,
override if eg. the shape is not contained in ``array.shape``
* valid = :meth:`.validate_shape` (shape) - check that the shape matches
the one in the NDArray specification. Override if special validation
logic is needed.
* :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
raise an exception if it was invalid. You know the deal bc it's the same
as raise for dtype.
* :meth:`.after_validation` - hook after validation for modifying the array
that is set as the model field value
passing the ``array`` argument and returning it from each.
Follow the method signatures and return types to override
Implementing an interface subclass largely consists of overriding these methods
as needed.
@ -58,8 +75,16 @@ class Interface(ABC, Generic[T]):
of :class:`.InterfaceError` )
"""
array = self.before_validation(array)
array = self.validate_dtype(array)
array = self.validate_shape(array)
dtype = self.get_dtype(array)
dtype_valid = self.validate_dtype(dtype)
self.raise_for_dtype(dtype_valid, dtype)
array = self.after_validate_dtype(array)
shape = self.get_shape(array)
shape_valid = self.validate_shape(shape)
self.raise_for_shape(shape_valid, shape)
array = self.after_validation(array)
return array
@ -72,40 +97,76 @@ class Interface(ABC, Generic[T]):
"""
return array
def validate_dtype(self, array: NDArrayType) -> NDArrayType:
def get_dtype(self, array: NDArrayType) -> DtypeType:
"""
Validate the dtype of the given array, returning it unmutated.
Get the dtype from the input array
"""
return array.dtype
def validate_dtype(self, dtype: DtypeType) -> bool:
"""
Validate the dtype of the given array, returning
``True`` if valid, ``False`` if not.
"""
if self.dtype is Any:
return True
if isinstance(self.dtype, tuple):
valid = dtype in self.dtype
elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
else:
valid = dtype == self.dtype
return valid
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
"""
After validating, raise an exception if invalid
Raises:
:class:`~numpydantic.exceptions.DtypeError`
"""
if self.dtype is Any:
return array
if isinstance(self.dtype, tuple):
valid = array.dtype in self.dtype
else:
valid = array.dtype == self.dtype
if not valid:
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}")
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
"""
Hook to modify array after validating dtype.
Default is a no-op.
"""
return array
def validate_shape(self, array: NDArrayType) -> NDArrayType:
def get_shape(self, array: NDArrayType) -> Tuple[int, ...]:
"""
Validate the shape of the given array, returning it unmutated
Get the shape from the array as a tuple of integers
"""
return array.shape
def validate_shape(self, shape: Tuple[int, ...]) -> bool:
"""
Validate the shape of the given array against the shape
specifier, returning ``True`` if valid, ``False`` if not.
"""
if self.shape is Any:
return True
return check_shape(shape, self.shape)
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
"""
Raise a ShapeError if the shape is invalid.
Raises:
:class:`~numpydantic.exceptions.ShapeError`
"""
if self.shape is Any:
return array
if not check_shape(array.shape, self.shape):
if not valid:
raise ShapeError(
f"Invalid shape! expected shape {self.shape.prepared_args}, "
f"got shape {array.shape}"
f"got shape {shape}"
)
return array
def after_validation(self, array: NDArrayType) -> T:
"""

View file

@ -7,18 +7,22 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Sequence, Union
import numpy as np
from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface
from numpydantic.types import DtypeType
try:
import zarr
from numcodecs import VLenUTF8
from zarr.core import Array as ZarrArray
from zarr.storage import StoreLike
except ImportError: # pragma: no cover
ZarrArray = None
StoreLike = None
storage = None
VLenUTF8 = None
@dataclass
@ -113,6 +117,19 @@ class ZarrInterface(Interface):
"""
return self._get_array(array)
def get_dtype(self, array: ZarrArray) -> DtypeType:
"""
Override base dtype getter to handle zarr's string-as-object encoding.
"""
if (
getattr(array.dtype, "type", None) is np.object_
and array.filters
and any([isinstance(f, VLenUTF8) for f in array.filters])
):
return np.str_
else:
return array.dtype
@classmethod
def to_json(
cls,

View file

@ -67,6 +67,7 @@ RGB_UNION: TypeAlias = Union[
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
@pytest.fixture(
@ -121,10 +122,15 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
ValidationCase(annotation=STRING, dtype=str, passes=True),
ValidationCase(annotation=STRING, dtype=int, passes=False),
ValidationCase(annotation=STRING, dtype=float, passes=False),
],
ids=[
"float",
@ -139,10 +145,15 @@ def shape_cases(request) -> ValidationCase:
"integer-uint8",
"integer-float",
"integer-float32",
"integer-str",
"float-float",
"float-float32",
"float-int",
"float-uint8",
"float-str",
"str-str",
"str-int",
"str-float",
],
)
def dtype_cases(request) -> ValidationCase: