Merge pull request #5 from p2p-ld/dtype-str

Better string dtype handling
This commit is contained in:
Jonny Saunders 2024-08-05 19:55:20 -07:00 committed by GitHub
commit 32db88fc1b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 158 additions and 36 deletions

View file

@ -38,7 +38,6 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install dependencies - name: Install dependencies
run: pip install -e ".[tests]" run: pip install -e ".[tests]"

View file

@ -2,6 +2,17 @@
## 1.* ## 1.*
### 1.3.0 - 24-08-05 - Better string dtype handling
API Changes:
- Split apart the validation methods into smaller chunks to better support
overrides by interfaces. Customize getting and raising errors for dtype and shape,
as well as separation of concerns between getting, validating, and raising.
Bugfix:
- [#4](https://github.com/p2p-ld/numpydantic/issues/4) - Support dtype checking
for strings in zarr and numpy arrays
### 1.2.3 - 24-07-31 - Vendor `nptyping` ### 1.2.3 - 24-07-31 - Vendor `nptyping`
`nptyping` vendored into `numpydantic.vendor.nptyping` - `nptyping` vendored into `numpydantic.vendor.nptyping` -

View file

@ -41,19 +41,42 @@ when cast to an `ndarray`, we only try as a last resort.
## Validation ## Validation
Validation is a chain of lifecycle methods, with a single argument passed and returned Validation is a chain of lifecycle methods, each of which can be overridden
to and from each: for interfaces to implement custom behavior that matches the array format.
{meth}`.Interface.validate` calls in order: {meth}`.Interface.validate` calls the following methods, in order:
An initial hook for modifying the input data before validation, eg.
if it needs to be coerced or wrapped in some proxy class. This method
should accept all and only the types specified in that interface's
{attr}`~.Interface.input_types`.
- {meth}`.Interface.before_validation` - {meth}`.Interface.before_validation`
- {meth}`.Interface.validate_dtype`
- {meth}`.Interface.validate_shape`
- {meth}`.Interface.after_validation`
The `before` and `after` methods provide hooks for coercion, loading, etc. such that A cluster of methods for validating dtype.
`validate` can accept one of the types in the interface's Separating these methods allow for array formats that store dtype information
{attr}`~.Interface.input_types` and return the {attr}`~.Interface.return_type` . in a nonstandard attribute, require additional coercion, or for implementing
custom exception handlers or rescuers.
Check the method signatures and return types
when overriding and the docstrings for details.
- {meth}`.Interface.get_dtype`
- {meth}`.Interface.validate_dtype`
- {meth}`.Interface.raise_for_dtype`
A halftime hook for modifying the array or bailing early between validation phases.
- {meth}`.Interface.after_validate_dtype`
A cluster of methods for validating shape, similar to the dtype cluster.
- {meth}`.Interface.get_shape`
- {meth}`.Interface.validate_shape`
- {meth}`.Interface.raise_for_shape`
A final hook for modifying the array before passing it to be assigned to the field.
This method should return an object matching the interface's {attr}`~.Interface.return_type`.
- {meth}`.Interface.after_validation`
## Diagram ## Diagram

View file

@ -1,6 +1,6 @@
[project] [project]
name = "numpydantic" name = "numpydantic"
version = "1.2.3" version = "1.3.0"
description = "Type and shape validation and serialization for numpy arrays in pydantic models" description = "Type and shape validation and serialization for numpy arrays in pydantic models"
authors = [ authors = [
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},

View file

@ -40,12 +40,29 @@ class Interface(ABC, Generic[T]):
Calls the methods, in order: Calls the methods, in order:
* :meth:`.before_validation` * array = :meth:`.before_validation` (array)
* :meth:`.validate_dtype` * dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
* :meth:`.validate_shape` override if eg. the dtype is not contained in ``array.dtype``
* :meth:`.after_validation` * valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches
the one in the NDArray specification. Override if special
validation logic is needed for a given format
* :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity,
raise an exception if it was invalid. Override to implement custom
exceptions or error conditions, or make validation errors conditional.
* array = :meth:`.after_validate_dtype` (array) - hook for additional
validation or array modification mid-validation
* shape = :meth:`.get_shape` (array) - get the shape from the array,
override if eg. the shape is not contained in ``array.shape``
* valid = :meth:`.validate_shape` (shape) - check that the shape matches
the one in the NDArray specification. Override if special validation
logic is needed.
* :meth:`.raise_for_shape` (valid, shape) - after checking shape validity,
raise an exception if it was invalid. You know the deal bc it's the same
as raise for dtype.
* :meth:`.after_validation` - hook after validation for modifying the array
that is set as the model field value
passing the ``array`` argument and returning it from each. Follow the method signatures and return types to override.
Implementing an interface subclass largely consists of overriding these methods Implementing an interface subclass largely consists of overriding these methods
as needed. as needed.
@ -58,8 +75,16 @@ class Interface(ABC, Generic[T]):
of :class:`.InterfaceError` ) of :class:`.InterfaceError` )
""" """
array = self.before_validation(array) array = self.before_validation(array)
array = self.validate_dtype(array)
array = self.validate_shape(array) dtype = self.get_dtype(array)
dtype_valid = self.validate_dtype(dtype)
self.raise_for_dtype(dtype_valid, dtype)
array = self.after_validate_dtype(array)
shape = self.get_shape(array)
shape_valid = self.validate_shape(shape)
self.raise_for_shape(shape_valid, shape)
array = self.after_validation(array) array = self.after_validation(array)
return array return array
@ -72,40 +97,76 @@ class Interface(ABC, Generic[T]):
""" """
return array return array
def validate_dtype(self, array: NDArrayType) -> NDArrayType: def get_dtype(self, array: NDArrayType) -> DtypeType:
""" """
Validate the dtype of the given array, returning it unmutated. Get the dtype from the input array
"""
return array.dtype
def validate_dtype(self, dtype: DtypeType) -> bool:
"""
Validate the dtype of the given array, returning
``True`` if valid, ``False`` if not.
"""
if self.dtype is Any:
return True
if isinstance(self.dtype, tuple):
valid = dtype in self.dtype
elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_
else:
valid = dtype == self.dtype
return valid
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
"""
After validating, raise an exception if invalid
Raises: Raises:
:class:`~numpydantic.exceptions.DtypeError` :class:`~numpydantic.exceptions.DtypeError`
""" """
if self.dtype is Any:
return array
if isinstance(self.dtype, tuple):
valid = array.dtype in self.dtype
else:
valid = array.dtype == self.dtype
if not valid: if not valid:
raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}") raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}")
def after_validate_dtype(self, array: NDArrayType) -> NDArrayType:
"""
Hook to modify array after validating dtype.
Default is a no-op.
"""
return array return array
def validate_shape(self, array: NDArrayType) -> NDArrayType: def get_shape(self, array: NDArrayType) -> Tuple[int, ...]:
""" """
Validate the shape of the given array, returning it unmutated Get the shape from the array as a tuple of integers
"""
return array.shape
def validate_shape(self, shape: Tuple[int, ...]) -> bool:
"""
Validate the shape of the given array against the shape
specifier, returning ``True`` if valid, ``False`` if not.
"""
if self.shape is Any:
return True
return check_shape(shape, self.shape)
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
"""
Raise a ShapeError if the shape is invalid.
Raises: Raises:
:class:`~numpydantic.exceptions.ShapeError` :class:`~numpydantic.exceptions.ShapeError`
""" """
if self.shape is Any: if not valid:
return array
if not check_shape(array.shape, self.shape):
raise ShapeError( raise ShapeError(
f"Invalid shape! expected shape {self.shape.prepared_args}, " f"Invalid shape! expected shape {self.shape.prepared_args}, "
f"got shape {array.shape}" f"got shape {shape}"
) )
return array
def after_validation(self, array: NDArrayType) -> T: def after_validation(self, array: NDArrayType) -> T:
""" """

