refactor validation into separate module, perf improvements from not iterating over every element of array, rename array to value universally in serialized json lol

This commit is contained in:
sneakers-the-rat 2024-09-23 23:25:20 -07:00
parent 7f2c79bbae
commit 85cef50603
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
19 changed files with 196 additions and 57 deletions

4
.gitignore vendored
View file

@ -159,4 +159,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
.pdm-python .pdm-python
ndarray.pyi ndarray.pyi
prof/

View file

@ -0,0 +1,7 @@
# dtype
```{eval-rst}
.. automodule:: numpydantic.validation.dtype
:members:
:undoc-members:
```

View file

@ -0,0 +1,6 @@
# validation
```{toctree}
dtype
shape
```

View file

@ -1,7 +1,7 @@
# shape # shape
```{eval-rst} ```{eval-rst}
.. automodule:: numpydantic.shape .. automodule:: numpydantic.validation.shape
:members: :members:
:undoc-members: :undoc-members:
``` ```

View file

@ -484,13 +484,13 @@ interfaces
api/index api/index
api/interface/index api/interface/index
api/validation/index
api/dtype api/dtype
api/ndarray api/ndarray
api/maps api/maps
api/meta api/meta
api/schema api/schema
api/serialization api/serialization
api/shape
api/types api/types
``` ```

View file

@ -4,7 +4,7 @@
from numpydantic.ndarray import NDArray from numpydantic.ndarray import NDArray
from numpydantic.meta import update_ndarray_stub from numpydantic.meta import update_ndarray_stub
from numpydantic.shape import Shape from numpydantic.validation.shape import Shape
update_ndarray_stub() update_ndarray_stub()

View file

@ -12,6 +12,9 @@ module, rather than needing to import each individually.
Some types like `Integer` are compound types - tuples of multiple dtypes. Some types like `Integer` are compound types - tuples of multiple dtypes.
Check these using ``in`` rather than ``==``. This interface will develop in future Check these using ``in`` rather than ``==``. This interface will develop in future
versions to allow a single dtype check. versions to allow a single dtype check.
For internal helper functions for validating dtype,
see :mod:`numpydantic.validation.dtype`
""" """
import sys import sys

View file

@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict):
name: str name: str
chunks: Iterable[tuple[int, ...]] chunks: Iterable[tuple[int, ...]]
dtype: str dtype: str
array: list value: list
def to_array_input(self) -> DaskArray: def to_array_input(self) -> DaskArray:
"""Construct a dask array""" """Construct a dask array"""
np_array = np.array(self.array, dtype=self.dtype) np_array = np.array(self.value, dtype=self.dtype)
array = from_array( array = from_array(
np_array, np_array,
name=self.name, name=self.name,
@ -100,7 +100,7 @@ class DaskInterface(Interface):
if info.round_trip: if info.round_trip:
as_json = DaskJsonDict( as_json = DaskJsonDict(
type=cls.name, type=cls.name,
array=as_json, value=as_json,
name=array.name, name=array.name,
chunks=array.chunks, chunks=array.chunks,
dtype=str(np_array.dtype), dtype=str(np_array.dtype),

View file

@ -20,8 +20,8 @@ from numpydantic.exceptions import (
ShapeError, ShapeError,
TooManyMatchesError, TooManyMatchesError,
) )
from numpydantic.shape import check_shape
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation import validate_dtype, validate_shape
T = TypeVar("T", bound=NDArrayType) T = TypeVar("T", bound=NDArrayType)
U = TypeVar("U", bound="JsonDict") U = TypeVar("U", bound="JsonDict")
@ -76,6 +76,21 @@ class InterfaceMark(BaseModel):
class JsonDict(BaseModel): class JsonDict(BaseModel):
""" """
Representation of array when dumped with round_trip == True. Representation of array when dumped with round_trip == True.
.. admonition:: Developer's Note
Any JsonDict that contains an actual array should be named ``value``
rather than array (or any other name), and nothing but the
array data should be named ``value`` .
During JSON serialization, it becomes ambiguous what contains an array
of data vs. an array of metadata. For the moment we would like to
reserve the ability to have lists of metadata, so until we rule that out,
we would like to be able to avoid iterating over every element of an array
in any context parameter transformation like relativizing/absolutizing paths.
To avoid that, it's good to agree on a single value name -- ``value`` --
and avoid using it for anything else.
""" """
type: str type: str
@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]):
Validate the dtype of the given array, returning Validate the dtype of the given array, returning
``True`` if valid, ``False`` if not. ``True`` if valid, ``False`` if not.
""" """
if self.dtype is Any: return validate_dtype(dtype, self.dtype)
return True
if isinstance(self.dtype, tuple):
valid = dtype in self.dtype
elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else:
# try to match as any subclass, if self.dtype is a class
try:
valid = issubclass(dtype, self.dtype)
except TypeError:
# expected, if dtype or self.dtype is not a class
valid = dtype == self.dtype
return valid
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
""" """
@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]):
if self.shape is Any: if self.shape is Any:
return True return True
return check_shape(shape, self.shape) return validate_shape(shape, self.shape)
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None: def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
""" """

View file

@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict):
type: Literal["numpy"] type: Literal["numpy"]
dtype: str dtype: str
array: list value: list
def to_array_input(self) -> ndarray: def to_array_input(self) -> ndarray:
""" """
Construct a numpy array Construct a numpy array
""" """
return np.array(self.array, dtype=self.dtype) return np.array(self.value, dtype=self.dtype)
class NumpyInterface(Interface): class NumpyInterface(Interface):
@ -99,6 +99,6 @@ class NumpyInterface(Interface):
if info.round_trip: if info.round_trip:
json_array = NumpyJsonDict( json_array = NumpyJsonDict(
type=cls.name, dtype=str(array.dtype), array=json_array type=cls.name, dtype=str(array.dtype), value=json_array
) )
return json_array return json_array

View file

@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict):
type: Literal["zarr"] type: Literal["zarr"]
file: Optional[str] = None file: Optional[str] = None
path: Optional[str] = None path: Optional[str] = None
array: Optional[list] = None value: Optional[list] = None
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]: def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
""" """
@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict):
if self.file: if self.file:
array = ZarrArrayPath(file=self.file, path=self.path) array = ZarrArrayPath(file=self.file, path=self.path)
else: else:
array = zarr.array(self.array) array = zarr.array(self.value)
return array return array
@ -202,7 +202,7 @@ class ZarrInterface(Interface):
as_json["info"]["hexdigest"] = array.hexdigest() as_json["info"]["hexdigest"] = array.hexdigest()
if dump_array or not is_file: if dump_array or not is_file:
as_json["array"] = array[:].tolist() as_json["value"] = array[:].tolist()
as_json = ZarrJsonDict(**as_json) as_json = ZarrJsonDict(**as_json)
else: else:

View file

@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
""" """
from typing import TYPE_CHECKING, Any, Tuple from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin
import numpy as np import numpy as np
from pydantic import GetJsonSchemaHandler from pydantic import GetJsonSchemaHandler
@ -29,6 +29,7 @@ from numpydantic.schema import (
) )
from numpydantic.serialization import jsonize_array from numpydantic.serialization import jsonize_array
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.error import InvalidArgumentsError from numpydantic.vendor.nptyping.error import InvalidArgumentsError
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
except InterfaceError: except InterfaceError:
return False return False
def _is_literal_like(cls, item: Any) -> bool:
"""
Changes from nptyping:
- doesn't just ducktype for literal but actually, yno, checks for being literal
"""
return get_origin(item) is Literal
def _get_shape(cls, dtype_candidate: Any) -> "Shape": def _get_shape(cls, dtype_candidate: Any) -> "Shape":
""" """
Override of base method to use our local definition of shape Override of base method to use our local definition of shape
""" """
from numpydantic.shape import Shape from numpydantic.validation.shape import Shape
if dtype_candidate is Any or dtype_candidate is Shape: if dtype_candidate is Any or dtype_candidate is Shape:
shape = Any shape = Any
@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
if dtype_candidate is Any: if dtype_candidate is Any:
dtype = Any dtype = Any
elif is_dtype: elif is_dtype or is_union(dtype_candidate):
dtype = dtype_candidate dtype = dtype_candidate
elif issubclass(dtype_candidate, Structure): # pragma: no cover elif issubclass(dtype_candidate, Structure): # pragma: no cover
dtype = dtype_candidate dtype = dtype_candidate

View file

