diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 63979cb..d08c914 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,7 +38,6 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - name: Install dependencies run: pip install -e ".[tests]" diff --git a/docs/changelog.md b/docs/changelog.md index 2df9c76..8a0a566 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,17 @@ ## 1.* +### 1.3.0 - 24-08-05 - Better string dtype handling + +API Changes: +- Split apart the validation methods into smaller chunks to better support + overrides by interfaces. Customize getting and raising errors for dtype and shape, + as well as separation of concerns between getting, validating, and raising. + +Bugfix: +- [#4](https://github.com/p2p-ld/numpydantic/issues/4) - Support dtype checking + for strings in zarr and numpy arrays + ### 1.2.3 - 24-07-31 - Vendor `nptyping` `nptyping` vendored into `numpydantic.vendor.nptyping` - diff --git a/docs/interfaces.md b/docs/interfaces.md index 7ed4154..c4d873d 100644 --- a/docs/interfaces.md +++ b/docs/interfaces.md @@ -41,19 +41,42 @@ when cast to an `ndarray`, we only try as a last resort. ## Validation -Validation is a chain of lifecycle methods, with a single argument passed and returned -to and from each: +Validation is a chain of lifecycle methods, each of which can be overridden +for interfaces to implement custom behavior that matches the array format. -{meth}`.Interface.validate` calls in order: +{meth}`.Interface.validate` calls the following methods, in order: + +An initial hook for modifying the input data before validation, eg. +if it needs to be coerced or wrapped in some proxy class. This method +should accept all and only the types specified in that interface's +{attr}`~.Interface.input_types`. - {meth}`.Interface.before_validation` -- {meth}`.Interface.validate_dtype` -- {meth}`.Interface.validate_shape` -- {meth}`.Interface.after_validation` -The `before` and `after` methods provide hooks for coercion, loading, etc. such that -`validate` can accept one of the types in the interface's -{attr}`~.Interface.input_types` and return the {attr}`~.Interface.return_type` . +A cluster of methods for validating dtype. +Separating these methods allow for array formats that store dtype information +in a nonstandard attribute, require additional coercion, or for implementing +custom exception handlers or rescuers. +Check the method signatures and return types +when overriding and the docstrings for details. + +- {meth}`.Interface.get_dtype` +- {meth}`.Interface.validate_dtype` +- {meth}`.Interface.raise_for_dtype` + +A halftime hook for modifying the array or bailing early between validation phases. + +- {meth}`.Interface.after_validate_dtype` + +A cluster of methods for validating shape, similar to the dtype cluster. + +- {meth}`.Interface.get_shape` +- {meth}`.Interface.validate_shape` +- {meth}`.Interface.raise_for_shape` + +A final hook for modifying the array before passing it to be assigned to the field. +This method should return an object matching the interface's {attr}`~.Interface.return_type`. +- {meth}`.Interface.after_validation` ## Diagram diff --git a/pyproject.toml b/pyproject.toml index 3da4cd6..3a1ebf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "numpydantic" -version = "1.2.3" +version = "1.3.0" description = "Type and shape validation and serialization for numpy arrays in pydantic models" authors = [ {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index 360b7f0..832fe83 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -40,12 +40,29 @@ class Interface(ABC, Generic[T]): Calls the methods, in order: - * :meth:`.before_validation` - * :meth:`.validate_dtype` - * :meth:`.validate_shape` - * :meth:`.after_validation` + * array = :meth:`.before_validation` (array) + * dtype = :meth:`.get_dtype` (array) - get the dtype from the array, + override if eg. the dtype is not contained in ``array.dtype`` + * valid = :meth:`.validate_dtype` (dtype) - check that the dtype matches + the one in the NDArray specification. Override if special + validation logic is needed for a given format + * :meth:`.raise_for_dtype` (valid, dtype) - after checking dtype validity, + raise an exception if it was invalid. Override to implement custom + exceptions or error conditions, or make validation errors conditional. + * array = :meth:`.after_validate_dtype` (array) - hook for additional + validation or array modification mid-validation + * shape = :meth:`.get_shape` (array) - get the shape from the array, + override if eg. the shape is not contained in ``array.shape`` + * valid = :meth:`.validate_shape` (shape) - check that the shape matches + the one in the NDArray specification. Override if special validation + logic is needed. + * :meth:`.raise_for_shape` (valid, shape) - after checking shape validity, + raise an exception if it was invalid. You know the deal bc it's the same + as raise for dtype. + * :meth:`.after_validation` - hook after validation for modifying the array + that is set as the model field value - passing the ``array`` argument and returning it from each. + Follow the method signatures and return types to override. Implementing an interface subclass largely consists of overriding these methods as needed. @@ -58,8 +75,16 @@ class Interface(ABC, Generic[T]): of :class:`.InterfaceError` ) """ array = self.before_validation(array) - array = self.validate_dtype(array) - array = self.validate_shape(array) + + dtype = self.get_dtype(array) + dtype_valid = self.validate_dtype(dtype) + self.raise_for_dtype(dtype_valid, dtype) + array = self.after_validate_dtype(array) + + shape = self.get_shape(array) + shape_valid = self.validate_shape(shape) + self.raise_for_shape(shape_valid, shape) + array = self.after_validation(array) return array @@ -72,40 +97,76 @@ class Interface(ABC, Generic[T]): """ return array - def validate_dtype(self, array: NDArrayType) -> NDArrayType: + def get_dtype(self, array: NDArrayType) -> DtypeType: """ - Validate the dtype of the given array, returning it unmutated. + Get the dtype from the input array + """ + return array.dtype + def validate_dtype(self, dtype: DtypeType) -> bool: + """ + Validate the dtype of the given array, returning + ``True`` if valid, ``False`` if not. + + + """ + if self.dtype is Any: + return True + + if isinstance(self.dtype, tuple): + valid = dtype in self.dtype + elif self.dtype is np.str_: + valid = getattr(dtype, "type", None) is np.str_ or dtype is np.str_ + else: + valid = dtype == self.dtype + return valid + + def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: + """ + After validating, raise an exception if invalid Raises: :class:`~numpydantic.exceptions.DtypeError` """ - if self.dtype is Any: - return array - - if isinstance(self.dtype, tuple): - valid = array.dtype in self.dtype - else: - valid = array.dtype == self.dtype - if not valid: - raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {array.dtype}") + raise DtypeError(f"Invalid dtype! expected {self.dtype}, got {dtype}") + + def after_validate_dtype(self, array: NDArrayType) -> NDArrayType: + """ + Hook to modify array after validating dtype. + Default is a no-op. + """ return array - def validate_shape(self, array: NDArrayType) -> NDArrayType: + def get_shape(self, array: NDArrayType) -> Tuple[int, ...]: """ - Validate the shape of the given array, returning it unmutated + Get the shape from the array as a tuple of integers + """ + return array.shape + + def validate_shape(self, shape: Tuple[int, ...]) -> bool: + """ + Validate the shape of the given array against the shape + specifier, returning ``True`` if valid, ``False`` if not. + + + """ + if self.shape is Any: + return True + + return check_shape(shape, self.shape) + + def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None: + """ + Raise a ShapeError if the shape is invalid. Raises: :class:`~numpydantic.exceptions.ShapeError` """ - if self.shape is Any: - return array - if not check_shape(array.shape, self.shape): + if not valid: raise ShapeError( f"Invalid shape! expected shape {self.shape.prepared_args}, " - f"got shape {array.shape}" + f"got shape {shape}" ) - return array def after_validation(self, array: NDArrayType) -> T: """ diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index 5a300e6..87f538a 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -7,18 +7,22 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Optional, Sequence, Union +import numpy as np from pydantic import SerializationInfo from numpydantic.interface.interface import Interface +from numpydantic.types import DtypeType try: import zarr + from numcodecs import VLenUTF8 from zarr.core import Array as ZarrArray from zarr.storage import StoreLike except ImportError: # pragma: no cover ZarrArray = None StoreLike = None storage = None + VLenUTF8 = None @dataclass @@ -113,6 +117,19 @@ class ZarrInterface(Interface): """ return self._get_array(array) + def get_dtype(self, array: ZarrArray) -> DtypeType: + """ + Override base dtype getter to handle zarr's string-as-object encoding. + """ + if ( + getattr(array.dtype, "type", None) is np.object_ + and array.filters + and any([isinstance(f, VLenUTF8) for f in array.filters]) + ): + return np.str_ + else: + return array.dtype + @classmethod def to_json( cls, diff --git a/tests/conftest.py b/tests/conftest.py index 4687635..2292dd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -67,6 +67,7 @@ RGB_UNION: TypeAlias = Union[ NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number] INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] +STRING: TypeAlias = NDArray[Shape["*, *, *"], str] @pytest.fixture( @@ -121,10 +122,15 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True), ValidationCase(annotation=INTEGER, dtype=float, passes=False), ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False), + ValidationCase(annotation=INTEGER, dtype=str, passes=False), ValidationCase(annotation=FLOAT, dtype=float, passes=True), ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True), ValidationCase(annotation=FLOAT, dtype=int, passes=False), ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False), + ValidationCase(annotation=FLOAT, dtype=str, passes=False), + ValidationCase(annotation=STRING, dtype=str, passes=True), + ValidationCase(annotation=STRING, dtype=int, passes=False), + ValidationCase(annotation=STRING, dtype=float, passes=False), ], ids=[ "float", @@ -139,10 +145,15 @@ def shape_cases(request) -> ValidationCase: "integer-uint8", "integer-float", "integer-float32", + "integer-str", "float-float", "float-float32", "float-int", "float-uint8", + "float-str", + "str-str", + "str-int", + "str-float", ], ) def dtype_cases(request) -> ValidationCase: