mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-14 10:44:28 +00:00
refactor validation into separate module, perf improvements from not iterating over every element of array, rename array
to value
universally in serialized json lol
This commit is contained in:
parent
7f2c79bbae
commit
85cef50603
19 changed files with 196 additions and 57 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -159,4 +159,6 @@ cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
.pdm-python
|
.pdm-python
|
||||||
ndarray.pyi
|
ndarray.pyi
|
||||||
|
|
||||||
|
prof/
|
7
docs/api/validation/dtype.md
Normal file
7
docs/api/validation/dtype.md
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# dtype
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. automodule:: numpydantic.validation.dtype
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
6
docs/api/validation/index.md
Normal file
6
docs/api/validation/index.md
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
# validation
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
dtype
|
||||||
|
shape
|
||||||
|
```
|
|
@ -1,7 +1,7 @@
|
||||||
# shape
|
# shape
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. automodule:: numpydantic.shape
|
.. automodule:: numpydantic.validation.shape
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
```
|
```
|
|
@ -484,13 +484,13 @@ interfaces
|
||||||
|
|
||||||
api/index
|
api/index
|
||||||
api/interface/index
|
api/interface/index
|
||||||
|
api/validation/index
|
||||||
api/dtype
|
api/dtype
|
||||||
api/ndarray
|
api/ndarray
|
||||||
api/maps
|
api/maps
|
||||||
api/meta
|
api/meta
|
||||||
api/schema
|
api/schema
|
||||||
api/serialization
|
api/serialization
|
||||||
api/shape
|
|
||||||
api/types
|
api/types
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
from numpydantic.ndarray import NDArray
|
from numpydantic.ndarray import NDArray
|
||||||
from numpydantic.meta import update_ndarray_stub
|
from numpydantic.meta import update_ndarray_stub
|
||||||
from numpydantic.shape import Shape
|
from numpydantic.validation.shape import Shape
|
||||||
|
|
||||||
update_ndarray_stub()
|
update_ndarray_stub()
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,9 @@ module, rather than needing to import each individually.
|
||||||
Some types like `Integer` are compound types - tuples of multiple dtypes.
|
Some types like `Integer` are compound types - tuples of multiple dtypes.
|
||||||
Check these using ``in`` rather than ``==``. This interface will develop in future
|
Check these using ``in`` rather than ``==``. This interface will develop in future
|
||||||
versions to allow a single dtype check.
|
versions to allow a single dtype check.
|
||||||
|
|
||||||
|
For internal helper functions for validating dtype,
|
||||||
|
see :mod:`numpydantic.validation.dtype`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict):
|
||||||
name: str
|
name: str
|
||||||
chunks: Iterable[tuple[int, ...]]
|
chunks: Iterable[tuple[int, ...]]
|
||||||
dtype: str
|
dtype: str
|
||||||
array: list
|
value: list
|
||||||
|
|
||||||
def to_array_input(self) -> DaskArray:
|
def to_array_input(self) -> DaskArray:
|
||||||
"""Construct a dask array"""
|
"""Construct a dask array"""
|
||||||
np_array = np.array(self.array, dtype=self.dtype)
|
np_array = np.array(self.value, dtype=self.dtype)
|
||||||
array = from_array(
|
array = from_array(
|
||||||
np_array,
|
np_array,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
|
@ -100,7 +100,7 @@ class DaskInterface(Interface):
|
||||||
if info.round_trip:
|
if info.round_trip:
|
||||||
as_json = DaskJsonDict(
|
as_json = DaskJsonDict(
|
||||||
type=cls.name,
|
type=cls.name,
|
||||||
array=as_json,
|
value=as_json,
|
||||||
name=array.name,
|
name=array.name,
|
||||||
chunks=array.chunks,
|
chunks=array.chunks,
|
||||||
dtype=str(np_array.dtype),
|
dtype=str(np_array.dtype),
|
||||||
|
|
|
@ -20,8 +20,8 @@ from numpydantic.exceptions import (
|
||||||
ShapeError,
|
ShapeError,
|
||||||
TooManyMatchesError,
|
TooManyMatchesError,
|
||||||
)
|
)
|
||||||
from numpydantic.shape import check_shape
|
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
from numpydantic.validation import validate_dtype, validate_shape
|
||||||
|
|
||||||
T = TypeVar("T", bound=NDArrayType)
|
T = TypeVar("T", bound=NDArrayType)
|
||||||
U = TypeVar("U", bound="JsonDict")
|
U = TypeVar("U", bound="JsonDict")
|
||||||
|
@ -76,6 +76,21 @@ class InterfaceMark(BaseModel):
|
||||||
class JsonDict(BaseModel):
|
class JsonDict(BaseModel):
|
||||||
"""
|
"""
|
||||||
Representation of array when dumped with round_trip == True.
|
Representation of array when dumped with round_trip == True.
|
||||||
|
|
||||||
|
.. admonition:: Developer's Note
|
||||||
|
|
||||||
|
Any JsonDict that contains an actual array should be named ``value``
|
||||||
|
rather than array (or any other name), and nothing but the
|
||||||
|
array data should be named ``value`` .
|
||||||
|
|
||||||
|
During JSON serialization, it becomes ambiguous what contains an array
|
||||||
|
of data vs. an array of metadata. For the moment we would like to
|
||||||
|
reserve the ability to have lists of metadata, so until we rule that out,
|
||||||
|
we would like to be able to avoid iterating over every element of an array
|
||||||
|
in any context parameter transformation like relativizing/absolutizing paths.
|
||||||
|
To avoid that, it's good to agree on a single value name -- ``value`` --
|
||||||
|
and avoid using it for anything else.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
|
@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]):
|
||||||
Validate the dtype of the given array, returning
|
Validate the dtype of the given array, returning
|
||||||
``True`` if valid, ``False`` if not.
|
``True`` if valid, ``False`` if not.
|
||||||
"""
|
"""
|
||||||
if self.dtype is Any:
|
return validate_dtype(dtype, self.dtype)
|
||||||
return True
|
|
||||||
|
|
||||||
if isinstance(self.dtype, tuple):
|
|
||||||
valid = dtype in self.dtype
|
|
||||||
elif self.dtype is np.str_:
|
|
||||||
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
|
|
||||||
np.str_,
|
|
||||||
str,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# try to match as any subclass, if self.dtype is a class
|
|
||||||
try:
|
|
||||||
valid = issubclass(dtype, self.dtype)
|
|
||||||
except TypeError:
|
|
||||||
# expected, if dtype or self.dtype is not a class
|
|
||||||
valid = dtype == self.dtype
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]):
|
||||||
if self.shape is Any:
|
if self.shape is Any:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return check_shape(shape, self.shape)
|
return validate_shape(shape, self.shape)
|
||||||
|
|
||||||
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
|
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict):
|
||||||
|
|
||||||
type: Literal["numpy"]
|
type: Literal["numpy"]
|
||||||
dtype: str
|
dtype: str
|
||||||
array: list
|
value: list
|
||||||
|
|
||||||
def to_array_input(self) -> ndarray:
|
def to_array_input(self) -> ndarray:
|
||||||
"""
|
"""
|
||||||
Construct a numpy array
|
Construct a numpy array
|
||||||
"""
|
"""
|
||||||
return np.array(self.array, dtype=self.dtype)
|
return np.array(self.value, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class NumpyInterface(Interface):
|
class NumpyInterface(Interface):
|
||||||
|
@ -99,6 +99,6 @@ class NumpyInterface(Interface):
|
||||||
|
|
||||||
if info.round_trip:
|
if info.round_trip:
|
||||||
json_array = NumpyJsonDict(
|
json_array = NumpyJsonDict(
|
||||||
type=cls.name, dtype=str(array.dtype), array=json_array
|
type=cls.name, dtype=str(array.dtype), value=json_array
|
||||||
)
|
)
|
||||||
return json_array
|
return json_array
|
||||||
|
|
|
@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict):
|
||||||
type: Literal["zarr"]
|
type: Literal["zarr"]
|
||||||
file: Optional[str] = None
|
file: Optional[str] = None
|
||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
array: Optional[list] = None
|
value: Optional[list] = None
|
||||||
|
|
||||||
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
||||||
"""
|
"""
|
||||||
|
@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict):
|
||||||
if self.file:
|
if self.file:
|
||||||
array = ZarrArrayPath(file=self.file, path=self.path)
|
array = ZarrArrayPath(file=self.file, path=self.path)
|
||||||
else:
|
else:
|
||||||
array = zarr.array(self.array)
|
array = zarr.array(self.value)
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ class ZarrInterface(Interface):
|
||||||
as_json["info"]["hexdigest"] = array.hexdigest()
|
as_json["info"]["hexdigest"] = array.hexdigest()
|
||||||
|
|
||||||
if dump_array or not is_file:
|
if dump_array or not is_file:
|
||||||
as_json["array"] = array[:].tolist()
|
as_json["value"] = array[:].tolist()
|
||||||
|
|
||||||
as_json = ZarrJsonDict(**as_json)
|
as_json = ZarrJsonDict(**as_json)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Tuple
|
from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import GetJsonSchemaHandler
|
from pydantic import GetJsonSchemaHandler
|
||||||
|
@ -29,6 +29,7 @@ from numpydantic.schema import (
|
||||||
)
|
)
|
||||||
from numpydantic.serialization import jsonize_array
|
from numpydantic.serialization import jsonize_array
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
from numpydantic.validation.dtype import is_union
|
||||||
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
||||||
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||||
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
|
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
|
||||||
|
@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
except InterfaceError:
|
except InterfaceError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _is_literal_like(cls, item: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Changes from nptyping:
|
||||||
|
- doesn't just ducktype for literal but actually, yno, checks for being literal
|
||||||
|
"""
|
||||||
|
return get_origin(item) is Literal
|
||||||
|
|
||||||
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
|
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
|
||||||
"""
|
"""
|
||||||
Override of base method to use our local definition of shape
|
Override of base method to use our local definition of shape
|
||||||
"""
|
"""
|
||||||
from numpydantic.shape import Shape
|
from numpydantic.validation.shape import Shape
|
||||||
|
|
||||||
if dtype_candidate is Any or dtype_candidate is Shape:
|
if dtype_candidate is Any or dtype_candidate is Shape:
|
||||||
shape = Any
|
shape = Any
|
||||||
|
@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
|
|
||||||
if dtype_candidate is Any:
|
if dtype_candidate is Any:
|
||||||
dtype = Any
|
dtype = Any
|
||||||
elif is_dtype:
|
elif is_dtype or is_union(dtype_candidate):
|
||||||
dtype = dtype_candidate
|
dtype = dtype_candidate
|
||||||
elif issubclass(dtype_candidate, Structure): # pragma: no cover
|
elif issubclass(dtype_candidate, Structure): # pragma: no cover
|
||||||
dtype = dtype_candidate
|
dtype = dtype_candidate
|
||||||
|
|
|
@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -16,6 +16,7 @@ from numpydantic import dtype as dt
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.maps import np_to_python
|
from numpydantic.maps import np_to_python
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
from numpydantic.validation.dtype import is_union
|
||||||
from numpydantic.vendor.nptyping.structure import StructureMeta
|
from numpydantic.vendor.nptyping.structure import StructureMeta
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
@ -51,7 +52,6 @@ def _lol_dtype(
|
||||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||||
if isinstance(dtype, StructureMeta): # pragma: no cover
|
if isinstance(dtype, StructureMeta): # pragma: no cover
|
||||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||||
|
|
||||||
if isinstance(dtype, tuple):
|
if isinstance(dtype, tuple):
|
||||||
# if it's a meta-type that refers to a generic float/int, just make that
|
# if it's a meta-type that refers to a generic float/int, just make that
|
||||||
if dtype in (dt.Float, dt.Number):
|
if dtype in (dt.Float, dt.Number):
|
||||||
|
@ -66,7 +66,10 @@ def _lol_dtype(
|
||||||
array_type = core_schema.union_schema(
|
array_type = core_schema.union_schema(
|
||||||
[_lol_dtype(t, _handler) for t in types_]
|
[_lol_dtype(t, _handler) for t in types_]
|
||||||
)
|
)
|
||||||
|
elif is_union(dtype):
|
||||||
|
array_type = core_schema.union_schema(
|
||||||
|
[_lol_dtype(t, _handler) for t in get_args(dtype)]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
python_type = np_to_python[dtype]
|
python_type = np_to_python[dtype]
|
||||||
|
@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
|
||||||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||||
core schema to use in the innermost list entry
|
core schema to use in the innermost list entry
|
||||||
"""
|
"""
|
||||||
from numpydantic.shape import _is_range
|
from numpydantic.validation.shape import _is_range
|
||||||
|
|
||||||
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
||||||
# labels, if present
|
# labels, if present
|
||||||
|
|
|
@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` .
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, TypeVar, Union
|
from typing import Any, Callable, Iterable, TypeVar, Union
|
||||||
|
|
||||||
from pydantic_core.core_schema import SerializationInfo
|
from pydantic_core.core_schema import SerializationInfo
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ U = TypeVar("U")
|
||||||
|
|
||||||
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
"""Use an interface class to render an array as JSON"""
|
"""Use an interface class to render an array as JSON"""
|
||||||
|
# perf: keys to skip in generation - anything named "value" is array data.
|
||||||
|
skip = ["value"]
|
||||||
|
|
||||||
interface_cls = Interface.match_output(value)
|
interface_cls = Interface.match_output(value)
|
||||||
array = interface_cls.to_json(value, info)
|
array = interface_cls.to_json(value, info)
|
||||||
if isinstance(array, JsonDict):
|
if isinstance(array, JsonDict):
|
||||||
|
@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
if info.context.get("mark_interface", False):
|
if info.context.get("mark_interface", False):
|
||||||
array = interface_cls.mark_json(array)
|
array = interface_cls.mark_json(array)
|
||||||
|
|
||||||
|
if isinstance(array, list):
|
||||||
|
return array
|
||||||
|
|
||||||
|
# ---- Perf Barrier ------------------------------------------------------
|
||||||
|
# put context args intended to **wrap** the array above
|
||||||
|
# put context args intended to **modify** the array below
|
||||||
|
#
|
||||||
|
# above, we assume that a list is **data** not to be modified.
|
||||||
|
# below, we must mark whenever the data is in the line of fire
|
||||||
|
# to avoid an expensive iteration.
|
||||||
|
|
||||||
if info.context.get("absolute_paths", False):
|
if info.context.get("absolute_paths", False):
|
||||||
array = _absolutize_paths(array)
|
array = _absolutize_paths(array, skip)
|
||||||
else:
|
else:
|
||||||
relative_to = info.context.get("relative_to", ".")
|
relative_to = info.context.get("relative_to", ".")
|
||||||
array = _relativize_paths(array, relative_to)
|
array = _relativize_paths(array, relative_to, skip)
|
||||||
else:
|
else:
|
||||||
# relativize paths by default
|
if isinstance(array, list):
|
||||||
array = _relativize_paths(array, ".")
|
return array
|
||||||
|
|
||||||
|
# ---- Perf Barrier ------------------------------------------------------
|
||||||
|
# same as above, ensure any keys that contain array values are skipped right now
|
||||||
|
|
||||||
|
array = _relativize_paths(array, ".", skip)
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
def _relativize_paths(
|
||||||
|
value: dict, relative_to: str = ".", skip: Iterable = tuple()
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Make paths relative to either the current directory or the provided
|
Make paths relative to either the current directory or the provided
|
||||||
``relative_to`` directory, if provided in the context
|
``relative_to`` directory, if provided in the context
|
||||||
|
@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
|
|
||||||
def _r_path(v: Any) -> Any:
|
def _r_path(v: Any) -> Any:
|
||||||
|
if not isinstance(v, (str, Path)):
|
||||||
|
return v
|
||||||
try:
|
try:
|
||||||
path = Path(v)
|
path = Path(v)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return _walk_and_apply(value, _r_path)
|
return _walk_and_apply(value, _r_path, skip)
|
||||||
|
|
||||||
|
|
||||||
def _absolutize_paths(value: dict) -> dict:
|
def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict:
|
||||||
def _a_path(v: Any) -> Any:
|
def _a_path(v: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
path = Path(v)
|
path = Path(v)
|
||||||
|
@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict:
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return _walk_and_apply(value, _a_path)
|
return _walk_and_apply(value, _a_path, skip)
|
||||||
|
|
||||||
|
|
||||||
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T:
|
def _walk_and_apply(value: T, f: Callable[[U], U], skip: Iterable = tuple()) -> T:
|
||||||
"""
|
"""
|
||||||
Walk an object, applying a function
|
Walk an object, applying a function
|
||||||
"""
|
"""
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
|
if k in skip:
|
||||||
|
continue
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
_walk_and_apply(v, f)
|
_walk_and_apply(v, f, skip)
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v]
|
value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v]
|
||||||
else:
|
else:
|
||||||
value[k] = f(v)
|
value[k] = f(v)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
value = [_walk_and_apply(v, f) for v in value]
|
value = [_walk_and_apply(v, f, skip) for v in value]
|
||||||
else:
|
else:
|
||||||
value = f(value)
|
value = f(value)
|
||||||
return value
|
return value
|
||||||
|
|
11
src/numpydantic/validation/__init__.py
Normal file
11
src/numpydantic/validation/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
"""
|
||||||
|
Helper functions for validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from numpydantic.validation.dtype import validate_dtype
|
||||||
|
from numpydantic.validation.shape import validate_shape
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"validate_dtype",
|
||||||
|
"validate_shape",
|
||||||
|
]
|
55
src/numpydantic/validation/dtype.py
Normal file
55
src/numpydantic/validation/dtype.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
"""
|
||||||
|
Helper functions for validation of dtype.
|
||||||
|
|
||||||
|
For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import UnionType
|
||||||
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from numpydantic.types import DtypeType
|
||||||
|
|
||||||
|
|
||||||
|
def validate_dtype(dtype: Any, target: DtypeType) -> bool:
|
||||||
|
"""
|
||||||
|
Validate a dtype against the target dtype
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: The dtype to validate
|
||||||
|
target (:class:`.DtypeType`): The target dtype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if valid, ``False`` otherwise
|
||||||
|
"""
|
||||||
|
if target is Any:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(target, tuple):
|
||||||
|
valid = dtype in target
|
||||||
|
elif is_union(target):
|
||||||
|
valid = any(
|
||||||
|
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
|
||||||
|
)
|
||||||
|
elif target is np.str_:
|
||||||
|
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
|
||||||
|
np.str_,
|
||||||
|
str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# try to match as any subclass, if target is a class
|
||||||
|
try:
|
||||||
|
valid = issubclass(dtype, target)
|
||||||
|
except TypeError:
|
||||||
|
# expected, if dtype or target is not a class
|
||||||
|
valid = dtype == target
|
||||||
|
|
||||||
|
return valid
|
||||||
|
|
||||||
|
|
||||||
|
def is_union(dtype: DtypeType) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a dtype is a union
|
||||||
|
"""
|
||||||
|
return get_origin(dtype) in (Union, UnionType)
|
|
@ -2,7 +2,7 @@
|
||||||
Declaration and validation functions for array shapes.
|
Declaration and validation functions for array shapes.
|
||||||
|
|
||||||
Mostly a mildly modified version of nptyping's
|
Mostly a mildly modified version of nptyping's
|
||||||
:func:`npytping.shape_expression.check_shape`
|
:func:`npytping.shape_expression.validate_shape`
|
||||||
and its internals to allow for extended syntax, including ranges of shapes.
|
and its internals to allow for extended syntax, including ranges of shapes.
|
||||||
|
|
||||||
Modifications from nptyping:
|
Modifications from nptyping:
|
||||||
|
@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) ->
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def check_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether the given shape corresponds to the given shape_expression.
|
Check whether the given shape corresponds to the given shape_expression.
|
||||||
:param shape: the shape in question.
|
:param shape: the shape in question.
|
|
@ -80,6 +80,8 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||||
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||||
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||||
|
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
|
||||||
|
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
|
@ -147,6 +149,16 @@ def shape_cases(request) -> ValidationCase:
|
||||||
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
||||||
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
||||||
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
|
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False),
|
||||||
|
ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True),
|
||||||
|
ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True),
|
||||||
|
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False),
|
||||||
|
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False),
|
||||||
|
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False),
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"float",
|
"float",
|
||||||
|
@ -174,6 +186,16 @@ def shape_cases(request) -> ValidationCase:
|
||||||
"model-badmodel",
|
"model-badmodel",
|
||||||
"model-int",
|
"model-int",
|
||||||
"model-subclass",
|
"model-subclass",
|
||||||
|
"union-pipe-uint32",
|
||||||
|
"union-pipe-float32",
|
||||||
|
"union-pipe-uint64",
|
||||||
|
"union-pipe-float64",
|
||||||
|
"union-pipe-str",
|
||||||
|
"union-type-uint32",
|
||||||
|
"union-type-float32",
|
||||||
|
"union-type-uint64",
|
||||||
|
"union-type-float64",
|
||||||
|
"union-type-str",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def dtype_cases(request) -> ValidationCase:
|
def dtype_cases(request) -> ValidationCase:
|
||||||
|
|
|
@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
|
||||||
|
|
||||||
if roundtrip:
|
if roundtrip:
|
||||||
if dump_array:
|
if dump_array:
|
||||||
assert as_json["array"] == lol_array
|
assert as_json["value"] == lol_array
|
||||||
else:
|
else:
|
||||||
if as_json.get("file", False):
|
if as_json.get("file", False):
|
||||||
assert "array" not in as_json
|
assert "array" not in as_json
|
||||||
|
|
Loading…
Reference in a new issue