@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
import hashlib import hashlib
import json import json
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional, get_args
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
@ -16,6 +16,7 @@ from numpydantic import dtype as dt
from numpydantic.interface import Interface from numpydantic.interface import Interface
from numpydantic.maps import np_to_python from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.structure import StructureMeta from numpydantic.vendor.nptyping.structure import StructureMeta
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
@ -51,7 +52,6 @@ def _lol_dtype(
"""Get the innermost dtype schema to use in the generated pydantic schema""" """Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, StructureMeta): # pragma: no cover if isinstance(dtype, StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported") raise NotImplementedError("Structured dtypes are currently unsupported")
if isinstance(dtype, tuple): if isinstance(dtype, tuple):
# if it's a meta-type that refers to a generic float/int, just make that # if it's a meta-type that refers to a generic float/int, just make that
if dtype in (dt.Float, dt.Number): if dtype in (dt.Float, dt.Number):
@ -66,7 +66,10 @@ def _lol_dtype(
array_type = core_schema.union_schema( array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in types_] [_lol_dtype(t, _handler) for t in types_]
) )
elif is_union(dtype):
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in get_args(dtype)]
)
else: else:
try: try:
python_type = np_to_python[dtype] python_type = np_to_python[dtype]
@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
core schema to use in the innermost list entry core schema to use in the innermost list entry
""" """
from numpydantic.shape import _is_range from numpydantic.validation.shape import _is_range
shape_parts = [part.strip() for part in shape.__args__[0].split(",")] shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
# labels, if present # labels, if present

View file

@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` .
""" """
from pathlib import Path from pathlib import Path
from typing import Any, Callable, TypeVar, Union from typing import Any, Callable, Iterable, TypeVar, Union
from pydantic_core.core_schema import SerializationInfo from pydantic_core.core_schema import SerializationInfo
@ -16,6 +16,9 @@ U = TypeVar("U")
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON""" """Use an interface class to render an array as JSON"""
# perf: keys to skip in generation - anything named "value" is array data.
skip = ["value"]
interface_cls = Interface.match_output(value) interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info) array = interface_cls.to_json(value, info)
if isinstance(array, JsonDict): if isinstance(array, JsonDict):
@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
if info.context.get("mark_interface", False): if info.context.get("mark_interface", False):
array = interface_cls.mark_json(array) array = interface_cls.mark_json(array)
if isinstance(array, list):
return array
# ---- Perf Barrier ------------------------------------------------------
# put context args intended to **wrap** the array above
# put context args intended to **modify** the array below
#
# above, we assume that a list is **data** not to be modified.
# below, we must mark whenever the data is in the line of fire
# to avoid an expensive iteration.
if info.context.get("absolute_paths", False): if info.context.get("absolute_paths", False):
array = _absolutize_paths(array) array = _absolutize_paths(array, skip)
else: else:
relative_to = info.context.get("relative_to", ".") relative_to = info.context.get("relative_to", ".")
array = _relativize_paths(array, relative_to) array = _relativize_paths(array, relative_to, skip)
else: else:
# relativize paths by default if isinstance(array, list):
array = _relativize_paths(array, ".") return array
# ---- Perf Barrier ------------------------------------------------------
# same as above, ensure any keys that contain array values are skipped right now
array = _relativize_paths(array, ".", skip)
return array return array
def _relativize_paths(value: dict, relative_to: str = ".") -> dict: def _relativize_paths(
value: dict, relative_to: str = ".", skip: Iterable = tuple()
) -> dict:
""" """
Make paths relative to either the current directory or the provided Make paths relative to either the current directory or the provided
``relative_to`` directory, if provided in the context ``relative_to`` directory, if provided in the context
@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
# pdb.set_trace() # pdb.set_trace()
def _r_path(v: Any) -> Any: def _r_path(v: Any) -> Any:
if not isinstance(v, (str, Path)):
return v
try: try:
path = Path(v) path = Path(v)
if not path.exists(): if not path.exists():
@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
except (TypeError, ValueError): except (TypeError, ValueError):
return v return v
return _walk_and_apply(value, _r_path) return _walk_and_apply(value, _r_path, skip)
def _absolutize_paths(value: dict) -> dict: def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict:
def _a_path(v: Any) -> Any: def _a_path(v: Any) -> Any:
try: try:
path = Path(v) path = Path(v)
@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict:
except (TypeError, ValueError): except (TypeError, ValueError):
return v return v
return _walk_and_apply(value, _a_path) return _walk_and_apply(value, _a_path, skip)
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T: def _walk_and_apply(value: T, f: Callable[[U], U], skip: Iterable = tuple()) -> T:
""" """
Walk an object, applying a function Walk an object, applying a function
""" """
if isinstance(value, dict): if isinstance(value, dict):
for k, v in value.items(): for k, v in value.items():
if k in skip:
continue
if isinstance(v, dict): if isinstance(v, dict):
_walk_and_apply(v, f) _walk_and_apply(v, f, skip)
elif isinstance(v, list): elif isinstance(v, list):
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v] value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v]
else: else:
value[k] = f(v) value[k] = f(v)
elif isinstance(value, list): elif isinstance(value, list):
value = [_walk_and_apply(v, f) for v in value] value = [_walk_and_apply(v, f, skip) for v in value]
else: else:
value = f(value) value = f(value)
return value return value

View file

@ -0,0 +1,11 @@
"""
Helper functions for validation
"""
from numpydantic.validation.dtype import validate_dtype
from numpydantic.validation.shape import validate_shape
__all__ = [
"validate_dtype",
"validate_shape",
]

View file

@ -0,0 +1,55 @@
"""
Helper functions for validation of dtype.
For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
"""
from types import UnionType
from typing import Any, Union, get_args, get_origin
import numpy as np
from numpydantic.types import DtypeType
def validate_dtype(dtype: Any, target: DtypeType) -> bool:
"""
Validate a dtype against the target dtype
Args:
dtype: The dtype to validate
target (:class:`.DtypeType`): The target dtype
Returns:
bool: ``True`` if valid, ``False`` otherwise
"""
if target is Any:
return True
if isinstance(target, tuple):
valid = dtype in target
elif is_union(target):
valid = any(
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
)
elif target is np.str_:
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else:
# try to match as any subclass, if target is a class
try:
valid = issubclass(dtype, target)
except TypeError:
# expected, if dtype or target is not a class
valid = dtype == target
return valid
def is_union(dtype: DtypeType) -> bool:
"""
Check if a dtype is a union
"""
return get_origin(dtype) in (Union, UnionType)

View file

@ -2,7 +2,7 @@
Declaration and validation functions for array shapes. Declaration and validation functions for array shapes.
Mostly a mildly modified version of nptyping's Mostly a mildly modified version of nptyping's
:func:`npytping.shape_expression.check_shape` :func:`npytping.shape_expression.validate_shape`
and its internals to allow for extended syntax, including ranges of shapes. and its internals to allow for extended syntax, including ranges of shapes.
Modifications from nptyping: Modifications from nptyping:
@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) ->
@lru_cache @lru_cache
def check_shape(shape: ShapeTuple, target: "Shape") -> bool: def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
""" """
Check whether the given shape corresponds to the given shape_expression. Check whether the given shape corresponds to the given shape_expression.
:param shape: the shape in question. :param shape: the shape in question.

View file

@ -80,6 +80,8 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str] STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
@pytest.fixture( @pytest.fixture(
@ -147,6 +149,16 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
ValidationCase(annotation=MODEL, dtype=int, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False),
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False),
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False),
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True),
ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True),
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False),
], ],
ids=[ ids=[
"float", "float",
@ -174,6 +186,16 @@ def shape_cases(request) -> ValidationCase:
"model-badmodel", "model-badmodel",
"model-int", "model-int",
"model-subclass", "model-subclass",
"union-pipe-uint32",
"union-pipe-float32",
"union-pipe-uint64",
"union-pipe-float64",
"union-pipe-str",
"union-type-uint32",
"union-type-float32",
"union-type-uint64",
"union-type-float64",
"union-type-str",
], ],
) )
def dtype_cases(request) -> ValidationCase: def dtype_cases(request) -> ValidationCase:

View file

@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
if roundtrip: if roundtrip:
if dump_array: if dump_array:
assert as_json["array"] == lol_array assert as_json["value"] == lol_array
else: else:
if as_json.get("file", False): if as_json.get("file", False):
assert "array" not in as_json assert "array" not in as_json