refactor validation into separate module, perf improvements from not iterating over every element of array, rename array to value universally in serialized json lol

This commit is contained in:
sneakers-the-rat 2024-09-23 23:25:20 -07:00
parent 7f2c79bbae
commit 85cef50603
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
19 changed files with 196 additions and 57 deletions

4
.gitignore vendored
View file

@ -159,4 +159,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.pdm-python
ndarray.pyi
ndarray.pyi
prof/

View file

@ -0,0 +1,7 @@
# dtype
```{eval-rst}
.. automodule:: numpydantic.validation.dtype
:members:
:undoc-members:
```

View file

@ -0,0 +1,6 @@
# validation
```{toctree}
dtype
shape
```

View file

@ -1,7 +1,7 @@
# shape
```{eval-rst}
.. automodule:: numpydantic.shape
.. automodule:: numpydantic.validation.shape
:members:
:undoc-members:
```

View file

@ -484,13 +484,13 @@ interfaces
api/index
api/interface/index
api/validation/index
api/dtype
api/ndarray
api/maps
api/meta
api/schema
api/serialization
api/shape
api/types
```

View file

@ -4,7 +4,7 @@
from numpydantic.ndarray import NDArray
from numpydantic.meta import update_ndarray_stub
from numpydantic.shape import Shape
from numpydantic.validation.shape import Shape
update_ndarray_stub()

View file

@ -12,6 +12,9 @@ module, rather than needing to import each individually.
Some types like `Integer` are compound types - tuples of multiple dtypes.
Check these using ``in`` rather than ``==``. This interface will develop in future
versions to allow a single dtype check.
For internal helper functions for validating dtype,
see :mod:`numpydantic.validation.dtype`
"""
import sys

View file

@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict):
name: str
chunks: Iterable[tuple[int, ...]]
dtype: str
array: list
value: list
def to_array_input(self) -> DaskArray:
"""Construct a dask array"""
np_array = np.array(self.array, dtype=self.dtype)
np_array = np.array(self.value, dtype=self.dtype)
array = from_array(
np_array,
name=self.name,
@ -100,7 +100,7 @@ class DaskInterface(Interface):
if info.round_trip:
as_json = DaskJsonDict(
type=cls.name,
array=as_json,
value=as_json,
name=array.name,
chunks=array.chunks,
dtype=str(np_array.dtype),

View file

@ -20,8 +20,8 @@ from numpydantic.exceptions import (
ShapeError,
TooManyMatchesError,
)
from numpydantic.shape import check_shape
from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation import validate_dtype, validate_shape
T = TypeVar("T", bound=NDArrayType)
U = TypeVar("U", bound="JsonDict")
@ -76,6 +76,21 @@ class InterfaceMark(BaseModel):
class JsonDict(BaseModel):
"""
Representation of array when dumped with round_trip == True.
.. admonition:: Developer's Note
Any JsonDict that contains an actual array should be named ``value``
rather than array (or any other name), and nothing but the
array data should be named ``value`` .
During JSON serialization, it becomes ambiguous what contains an array
of data vs. an array of metadata. For the moment we would like to
reserve the ability to have lists of metadata, so until we rule that out,
we would like to be able to avoid iterating over every element of an array
in any context parameter transformation like relativizing/absolutizing paths.
To avoid that, it's good to agree on a single value name -- ``value`` --
and avoid using it for anything else.
"""
type: str
@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]):
Validate the dtype of the given array, returning
``True`` if valid, ``False`` if not.
"""
if self.dtype is Any:
return True
if isinstance(self.dtype, tuple):
valid = dtype in self.dtype
elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else:
# try to match as any subclass, if self.dtype is a class
try:
valid = issubclass(dtype, self.dtype)
except TypeError:
# expected, if dtype or self.dtype is not a class
valid = dtype == self.dtype
return valid
return validate_dtype(dtype, self.dtype)
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
"""
@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]):
if self.shape is Any:
return True
return check_shape(shape, self.shape)
return validate_shape(shape, self.shape)
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
"""

View file

@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict):
type: Literal["numpy"]
dtype: str
array: list
value: list
def to_array_input(self) -> ndarray:
"""
Construct a numpy array
"""
return np.array(self.array, dtype=self.dtype)
return np.array(self.value, dtype=self.dtype)
class NumpyInterface(Interface):
@ -99,6 +99,6 @@ class NumpyInterface(Interface):
if info.round_trip:
json_array = NumpyJsonDict(
type=cls.name, dtype=str(array.dtype), array=json_array
type=cls.name, dtype=str(array.dtype), value=json_array
)
return json_array

View file

@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict):
type: Literal["zarr"]
file: Optional[str] = None
path: Optional[str] = None
array: Optional[list] = None
value: Optional[list] = None
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
"""
@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict):
if self.file:
array = ZarrArrayPath(file=self.file, path=self.path)
else:
array = zarr.array(self.array)
array = zarr.array(self.value)
return array
@ -202,7 +202,7 @@ class ZarrInterface(Interface):
as_json["info"]["hexdigest"] = array.hexdigest()
if dump_array or not is_file:
as_json["array"] = array[:].tolist()
as_json["value"] = array[:].tolist()
as_json = ZarrJsonDict(**as_json)
else:

View file

@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
"""
from typing import TYPE_CHECKING, Any, Tuple
from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin
import numpy as np
from pydantic import GetJsonSchemaHandler
@ -29,6 +29,7 @@ from numpydantic.schema import (
)
from numpydantic.serialization import jsonize_array
from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
except InterfaceError:
return False
def _is_literal_like(cls, item: Any) -> bool:
"""
Changes from nptyping:
- doesn't just ducktype for literal but actually, yno, checks for being literal
"""
return get_origin(item) is Literal
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
"""
Override of base method to use our local definition of shape
"""
from numpydantic.shape import Shape
from numpydantic.validation.shape import Shape
if dtype_candidate is Any or dtype_candidate is Shape:
shape = Any
@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
if dtype_candidate is Any:
dtype = Any
elif is_dtype:
elif is_dtype or is_union(dtype_candidate):
dtype = dtype_candidate
elif issubclass(dtype_candidate, Structure): # pragma: no cover
dtype = dtype_candidate

View file

@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
import hashlib
import json
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, get_args
import numpy as np
from pydantic import BaseModel
@ -16,6 +16,7 @@ from numpydantic import dtype as dt
from numpydantic.interface import Interface
from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.structure import StructureMeta
if TYPE_CHECKING: # pragma: no cover
@ -51,7 +52,6 @@ def _lol_dtype(
"""Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported")
if isinstance(dtype, tuple):
# if it's a meta-type that refers to a generic float/int, just make that
if dtype in (dt.Float, dt.Number):
@ -66,7 +66,10 @@ def _lol_dtype(
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in types_]
)
elif is_union(dtype):
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in get_args(dtype)]
)
else:
try:
python_type = np_to_python[dtype]
@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
core schema to use in the innermost list entry
"""
from numpydantic.shape import _is_range
from numpydantic.validation.shape import _is_range
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
# labels, if present

View file

@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` .
"""
from pathlib import Path
from typing import Any, Callable, TypeVar, Union
from typing import Any, Callable, Iterable, TypeVar, Union
from pydantic_core.core_schema import SerializationInfo
@ -16,6 +16,9 @@ U = TypeVar("U")
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON"""
# perf: keys to skip in generation - anything named "value" is array data.
skip = ["value"]
interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info)
if isinstance(array, JsonDict):
@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
if info.context.get("mark_interface", False):
array = interface_cls.mark_json(array)
if isinstance(array, list):
return array
# ---- Perf Barrier ------------------------------------------------------
# put context args intended to **wrap** the array above
# put context args intended to **modify** the array below
#
# above, we assume that a list is **data** not to be modified.
# below, we must mark whenever the data is in the line of fire
# to avoid an expensive iteration.
if info.context.get("absolute_paths", False):
array = _absolutize_paths(array)
array = _absolutize_paths(array, skip)
else:
relative_to = info.context.get("relative_to", ".")
array = _relativize_paths(array, relative_to)
array = _relativize_paths(array, relative_to, skip)
else:
# relativize paths by default
array = _relativize_paths(array, ".")
if isinstance(array, list):
return array
# ---- Perf Barrier ------------------------------------------------------
# same as above, ensure any keys that contain array values are skipped right now
array = _relativize_paths(array, ".", skip)
return array
def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
def _relativize_paths(
value: dict, relative_to: str = ".", skip: Iterable = tuple()
) -> dict:
"""
Make paths relative to either the current directory or the provided
``relative_to`` directory, if provided in the context
@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
# pdb.set_trace()
def _r_path(v: Any) -> Any:
if not isinstance(v, (str, Path)):
return v
try:
path = Path(v)
if not path.exists():
@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
except (TypeError, ValueError):
return v
return _walk_and_apply(value, _r_path)
return _walk_and_apply(value, _r_path, skip)
def _absolutize_paths(value: dict) -> dict:
def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict:
def _a_path(v: Any) -> Any:
try:
path = Path(v)
@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict:
except (TypeError, ValueError):
return v
return _walk_and_apply(value, _a_path)
return _walk_and_apply(value, _a_path, skip)
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T:
def _walk_and_apply(value: T, f: Callable[[U], U], skip: Iterable = tuple()) -> T:
"""
Walk an object, applying a function
"""
if isinstance(value, dict):
for k, v in value.items():
if k in skip:
continue
if isinstance(v, dict):
_walk_and_apply(v, f)
_walk_and_apply(v, f, skip)
elif isinstance(v, list):
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v]
value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v]
else:
value[k] = f(v)
elif isinstance(value, list):
value = [_walk_and_apply(v, f) for v in value]
value = [_walk_and_apply(v, f, skip) for v in value]
else:
value = f(value)
return value

View file

@ -0,0 +1,11 @@
"""
Helper functions for validation
"""
from numpydantic.validation.dtype import validate_dtype
from numpydantic.validation.shape import validate_shape
__all__ = [
"validate_dtype",
"validate_shape",
]

View file

@ -0,0 +1,55 @@
"""
Helper functions for validation of dtype.
For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
"""
from types import UnionType
from typing import Any, Union, get_args, get_origin
import numpy as np
from numpydantic.types import DtypeType
def validate_dtype(dtype: Any, target: DtypeType) -> bool:
"""
Validate a dtype against the target dtype
Args:
dtype: The dtype to validate
target (:class:`.DtypeType`): The target dtype
Returns:
bool: ``True`` if valid, ``False`` otherwise
"""
if target is Any:
return True
if isinstance(target, tuple):
valid = dtype in target
elif is_union(target):
valid = any(
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
)
elif target is np.str_:
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else:
# try to match as any subclass, if target is a class
try:
valid = issubclass(dtype, target)
except TypeError:
# expected, if dtype or target is not a class
valid = dtype == target
return valid
def is_union(dtype: DtypeType) -> bool:
"""
Check if a dtype is a union
"""
return get_origin(dtype) in (Union, UnionType)

View file

@ -2,7 +2,7 @@
Declaration and validation functions for array shapes.
Mostly a mildly modified version of nptyping's
:func:`npytping.shape_expression.check_shape`
:func:`npytping.shape_expression.validate_shape`
and its internals to allow for extended syntax, including ranges of shapes.
Modifications from nptyping:
@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) ->
@lru_cache
def check_shape(shape: ShapeTuple, target: "Shape") -> bool:
def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
"""
Check whether the given shape corresponds to the given shape_expression.
:param shape: the shape in question.

View file

@ -80,6 +80,8 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
@pytest.fixture(
@ -147,6 +149,16 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
ValidationCase(annotation=MODEL, dtype=int, passes=False),
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False),
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False),
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True),
ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True),
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False),
],
ids=[
"float",
@ -174,6 +186,16 @@ def shape_cases(request) -> ValidationCase:
"model-badmodel",
"model-int",
"model-subclass",
"union-pipe-uint32",
"union-pipe-float32",
"union-pipe-uint64",
"union-pipe-float64",
"union-pipe-str",
"union-type-uint32",
"union-type-float32",
"union-type-uint64",
"union-type-float64",
"union-type-str",
],
)
def dtype_cases(request) -> ValidationCase:

View file

@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
if roundtrip:
if dump_array:
assert as_json["array"] == lol_array
assert as_json["value"] == lol_array
else:
if as_json.get("file", False):
assert "array" not in as_json