View file

@ -7,18 +7,22 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Sequence, Union from typing import Any, Optional, Sequence, Union
import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface
from numpydantic.types import DtypeType
try: try:
import zarr import zarr
from numcodecs import VLenUTF8
from zarr.core import Array as ZarrArray from zarr.core import Array as ZarrArray
from zarr.storage import StoreLike from zarr.storage import StoreLike
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
ZarrArray = None ZarrArray = None
StoreLike = None StoreLike = None
storage = None storage = None
VLenUTF8 = None
@dataclass @dataclass
@ -113,6 +117,19 @@ class ZarrInterface(Interface):
""" """
return self._get_array(array) return self._get_array(array)
def get_dtype(self, array: ZarrArray) -> DtypeType:
"""
Override base dtype getter to handle zarr's string-as-object encoding.
"""
if (
getattr(array.dtype, "type", None) is np.object_
and array.filters
and any([isinstance(f, VLenUTF8) for f in array.filters])
):
return np.str_
else:
return array.dtype
@classmethod @classmethod
def to_json( def to_json(
cls, cls,

View file

@ -67,6 +67,7 @@ RGB_UNION: TypeAlias = Union[
NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number] NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
@pytest.fixture( @pytest.fixture(
@ -121,10 +122,15 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True), ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
ValidationCase(annotation=INTEGER, dtype=float, passes=False), ValidationCase(annotation=INTEGER, dtype=float, passes=False),
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False), ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
ValidationCase(annotation=FLOAT, dtype=float, passes=True), ValidationCase(annotation=FLOAT, dtype=float, passes=True),
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True), ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
ValidationCase(annotation=FLOAT, dtype=int, passes=False), ValidationCase(annotation=FLOAT, dtype=int, passes=False),
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False), ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
ValidationCase(annotation=STRING, dtype=str, passes=True),
ValidationCase(annotation=STRING, dtype=int, passes=False),
ValidationCase(annotation=STRING, dtype=float, passes=False),
], ],
ids=[ ids=[
"float", "float",
@ -139,10 +145,15 @@ def shape_cases(request) -> ValidationCase:
"integer-uint8", "integer-uint8",
"integer-float", "integer-float",
"integer-float32", "integer-float32",
"integer-str",
"float-float", "float-float",
"float-float32", "float-float32",
"float-int", "float-int",
"float-uint8", "float-uint8",
"float-str",
"str-str",
"str-int",
"str-float",
], ],
) )
def dtype_cases(request) -> ValidationCase: def dtype_cases(request) -> ValidationCase: