mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 13:44:26 +00:00
Merge pull request #24 from p2p-ld/dtype-union
Some checks are pending
Lint / Ruff Linting (push) Waiting to run
Lint / Black Formatting (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.10) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / finish-coverage (push) Blocked by required conditions
Some checks are pending
Lint / Ruff Linting (push) Waiting to run
Lint / Black Formatting (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.10) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / finish-coverage (push) Blocked by required conditions
[dtype] Support Unions
This commit is contained in:
commit
1cf69eb18c
23 changed files with 443 additions and 120 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -160,3 +160,5 @@ cython_debug/
|
|||
#.idea/
|
||||
.pdm-python
|
||||
ndarray.pyi
|
||||
|
||||
prof/
|
7
docs/api/validation/dtype.md
Normal file
7
docs/api/validation/dtype.md
Normal file
|
@ -0,0 +1,7 @@
|
|||
# dtype
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: numpydantic.validation.dtype
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
6
docs/api/validation/index.md
Normal file
6
docs/api/validation/index.md
Normal file
|
@ -0,0 +1,6 @@
|
|||
# validation
|
||||
|
||||
```{toctree}
|
||||
dtype
|
||||
shape
|
||||
```
|
|
@ -1,7 +1,7 @@
|
|||
# shape
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: numpydantic.shape
|
||||
.. automodule:: numpydantic.validation.shape
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
|
@ -4,6 +4,33 @@
|
|||
|
||||
### 1.6.*
|
||||
|
||||
#### 1.6.1 - 24-09-23 - Support Union Dtypes
|
||||
|
||||
It's now possible to do this, like it always should have been
|
||||
|
||||
```python
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Any, int | float]
|
||||
```
|
||||
|
||||
**Features**
|
||||
- Support for Union Dtypes
|
||||
|
||||
**Structure**
|
||||
- New `validation` module containing `shape` and `dtype` convenience methods
|
||||
to declutter main namespace and make a grouping for related code
|
||||
- Rename all serialized arrays within a container dict to `value` to be able
|
||||
to identify them by convention and avoid long iteration - see perf below.
|
||||
|
||||
**Perf**
|
||||
- Avoid iterating over every item in an array trying to convert it to a path for
|
||||
a several order of magnitude perf improvement over `1.6.0` (oops)
|
||||
|
||||
**Docs**
|
||||
- Page for `dtypes`, mostly stubs at the moment, but more explicit documentation
|
||||
about what kind of dtypes we support.
|
||||
|
||||
|
||||
#### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization
|
||||
|
||||
Roundtrip JSON serialization is here - with serialization to list of lists,
|
||||
|
|
98
docs/dtype.md
Normal file
98
docs/dtype.md
Normal file
|
@ -0,0 +1,98 @@
|
|||
# dtype
|
||||
|
||||
```{todo}
|
||||
This section is under construction as of 1.6.1
|
||||
|
||||
Much of the details of dtypes are covered in [syntax](./syntax.md)
|
||||
and in {mod}`numpydantic.dtype` , but this section will specifically
|
||||
address how dtypes are handled both generically and by interfaces
|
||||
as we expand custom dtype handling <3.
|
||||
|
||||
For details of support and implementation until the docs have time for some love,
|
||||
please see the tests, which are the source of truth for the functionality
|
||||
of the library for now and forever.
|
||||
```
|
||||
|
||||
Recall the general syntax:
|
||||
|
||||
```
|
||||
NDArray[Shape, dtype]
|
||||
```
|
||||
|
||||
These are the docs for what can do in `dtype`.
|
||||
|
||||
## Scalar Dtypes
|
||||
|
||||
Python builtin types and numpy types should be handled transparently,
|
||||
with some exception for complex numbers and objects (described below).
|
||||
|
||||
### Numbers
|
||||
|
||||
#### Complex numbers
|
||||
|
||||
```{todo}
|
||||
Document limitations for complex numbers and strategies for serialization/validation
|
||||
```
|
||||
|
||||
### Datetimes
|
||||
|
||||
```{todo}
|
||||
Datetimes are supported by every interface except :class:`.VideoInterface` ,
|
||||
with the caveat that HDF5 loses timezone information, and thus all timestamps should
|
||||
be re-encoded to UTC before saving/loading.
|
||||
|
||||
More generic datetime support is TODO.
|
||||
```
|
||||
|
||||
### Objects
|
||||
|
||||
```{todo}
|
||||
Generic objects are supported by all interfaces except
|
||||
:class:`.VideoInterface` , :class;`.HDF5Interface` , and :class:`.ZarrInterface` .
|
||||
|
||||
this might be expected, but there is also hope, TODO fill in serialization plans.
|
||||
```
|
||||
|
||||
### Strings
|
||||
|
||||
```{todo}
|
||||
Strings are supported by all interfaces except :class:`.VideoInterface` .
|
||||
|
||||
TODO is fill in the subtleties of how this works
|
||||
```
|
||||
|
||||
## Generic Dtypes
|
||||
|
||||
```{todo}
|
||||
For now these are handled as tuples of dtypes, see the source of
|
||||
{ref}`numpydantic.dtype.Float` . They should either be handled as Unions
|
||||
or as a more prescribed meta-type.
|
||||
|
||||
For now, use `int` and `float` to refer to the general concepts of
|
||||
"any int" or "any float" even if this is a bit mismatched from the numpy usage.
|
||||
```
|
||||
|
||||
## Extended Python Typing Universe
|
||||
|
||||
### Union Types
|
||||
|
||||
Union types can be used as expected.
|
||||
|
||||
Union types are tested recursively -- if any item within a ``Union`` matches
|
||||
the expected dtype at a given level of recursion, the dtype test passes.
|
||||
|
||||
```python
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Any, int | float]
|
||||
```
|
||||
|
||||
## Compound Dtypes
|
||||
|
||||
```{todo}
|
||||
Compound dtypes are currently unsupported,
|
||||
though the HDF5 interface supports indexing into compound dtypes
|
||||
as separable dimensions/arrays using the third "field" parameter in
|
||||
{class}`.hdf5.H5ArrayPath` .
|
||||
```
|
||||
|
||||
|
|
@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
|
|||
|
||||
design
|
||||
syntax
|
||||
dtype
|
||||
serialization
|
||||
interfaces
|
||||
```
|
||||
|
@ -484,13 +485,13 @@ interfaces
|
|||
|
||||
api/index
|
||||
api/interface/index
|
||||
api/validation/index
|
||||
api/dtype
|
||||
api/ndarray
|
||||
api/maps
|
||||
api/meta
|
||||
api/schema
|
||||
api/serialization
|
||||
api/shape
|
||||
api/types
|
||||
|
||||
```
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "numpydantic"
|
||||
version = "1.6.0"
|
||||
version = "1.6.1"
|
||||
description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
|
||||
authors = [
|
||||
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
||||
|
@ -126,7 +126,7 @@ markers = [
|
|||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
target-version = "py39"
|
||||
include = ["src/numpydantic/**/*.py", "pyproject.toml"]
|
||||
exclude = ["tests"]
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
from numpydantic.ndarray import NDArray
|
||||
from numpydantic.meta import update_ndarray_stub
|
||||
from numpydantic.shape import Shape
|
||||
from numpydantic.validation.shape import Shape
|
||||
|
||||
update_ndarray_stub()
|
||||
|
||||
|
|
|
@ -12,6 +12,9 @@ module, rather than needing to import each individually.
|
|||
Some types like `Integer` are compound types - tuples of multiple dtypes.
|
||||
Check these using ``in`` rather than ``==``. This interface will develop in future
|
||||
versions to allow a single dtype check.
|
||||
|
||||
For internal helper functions for validating dtype,
|
||||
see :mod:`numpydantic.validation.dtype`
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
|
|
@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict):
|
|||
name: str
|
||||
chunks: Iterable[tuple[int, ...]]
|
||||
dtype: str
|
||||
array: list
|
||||
value: list
|
||||
|
||||
def to_array_input(self) -> DaskArray:
|
||||
"""Construct a dask array"""
|
||||
np_array = np.array(self.array, dtype=self.dtype)
|
||||
np_array = np.array(self.value, dtype=self.dtype)
|
||||
array = from_array(
|
||||
np_array,
|
||||
name=self.name,
|
||||
|
@ -100,7 +100,7 @@ class DaskInterface(Interface):
|
|||
if info.round_trip:
|
||||
as_json = DaskJsonDict(
|
||||
type=cls.name,
|
||||
array=as_json,
|
||||
value=as_json,
|
||||
name=array.name,
|
||||
chunks=array.chunks,
|
||||
dtype=str(np_array.dtype),
|
||||
|
|
|
@ -20,8 +20,8 @@ from numpydantic.exceptions import (
|
|||
ShapeError,
|
||||
TooManyMatchesError,
|
||||
)
|
||||
from numpydantic.shape import check_shape
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
from numpydantic.validation import validate_dtype, validate_shape
|
||||
|
||||
T = TypeVar("T", bound=NDArrayType)
|
||||
U = TypeVar("U", bound="JsonDict")
|
||||
|
@ -76,6 +76,21 @@ class InterfaceMark(BaseModel):
|
|||
class JsonDict(BaseModel):
|
||||
"""
|
||||
Representation of array when dumped with round_trip == True.
|
||||
|
||||
.. admonition:: Developer's Note
|
||||
|
||||
Any JsonDict that contains an actual array should be named ``value``
|
||||
rather than array (or any other name), and nothing but the
|
||||
array data should be named ``value`` .
|
||||
|
||||
During JSON serialization, it becomes ambiguous what contains an array
|
||||
of data vs. an array of metadata. For the moment we would like to
|
||||
reserve the ability to have lists of metadata, so until we rule that out,
|
||||
we would like to be able to avoid iterating over every element of an array
|
||||
in any context parameter transformation like relativizing/absolutizing paths.
|
||||
To avoid that, it's good to agree on a single value name -- ``value`` --
|
||||
and avoid using it for anything else.
|
||||
|
||||
"""
|
||||
|
||||
type: str
|
||||
|
@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]):
|
|||
Validate the dtype of the given array, returning
|
||||
``True`` if valid, ``False`` if not.
|
||||
"""
|
||||
if self.dtype is Any:
|
||||
return True
|
||||
|
||||
if isinstance(self.dtype, tuple):
|
||||
valid = dtype in self.dtype
|
||||
elif self.dtype is np.str_:
|
||||
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
|
||||
np.str_,
|
||||
str,
|
||||
)
|
||||
else:
|
||||
# try to match as any subclass, if self.dtype is a class
|
||||
try:
|
||||
valid = issubclass(dtype, self.dtype)
|
||||
except TypeError:
|
||||
# expected, if dtype or self.dtype is not a class
|
||||
valid = dtype == self.dtype
|
||||
|
||||
return valid
|
||||
return validate_dtype(dtype, self.dtype)
|
||||
|
||||
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
||||
"""
|
||||
|
@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]):
|
|||
if self.shape is Any:
|
||||
return True
|
||||
|
||||
return check_shape(shape, self.shape)
|
||||
return validate_shape(shape, self.shape)
|
||||
|
||||
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
|
||||
"""
|
||||
|
|
|
@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict):
|
|||
|
||||
type: Literal["numpy"]
|
||||
dtype: str
|
||||
array: list
|
||||
value: list
|
||||
|
||||
def to_array_input(self) -> ndarray:
|
||||
"""
|
||||
Construct a numpy array
|
||||
"""
|
||||
return np.array(self.array, dtype=self.dtype)
|
||||
return np.array(self.value, dtype=self.dtype)
|
||||
|
||||
|
||||
class NumpyInterface(Interface):
|
||||
|
@ -99,6 +99,6 @@ class NumpyInterface(Interface):
|
|||
|
||||
if info.round_trip:
|
||||
json_array = NumpyJsonDict(
|
||||
type=cls.name, dtype=str(array.dtype), array=json_array
|
||||
type=cls.name, dtype=str(array.dtype), value=json_array
|
||||
)
|
||||
return json_array
|
||||
|
|
|
@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict):
|
|||
type: Literal["zarr"]
|
||||
file: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
array: Optional[list] = None
|
||||
value: Optional[list] = None
|
||||
|
||||
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
||||
"""
|
||||
|
@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict):
|
|||
if self.file:
|
||||
array = ZarrArrayPath(file=self.file, path=self.path)
|
||||
else:
|
||||
array = zarr.array(self.array)
|
||||
array = zarr.array(self.value)
|
||||
return array
|
||||
|
||||
|
||||
|
@ -202,7 +202,7 @@ class ZarrInterface(Interface):
|
|||
as_json["info"]["hexdigest"] = array.hexdigest()
|
||||
|
||||
if dump_array or not is_file:
|
||||
as_json["array"] = array[:].tolist()
|
||||
as_json["value"] = array[:].tolist()
|
||||
|
||||
as_json = ZarrJsonDict(**as_json)
|
||||
else:
|
||||
|
|
|
@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
|
|||
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin
|
||||
|
||||
import numpy as np
|
||||
from pydantic import GetJsonSchemaHandler
|
||||
|
@ -29,6 +29,7 @@ from numpydantic.schema import (
|
|||
)
|
||||
from numpydantic.serialization import jsonize_array
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
from numpydantic.validation.dtype import is_union
|
||||
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
||||
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
|
||||
|
@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
|||
except InterfaceError:
|
||||
return False
|
||||
|
||||
def _is_literal_like(cls, item: Any) -> bool:
|
||||
"""
|
||||
Changes from nptyping:
|
||||
- doesn't just ducktype for literal but actually, yno, checks for being literal
|
||||
"""
|
||||
return get_origin(item) is Literal
|
||||
|
||||
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
|
||||
"""
|
||||
Override of base method to use our local definition of shape
|
||||
"""
|
||||
from numpydantic.shape import Shape
|
||||
from numpydantic.validation.shape import Shape
|
||||
|
||||
if dtype_candidate is Any or dtype_candidate is Shape:
|
||||
shape = Any
|
||||
|
@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
|||
|
||||
if dtype_candidate is Any:
|
||||
dtype = Any
|
||||
elif is_dtype:
|
||||
elif is_dtype or is_union(dtype_candidate):
|
||||
dtype = dtype_candidate
|
||||
elif issubclass(dtype_candidate, Structure): # pragma: no cover
|
||||
dtype = dtype_candidate
|
||||
|
|
|
@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
|
@ -16,6 +16,7 @@ from numpydantic import dtype as dt
|
|||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import np_to_python
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
from numpydantic.validation.dtype import is_union
|
||||
from numpydantic.vendor.nptyping.structure import StructureMeta
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
|
@ -51,7 +52,6 @@ def _lol_dtype(
|
|||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||
if isinstance(dtype, StructureMeta): # pragma: no cover
|
||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||
|
||||
if isinstance(dtype, tuple):
|
||||
# if it's a meta-type that refers to a generic float/int, just make that
|
||||
if dtype in (dt.Float, dt.Number):
|
||||
|
@ -66,7 +66,10 @@ def _lol_dtype(
|
|||
array_type = core_schema.union_schema(
|
||||
[_lol_dtype(t, _handler) for t in types_]
|
||||
)
|
||||
|
||||
elif is_union(dtype):
|
||||
array_type = core_schema.union_schema(
|
||||
[_lol_dtype(t, _handler) for t in get_args(dtype)]
|
||||
)
|
||||
else:
|
||||
try:
|
||||
python_type = np_to_python[dtype]
|
||||
|
@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
|
|||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||
core schema to use in the innermost list entry
|
||||
"""
|
||||
from numpydantic.shape import _is_range
|
||||
from numpydantic.validation.shape import _is_range
|
||||
|
||||
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
||||
# labels, if present
|
||||
|
|
|
@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` .
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
from typing import Any, Callable, Iterable, TypeVar, Union
|
||||
|
||||
from pydantic_core.core_schema import SerializationInfo
|
||||
|
||||
|
@ -16,6 +16,9 @@ U = TypeVar("U")
|
|||
|
||||
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||
"""Use an interface class to render an array as JSON"""
|
||||
# perf: keys to skip in generation - anything named "value" is array data.
|
||||
skip = ["value"]
|
||||
|
||||
interface_cls = Interface.match_output(value)
|
||||
array = interface_cls.to_json(value, info)
|
||||
if isinstance(array, JsonDict):
|
||||
|
@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
|||
if info.context.get("mark_interface", False):
|
||||
array = interface_cls.mark_json(array)
|
||||
|
||||
if isinstance(array, list):
|
||||
return array
|
||||
|
||||
# ---- Perf Barrier ------------------------------------------------------
|
||||
# put context args intended to **wrap** the array above
|
||||
# put context args intended to **modify** the array below
|
||||
#
|
||||
# above, we assume that a list is **data** not to be modified.
|
||||
# below, we must mark whenever the data is in the line of fire
|
||||
# to avoid an expensive iteration.
|
||||
|
||||
if info.context.get("absolute_paths", False):
|
||||
array = _absolutize_paths(array)
|
||||
array = _absolutize_paths(array, skip)
|
||||
else:
|
||||
relative_to = info.context.get("relative_to", ".")
|
||||
array = _relativize_paths(array, relative_to)
|
||||
array = _relativize_paths(array, relative_to, skip)
|
||||
else:
|
||||
# relativize paths by default
|
||||
array = _relativize_paths(array, ".")
|
||||
if isinstance(array, list):
|
||||
return array
|
||||
|
||||
# ---- Perf Barrier ------------------------------------------------------
|
||||
# same as above, ensure any keys that contain array values are skipped right now
|
||||
|
||||
array = _relativize_paths(array, ".", skip)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
||||
def _relativize_paths(
|
||||
value: dict, relative_to: str = ".", skip: Iterable = tuple()
|
||||
) -> dict:
|
||||
"""
|
||||
Make paths relative to either the current directory or the provided
|
||||
``relative_to`` directory, if provided in the context
|
||||
|
@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
|||
# pdb.set_trace()
|
||||
|
||||
def _r_path(v: Any) -> Any:
|
||||
if not isinstance(v, (str, Path)):
|
||||
return v
|
||||
try:
|
||||
path = Path(v)
|
||||
if not path.exists():
|
||||
|
@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
|||
except (TypeError, ValueError):
|
||||
return v
|
||||
|
||||
return _walk_and_apply(value, _r_path)
|
||||
return _walk_and_apply(value, _r_path, skip)
|
||||
|
||||
|
||||
def _absolutize_paths(value: dict) -> dict:
|
||||
def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict:
|
||||
def _a_path(v: Any) -> Any:
|
||||
try:
|
||||
path = Path(v)
|
||||
|
@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict:
|
|||
except (TypeError, ValueError):
|
||||
return v
|
||||
|
||||
return _walk_and_apply(value, _a_path)
|
||||
return _walk_and_apply(value, _a_path, skip)
|
||||
|
||||
|
||||
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T:
|
||||
def _walk_and_apply(value: T, f: Callable[[U, bool], U], skip: Iterable = tuple()) -> T:
|
||||
"""
|
||||
Walk an object, applying a function
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
if k in skip:
|
||||
continue
|
||||
if isinstance(v, dict):
|
||||
_walk_and_apply(v, f)
|
||||
_walk_and_apply(v, f, skip)
|
||||
elif isinstance(v, list):
|
||||
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v]
|
||||
value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v]
|
||||
else:
|
||||
value[k] = f(v)
|
||||
elif isinstance(value, list):
|
||||
value = [_walk_and_apply(v, f) for v in value]
|
||||
value = [_walk_and_apply(v, f, skip) for v in value]
|
||||
else:
|
||||
value = f(value)
|
||||
return value
|
||||
|
|
11
src/numpydantic/validation/__init__.py
Normal file
11
src/numpydantic/validation/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
"""
|
||||
Helper functions for validation
|
||||
"""
|
||||
|
||||
from numpydantic.validation.dtype import validate_dtype
|
||||
from numpydantic.validation.shape import validate_shape
|
||||
|
||||
__all__ = [
|
||||
"validate_dtype",
|
||||
"validate_shape",
|
||||
]
|
63
src/numpydantic/validation/dtype.py
Normal file
63
src/numpydantic/validation/dtype.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
"""
|
||||
Helper functions for validation of dtype.
|
||||
|
||||
For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
|
||||
import numpy as np
|
||||
|
||||
from numpydantic.types import DtypeType
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from types import UnionType
|
||||
else:
|
||||
UnionType = None
|
||||
|
||||
|
||||
def validate_dtype(dtype: Any, target: DtypeType) -> bool:
|
||||
"""
|
||||
Validate a dtype against the target dtype
|
||||
|
||||
Args:
|
||||
dtype: The dtype to validate
|
||||
target (:class:`.DtypeType`): The target dtype
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if valid, ``False`` otherwise
|
||||
"""
|
||||
if target is Any:
|
||||
return True
|
||||
|
||||
if isinstance(target, tuple):
|
||||
valid = dtype in target
|
||||
elif is_union(target):
|
||||
valid = any(
|
||||
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
|
||||
)
|
||||
elif target is np.str_:
|
||||
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
|
||||
np.str_,
|
||||
str,
|
||||
)
|
||||
else:
|
||||
# try to match as any subclass, if target is a class
|
||||
try:
|
||||
valid = issubclass(dtype, target)
|
||||
except TypeError:
|
||||
# expected, if dtype or target is not a class
|
||||
valid = dtype == target
|
||||
|
||||
return valid
|
||||
|
||||
|
||||
def is_union(dtype: DtypeType) -> bool:
|
||||
"""
|
||||
Check if a dtype is a union
|
||||
"""
|
||||
if UnionType is None:
|
||||
return get_origin(dtype) is Union
|
||||
else:
|
||||
return get_origin(dtype) in (Union, UnionType)
|
|
@ -2,7 +2,7 @@
|
|||
Declaration and validation functions for array shapes.
|
||||
|
||||
Mostly a mildly modified version of nptyping's
|
||||
:func:`npytping.shape_expression.check_shape`
|
||||
:func:`npytping.shape_expression.validate_shape`
|
||||
and its internals to allow for extended syntax, including ranges of shapes.
|
||||
|
||||
Modifications from nptyping:
|
||||
|
@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) ->
|
|||
|
||||
|
||||
@lru_cache
|
||||
def check_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
||||
def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
||||
"""
|
||||
Check whether the given shape corresponds to the given shape_expression.
|
||||
:param shape: the shape in question.
|
|
@ -3,10 +3,6 @@ import sys
|
|||
import pytest
|
||||
from typing import Any, Tuple, Union, Type
|
||||
|
||||
if sys.version_info.minor >= 10:
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
from pydantic import BaseModel, computed_field, ConfigDict
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.ndarray import NDArrayMeta
|
||||
|
@ -15,6 +11,15 @@ import numpy as np
|
|||
|
||||
from tests.fixtures import *
|
||||
|
||||
if sys.version_info.minor >= 10:
|
||||
from typing import TypeAlias
|
||||
|
||||
YES_PIPE = True
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
YES_PIPE = False
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
|
@ -80,6 +85,9 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
|||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
|
||||
if YES_PIPE:
|
||||
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
|
@ -119,9 +127,7 @@ def shape_cases(request) -> ValidationCase:
|
|||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
params=[
|
||||
DTYPE_CASES = [
|
||||
ValidationCase(dtype=float, passes=True),
|
||||
ValidationCase(dtype=int, passes=False),
|
||||
ValidationCase(dtype=np.uint8, passes=False),
|
||||
|
@ -147,8 +153,14 @@ def shape_cases(request) -> ValidationCase:
|
|||
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
||||
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
||||
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
|
||||
],
|
||||
ids=[
|
||||
ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True),
|
||||
ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True),
|
||||
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False),
|
||||
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False),
|
||||
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False),
|
||||
]
|
||||
|
||||
DTYPE_IDS = [
|
||||
"float",
|
||||
"int",
|
||||
"uint8",
|
||||
|
@ -174,7 +186,34 @@ def shape_cases(request) -> ValidationCase:
|
|||
"model-badmodel",
|
||||
"model-int",
|
||||
"model-subclass",
|
||||
],
|
||||
)
|
||||
"union-type-uint32",
|
||||
"union-type-float32",
|
||||
"union-type-uint64",
|
||||
"union-type-float64",
|
||||
"union-type-str",
|
||||
]
|
||||
|
||||
if YES_PIPE:
|
||||
DTYPE_CASES.extend(
|
||||
[
|
||||
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True),
|
||||
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True),
|
||||
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False),
|
||||
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False),
|
||||
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False),
|
||||
]
|
||||
)
|
||||
DTYPE_IDS.extend(
|
||||
[
|
||||
"union-pipe-uint32",
|
||||
"union-pipe-float32",
|
||||
"union-pipe-uint64",
|
||||
"union-pipe-float64",
|
||||
"union-pipe-str",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=DTYPE_CASES, ids=DTYPE_IDS)
|
||||
def dtype_cases(request) -> ValidationCase:
|
||||
return request.param
|
||||
|
|
|
@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
|
|||
|
||||
if roundtrip:
|
||||
if dump_array:
|
||||
assert as_json["array"] == lol_array
|
||||
assert as_json["value"] == lol_array
|
||||
else:
|
||||
if as_json.get("file", False):
|
||||
assert "array" not in as_json
|
||||
|
|
|
@ -10,6 +10,8 @@ from typing import Callable
|
|||
import numpy as np
|
||||
import json
|
||||
|
||||
from numpydantic.serialization import _walk_and_apply
|
||||
|
||||
pytestmark = pytest.mark.serialization
|
||||
|
||||
|
||||
|
@ -93,3 +95,34 @@ def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank):
|
|||
|
||||
# shouldn't have absolutized subpath even if it's pathlike
|
||||
assert data["path"] == expected_dataset
|
||||
|
||||
|
||||
def test_walk_and_apply():
|
||||
"""
|
||||
Walk and apply should recursively apply a function to everything in a nesty structure
|
||||
"""
|
||||
test = {
|
||||
"a": 1,
|
||||
"b": 1,
|
||||
"c": [
|
||||
{"a": 1, "b": {"a": 1, "b": 1}, "c": [1, 1, 1]},
|
||||
{"a": 1, "b": [1, 1, 1]},
|
||||
],
|
||||
}
|
||||
|
||||
def _mult_2(v, skip: bool = False):
|
||||
return v * 2
|
||||
|
||||
def _assert_2(v, skip: bool = False):
|
||||
assert v == 2
|
||||
return v
|
||||
|
||||
walked = _walk_and_apply(test, _mult_2)
|
||||
_walk_and_apply(walked, _assert_2)
|
||||
|
||||
assert walked["a"] == 2
|
||||
assert walked["c"][0]["a"] == 2
|
||||
assert walked["c"][0]["b"]["a"] == 2
|
||||
assert all([w == 2 for w in walked["c"][0]["c"]])
|
||||
assert walked["c"][1]["a"] == 2
|
||||
assert all([w == 2 for w in walked["c"][1]["b"]])
|
||||
|
|
Loading…
Reference in a new issue