mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 13:44:26 +00:00
Merge pull request #24 from p2p-ld/dtype-union
Some checks are pending
Lint / Ruff Linting (push) Waiting to run
Lint / Black Formatting (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.10) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / finish-coverage (push) Blocked by required conditions
Some checks are pending
Lint / Ruff Linting (push) Waiting to run
Lint / Black Formatting (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.10) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / finish-coverage (push) Blocked by required conditions
[dtype] Support Unions
This commit is contained in:
commit
1cf69eb18c
23 changed files with 443 additions and 120 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -159,4 +159,6 @@ cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
.pdm-python
|
.pdm-python
|
||||||
ndarray.pyi
|
ndarray.pyi
|
||||||
|
|
||||||
|
prof/
|
7
docs/api/validation/dtype.md
Normal file
7
docs/api/validation/dtype.md
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# dtype
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. automodule:: numpydantic.validation.dtype
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
6
docs/api/validation/index.md
Normal file
6
docs/api/validation/index.md
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
# validation
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
dtype
|
||||||
|
shape
|
||||||
|
```
|
|
@ -1,7 +1,7 @@
|
||||||
# shape
|
# shape
|
||||||
|
|
||||||
```{eval-rst}
|
```{eval-rst}
|
||||||
.. automodule:: numpydantic.shape
|
.. automodule:: numpydantic.validation.shape
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
```
|
```
|
|
@ -4,6 +4,33 @@
|
||||||
|
|
||||||
### 1.6.*
|
### 1.6.*
|
||||||
|
|
||||||
|
#### 1.6.1 - 24-09-23 - Support Union Dtypes
|
||||||
|
|
||||||
|
It's now possible to do this, like it always should have been
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray[Any, int | float]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Features**
|
||||||
|
- Support for Union Dtypes
|
||||||
|
|
||||||
|
**Structure**
|
||||||
|
- New `validation` module containing `shape` and `dtype` convenience methods
|
||||||
|
to declutter main namespace and make a grouping for related code
|
||||||
|
- Rename all serialized arrays within a container dict to `value` to be able
|
||||||
|
to identify them by convention and avoid long iteration - see perf below.
|
||||||
|
|
||||||
|
**Perf**
|
||||||
|
- Avoid iterating over every item in an array trying to convert it to a path for
|
||||||
|
a several order of magnitude perf improvement over `1.6.0` (oops)
|
||||||
|
|
||||||
|
**Docs**
|
||||||
|
- Page for `dtypes`, mostly stubs at the moment, but more explicit documentation
|
||||||
|
about what kind of dtypes we support.
|
||||||
|
|
||||||
|
|
||||||
#### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization
|
#### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization
|
||||||
|
|
||||||
Roundtrip JSON serialization is here - with serialization to list of lists,
|
Roundtrip JSON serialization is here - with serialization to list of lists,
|
||||||
|
|
98
docs/dtype.md
Normal file
98
docs/dtype.md
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# dtype
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
This section is under construction as of 1.6.1
|
||||||
|
|
||||||
|
Much of the details of dtypes are covered in [syntax](./syntax.md)
|
||||||
|
and in {mod}`numpydantic.dtype` , but this section will specifically
|
||||||
|
address how dtypes are handled both generically and by interfaces
|
||||||
|
as we expand custom dtype handling <3.
|
||||||
|
|
||||||
|
For details of support and implementation until the docs have time for some love,
|
||||||
|
please see the tests, which are the source of truth for the functionality
|
||||||
|
of the library for now and forever.
|
||||||
|
```
|
||||||
|
|
||||||
|
Recall the general syntax:
|
||||||
|
|
||||||
|
```
|
||||||
|
NDArray[Shape, dtype]
|
||||||
|
```
|
||||||
|
|
||||||
|
These are the docs for what can do in `dtype`.
|
||||||
|
|
||||||
|
## Scalar Dtypes
|
||||||
|
|
||||||
|
Python builtin types and numpy types should be handled transparently,
|
||||||
|
with some exception for complex numbers and objects (described below).
|
||||||
|
|
||||||
|
### Numbers
|
||||||
|
|
||||||
|
#### Complex numbers
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Document limitations for complex numbers and strategies for serialization/validation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Datetimes
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Datetimes are supported by every interface except :class:`.VideoInterface` ,
|
||||||
|
with the caveat that HDF5 loses timezone information, and thus all timestamps should
|
||||||
|
be re-encoded to UTC before saving/loading.
|
||||||
|
|
||||||
|
More generic datetime support is TODO.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Objects
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Generic objects are supported by all interfaces except
|
||||||
|
:class:`.VideoInterface` , :class;`.HDF5Interface` , and :class:`.ZarrInterface` .
|
||||||
|
|
||||||
|
this might be expected, but there is also hope, TODO fill in serialization plans.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Strings
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Strings are supported by all interfaces except :class:`.VideoInterface` .
|
||||||
|
|
||||||
|
TODO is fill in the subtleties of how this works
|
||||||
|
```
|
||||||
|
|
||||||
|
## Generic Dtypes
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
For now these are handled as tuples of dtypes, see the source of
|
||||||
|
{ref}`numpydantic.dtype.Float` . They should either be handled as Unions
|
||||||
|
or as a more prescribed meta-type.
|
||||||
|
|
||||||
|
For now, use `int` and `float` to refer to the general concepts of
|
||||||
|
"any int" or "any float" even if this is a bit mismatched from the numpy usage.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Extended Python Typing Universe
|
||||||
|
|
||||||
|
### Union Types
|
||||||
|
|
||||||
|
Union types can be used as expected.
|
||||||
|
|
||||||
|
Union types are tested recursively -- if any item within a ``Union`` matches
|
||||||
|
the expected dtype at a given level of recursion, the dtype test passes.
|
||||||
|
|
||||||
|
```python
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray[Any, int | float]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Compound Dtypes
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Compound dtypes are currently unsupported,
|
||||||
|
though the HDF5 interface supports indexing into compound dtypes
|
||||||
|
as separable dimensions/arrays using the third "field" parameter in
|
||||||
|
{class}`.hdf5.H5ArrayPath` .
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
|
||||||
|
|
||||||
design
|
design
|
||||||
syntax
|
syntax
|
||||||
|
dtype
|
||||||
serialization
|
serialization
|
||||||
interfaces
|
interfaces
|
||||||
```
|
```
|
||||||
|
@ -484,13 +485,13 @@ interfaces
|
||||||
|
|
||||||
api/index
|
api/index
|
||||||
api/interface/index
|
api/interface/index
|
||||||
|
api/validation/index
|
||||||
api/dtype
|
api/dtype
|
||||||
api/ndarray
|
api/ndarray
|
||||||
api/maps
|
api/maps
|
||||||
api/meta
|
api/meta
|
||||||
api/schema
|
api/schema
|
||||||
api/serialization
|
api/serialization
|
||||||
api/shape
|
|
||||||
api/types
|
api/types
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "numpydantic"
|
name = "numpydantic"
|
||||||
version = "1.6.0"
|
version = "1.6.1"
|
||||||
description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
|
description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
||||||
|
@ -126,7 +126,7 @@ markers = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py311"
|
target-version = "py39"
|
||||||
include = ["src/numpydantic/**/*.py", "pyproject.toml"]
|
include = ["src/numpydantic/**/*.py", "pyproject.toml"]
|
||||||
exclude = ["tests"]
|
exclude = ["tests"]
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
from numpydantic.ndarray import NDArray
|
from numpydantic.ndarray import NDArray
|
||||||
from numpydantic.meta import update_ndarray_stub
|
from numpydantic.meta import update_ndarray_stub
|
||||||
from numpydantic.shape import Shape
|
from numpydantic.validation.shape import Shape
|
||||||
|
|
||||||
update_ndarray_stub()
|
update_ndarray_stub()
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,9 @@ module, rather than needing to import each individually.
|
||||||
Some types like `Integer` are compound types - tuples of multiple dtypes.
|
Some types like `Integer` are compound types - tuples of multiple dtypes.
|
||||||
Check these using ``in`` rather than ``==``. This interface will develop in future
|
Check these using ``in`` rather than ``==``. This interface will develop in future
|
||||||
versions to allow a single dtype check.
|
versions to allow a single dtype check.
|
||||||
|
|
||||||
|
For internal helper functions for validating dtype,
|
||||||
|
see :mod:`numpydantic.validation.dtype`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
|
@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict):
|
||||||
name: str
|
name: str
|
||||||
chunks: Iterable[tuple[int, ...]]
|
chunks: Iterable[tuple[int, ...]]
|
||||||
dtype: str
|
dtype: str
|
||||||
array: list
|
value: list
|
||||||
|
|
||||||
def to_array_input(self) -> DaskArray:
|
def to_array_input(self) -> DaskArray:
|
||||||
"""Construct a dask array"""
|
"""Construct a dask array"""
|
||||||
np_array = np.array(self.array, dtype=self.dtype)
|
np_array = np.array(self.value, dtype=self.dtype)
|
||||||
array = from_array(
|
array = from_array(
|
||||||
np_array,
|
np_array,
|
||||||
name=self.name,
|
name=self.name,
|
||||||
|
@ -100,7 +100,7 @@ class DaskInterface(Interface):
|
||||||
if info.round_trip:
|
if info.round_trip:
|
||||||
as_json = DaskJsonDict(
|
as_json = DaskJsonDict(
|
||||||
type=cls.name,
|
type=cls.name,
|
||||||
array=as_json,
|
value=as_json,
|
||||||
name=array.name,
|
name=array.name,
|
||||||
chunks=array.chunks,
|
chunks=array.chunks,
|
||||||
dtype=str(np_array.dtype),
|
dtype=str(np_array.dtype),
|
||||||
|
|
|
@ -20,8 +20,8 @@ from numpydantic.exceptions import (
|
||||||
ShapeError,
|
ShapeError,
|
||||||
TooManyMatchesError,
|
TooManyMatchesError,
|
||||||
)
|
)
|
||||||
from numpydantic.shape import check_shape
|
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
from numpydantic.validation import validate_dtype, validate_shape
|
||||||
|
|
||||||
T = TypeVar("T", bound=NDArrayType)
|
T = TypeVar("T", bound=NDArrayType)
|
||||||
U = TypeVar("U", bound="JsonDict")
|
U = TypeVar("U", bound="JsonDict")
|
||||||
|
@ -76,6 +76,21 @@ class InterfaceMark(BaseModel):
|
||||||
class JsonDict(BaseModel):
|
class JsonDict(BaseModel):
|
||||||
"""
|
"""
|
||||||
Representation of array when dumped with round_trip == True.
|
Representation of array when dumped with round_trip == True.
|
||||||
|
|
||||||
|
.. admonition:: Developer's Note
|
||||||
|
|
||||||
|
Any JsonDict that contains an actual array should be named ``value``
|
||||||
|
rather than array (or any other name), and nothing but the
|
||||||
|
array data should be named ``value`` .
|
||||||
|
|
||||||
|
During JSON serialization, it becomes ambiguous what contains an array
|
||||||
|
of data vs. an array of metadata. For the moment we would like to
|
||||||
|
reserve the ability to have lists of metadata, so until we rule that out,
|
||||||
|
we would like to be able to avoid iterating over every element of an array
|
||||||
|
in any context parameter transformation like relativizing/absolutizing paths.
|
||||||
|
To avoid that, it's good to agree on a single value name -- ``value`` --
|
||||||
|
and avoid using it for anything else.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
|
@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]):
|
||||||
Validate the dtype of the given array, returning
|
Validate the dtype of the given array, returning
|
||||||
``True`` if valid, ``False`` if not.
|
``True`` if valid, ``False`` if not.
|
||||||
"""
|
"""
|
||||||
if self.dtype is Any:
|
return validate_dtype(dtype, self.dtype)
|
||||||
return True
|
|
||||||
|
|
||||||
if isinstance(self.dtype, tuple):
|
|
||||||
valid = dtype in self.dtype
|
|
||||||
elif self.dtype is np.str_:
|
|
||||||
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
|
|
||||||
np.str_,
|
|
||||||
str,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# try to match as any subclass, if self.dtype is a class
|
|
||||||
try:
|
|
||||||
valid = issubclass(dtype, self.dtype)
|
|
||||||
except TypeError:
|
|
||||||
# expected, if dtype or self.dtype is not a class
|
|
||||||
valid = dtype == self.dtype
|
|
||||||
|
|
||||||
return valid
|
|
||||||
|
|
||||||
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]):
|
||||||
if self.shape is Any:
|
if self.shape is Any:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return check_shape(shape, self.shape)
|
return validate_shape(shape, self.shape)
|
||||||
|
|
||||||
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
|
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict):
|
||||||
|
|
||||||
type: Literal["numpy"]
|
type: Literal["numpy"]
|
||||||
dtype: str
|
dtype: str
|
||||||
array: list
|
value: list
|
||||||
|
|
||||||
def to_array_input(self) -> ndarray:
|
def to_array_input(self) -> ndarray:
|
||||||
"""
|
"""
|
||||||
Construct a numpy array
|
Construct a numpy array
|
||||||
"""
|
"""
|
||||||
return np.array(self.array, dtype=self.dtype)
|
return np.array(self.value, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class NumpyInterface(Interface):
|
class NumpyInterface(Interface):
|
||||||
|
@ -99,6 +99,6 @@ class NumpyInterface(Interface):
|
||||||
|
|
||||||
if info.round_trip:
|
if info.round_trip:
|
||||||
json_array = NumpyJsonDict(
|
json_array = NumpyJsonDict(
|
||||||
type=cls.name, dtype=str(array.dtype), array=json_array
|
type=cls.name, dtype=str(array.dtype), value=json_array
|
||||||
)
|
)
|
||||||
return json_array
|
return json_array
|
||||||
|
|
|
@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict):
|
||||||
type: Literal["zarr"]
|
type: Literal["zarr"]
|
||||||
file: Optional[str] = None
|
file: Optional[str] = None
|
||||||
path: Optional[str] = None
|
path: Optional[str] = None
|
||||||
array: Optional[list] = None
|
value: Optional[list] = None
|
||||||
|
|
||||||
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
||||||
"""
|
"""
|
||||||
|
@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict):
|
||||||
if self.file:
|
if self.file:
|
||||||
array = ZarrArrayPath(file=self.file, path=self.path)
|
array = ZarrArrayPath(file=self.file, path=self.path)
|
||||||
else:
|
else:
|
||||||
array = zarr.array(self.array)
|
array = zarr.array(self.value)
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
@ -202,7 +202,7 @@ class ZarrInterface(Interface):
|
||||||
as_json["info"]["hexdigest"] = array.hexdigest()
|
as_json["info"]["hexdigest"] = array.hexdigest()
|
||||||
|
|
||||||
if dump_array or not is_file:
|
if dump_array or not is_file:
|
||||||
as_json["array"] = array[:].tolist()
|
as_json["value"] = array[:].tolist()
|
||||||
|
|
||||||
as_json = ZarrJsonDict(**as_json)
|
as_json = ZarrJsonDict(**as_json)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Tuple
|
from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import GetJsonSchemaHandler
|
from pydantic import GetJsonSchemaHandler
|
||||||
|
@ -29,6 +29,7 @@ from numpydantic.schema import (
|
||||||
)
|
)
|
||||||
from numpydantic.serialization import jsonize_array
|
from numpydantic.serialization import jsonize_array
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
from numpydantic.validation.dtype import is_union
|
||||||
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
||||||
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||||
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
|
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
|
||||||
|
@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
except InterfaceError:
|
except InterfaceError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _is_literal_like(cls, item: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Changes from nptyping:
|
||||||
|
- doesn't just ducktype for literal but actually, yno, checks for being literal
|
||||||
|
"""
|
||||||
|
return get_origin(item) is Literal
|
||||||
|
|
||||||
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
|
def _get_shape(cls, dtype_candidate: Any) -> "Shape":
|
||||||
"""
|
"""
|
||||||
Override of base method to use our local definition of shape
|
Override of base method to use our local definition of shape
|
||||||
"""
|
"""
|
||||||
from numpydantic.shape import Shape
|
from numpydantic.validation.shape import Shape
|
||||||
|
|
||||||
if dtype_candidate is Any or dtype_candidate is Shape:
|
if dtype_candidate is Any or dtype_candidate is Shape:
|
||||||
shape = Any
|
shape = Any
|
||||||
|
@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
|
|
||||||
if dtype_candidate is Any:
|
if dtype_candidate is Any:
|
||||||
dtype = Any
|
dtype = Any
|
||||||
elif is_dtype:
|
elif is_dtype or is_union(dtype_candidate):
|
||||||
dtype = dtype_candidate
|
dtype = dtype_candidate
|
||||||
elif issubclass(dtype_candidate, Structure): # pragma: no cover
|
elif issubclass(dtype_candidate, Structure): # pragma: no cover
|
||||||
dtype = dtype_candidate
|
dtype = dtype_candidate
|
||||||
|
|
|
@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -16,6 +16,7 @@ from numpydantic import dtype as dt
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.maps import np_to_python
|
from numpydantic.maps import np_to_python
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
from numpydantic.validation.dtype import is_union
|
||||||
from numpydantic.vendor.nptyping.structure import StructureMeta
|
from numpydantic.vendor.nptyping.structure import StructureMeta
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
@ -51,7 +52,6 @@ def _lol_dtype(
|
||||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||||
if isinstance(dtype, StructureMeta): # pragma: no cover
|
if isinstance(dtype, StructureMeta): # pragma: no cover
|
||||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||||
|
|
||||||
if isinstance(dtype, tuple):
|
if isinstance(dtype, tuple):
|
||||||
# if it's a meta-type that refers to a generic float/int, just make that
|
# if it's a meta-type that refers to a generic float/int, just make that
|
||||||
if dtype in (dt.Float, dt.Number):
|
if dtype in (dt.Float, dt.Number):
|
||||||
|
@ -66,7 +66,10 @@ def _lol_dtype(
|
||||||
array_type = core_schema.union_schema(
|
array_type = core_schema.union_schema(
|
||||||
[_lol_dtype(t, _handler) for t in types_]
|
[_lol_dtype(t, _handler) for t in types_]
|
||||||
)
|
)
|
||||||
|
elif is_union(dtype):
|
||||||
|
array_type = core_schema.union_schema(
|
||||||
|
[_lol_dtype(t, _handler) for t in get_args(dtype)]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
python_type = np_to_python[dtype]
|
python_type = np_to_python[dtype]
|
||||||
|
@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
|
||||||
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
|
||||||
core schema to use in the innermost list entry
|
core schema to use in the innermost list entry
|
||||||
"""
|
"""
|
||||||
from numpydantic.shape import _is_range
|
from numpydantic.validation.shape import _is_range
|
||||||
|
|
||||||
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
|
||||||
# labels, if present
|
# labels, if present
|
||||||
|
|
|
@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` .
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, TypeVar, Union
|
from typing import Any, Callable, Iterable, TypeVar, Union
|
||||||
|
|
||||||
from pydantic_core.core_schema import SerializationInfo
|
from pydantic_core.core_schema import SerializationInfo
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ U = TypeVar("U")
|
||||||
|
|
||||||
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
"""Use an interface class to render an array as JSON"""
|
"""Use an interface class to render an array as JSON"""
|
||||||
|
# perf: keys to skip in generation - anything named "value" is array data.
|
||||||
|
skip = ["value"]
|
||||||
|
|
||||||
interface_cls = Interface.match_output(value)
|
interface_cls = Interface.match_output(value)
|
||||||
array = interface_cls.to_json(value, info)
|
array = interface_cls.to_json(value, info)
|
||||||
if isinstance(array, JsonDict):
|
if isinstance(array, JsonDict):
|
||||||
|
@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
if info.context.get("mark_interface", False):
|
if info.context.get("mark_interface", False):
|
||||||
array = interface_cls.mark_json(array)
|
array = interface_cls.mark_json(array)
|
||||||
|
|
||||||
|
if isinstance(array, list):
|
||||||
|
return array
|
||||||
|
|
||||||
|
# ---- Perf Barrier ------------------------------------------------------
|
||||||
|
# put context args intended to **wrap** the array above
|
||||||
|
# put context args intended to **modify** the array below
|
||||||
|
#
|
||||||
|
# above, we assume that a list is **data** not to be modified.
|
||||||
|
# below, we must mark whenever the data is in the line of fire
|
||||||
|
# to avoid an expensive iteration.
|
||||||
|
|
||||||
if info.context.get("absolute_paths", False):
|
if info.context.get("absolute_paths", False):
|
||||||
array = _absolutize_paths(array)
|
array = _absolutize_paths(array, skip)
|
||||||
else:
|
else:
|
||||||
relative_to = info.context.get("relative_to", ".")
|
relative_to = info.context.get("relative_to", ".")
|
||||||
array = _relativize_paths(array, relative_to)
|
array = _relativize_paths(array, relative_to, skip)
|
||||||
else:
|
else:
|
||||||
# relativize paths by default
|
if isinstance(array, list):
|
||||||
array = _relativize_paths(array, ".")
|
return array
|
||||||
|
|
||||||
|
# ---- Perf Barrier ------------------------------------------------------
|
||||||
|
# same as above, ensure any keys that contain array values are skipped right now
|
||||||
|
|
||||||
|
array = _relativize_paths(array, ".", skip)
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
def _relativize_paths(
|
||||||
|
value: dict, relative_to: str = ".", skip: Iterable = tuple()
|
||||||
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Make paths relative to either the current directory or the provided
|
Make paths relative to either the current directory or the provided
|
||||||
``relative_to`` directory, if provided in the context
|
``relative_to`` directory, if provided in the context
|
||||||
|
@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
|
|
||||||
def _r_path(v: Any) -> Any:
|
def _r_path(v: Any) -> Any:
|
||||||
|
if not isinstance(v, (str, Path)):
|
||||||
|
return v
|
||||||
try:
|
try:
|
||||||
path = Path(v)
|
path = Path(v)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return _walk_and_apply(value, _r_path)
|
return _walk_and_apply(value, _r_path, skip)
|
||||||
|
|
||||||
|
|
||||||
def _absolutize_paths(value: dict) -> dict:
|
def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict:
|
||||||
def _a_path(v: Any) -> Any:
|
def _a_path(v: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
path = Path(v)
|
path = Path(v)
|
||||||
|
@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict:
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return _walk_and_apply(value, _a_path)
|
return _walk_and_apply(value, _a_path, skip)
|
||||||
|
|
||||||
|
|
||||||
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T:
|
def _walk_and_apply(value: T, f: Callable[[U, bool], U], skip: Iterable = tuple()) -> T:
|
||||||
"""
|
"""
|
||||||
Walk an object, applying a function
|
Walk an object, applying a function
|
||||||
"""
|
"""
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
|
if k in skip:
|
||||||
|
continue
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
_walk_and_apply(v, f)
|
_walk_and_apply(v, f, skip)
|
||||||
elif isinstance(v, list):
|
elif isinstance(v, list):
|
||||||
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v]
|
value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v]
|
||||||
else:
|
else:
|
||||||
value[k] = f(v)
|
value[k] = f(v)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
value = [_walk_and_apply(v, f) for v in value]
|
value = [_walk_and_apply(v, f, skip) for v in value]
|
||||||
else:
|
else:
|
||||||
value = f(value)
|
value = f(value)
|
||||||
return value
|
return value
|
||||||
|
|
11
src/numpydantic/validation/__init__.py
Normal file
11
src/numpydantic/validation/__init__.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
"""
|
||||||
|
Helper functions for validation
|
||||||
|
"""
|
||||||
|
|
||||||
|
from numpydantic.validation.dtype import validate_dtype
|
||||||
|
from numpydantic.validation.shape import validate_shape
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"validate_dtype",
|
||||||
|
"validate_shape",
|
||||||
|
]
|
63
src/numpydantic/validation/dtype.py
Normal file
63
src/numpydantic/validation/dtype.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
"""
|
||||||
|
Helper functions for validation of dtype.
|
||||||
|
|
||||||
|
For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from typing import Any, Union, get_args, get_origin
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from numpydantic.types import DtypeType
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from types import UnionType
|
||||||
|
else:
|
||||||
|
UnionType = None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_dtype(dtype: Any, target: DtypeType) -> bool:
|
||||||
|
"""
|
||||||
|
Validate a dtype against the target dtype
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: The dtype to validate
|
||||||
|
target (:class:`.DtypeType`): The target dtype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: ``True`` if valid, ``False`` otherwise
|
||||||
|
"""
|
||||||
|
if target is Any:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(target, tuple):
|
||||||
|
valid = dtype in target
|
||||||
|
elif is_union(target):
|
||||||
|
valid = any(
|
||||||
|
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
|
||||||
|
)
|
||||||
|
elif target is np.str_:
|
||||||
|
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
|
||||||
|
np.str_,
|
||||||
|
str,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# try to match as any subclass, if target is a class
|
||||||
|
try:
|
||||||
|
valid = issubclass(dtype, target)
|
||||||
|
except TypeError:
|
||||||
|
# expected, if dtype or target is not a class
|
||||||
|
valid = dtype == target
|
||||||
|
|
||||||
|
return valid
|
||||||
|
|
||||||
|
|
||||||
|
def is_union(dtype: DtypeType) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a dtype is a union
|
||||||
|
"""
|
||||||
|
if UnionType is None:
|
||||||
|
return get_origin(dtype) is Union
|
||||||
|
else:
|
||||||
|
return get_origin(dtype) in (Union, UnionType)
|
|
@ -2,7 +2,7 @@
|
||||||
Declaration and validation functions for array shapes.
|
Declaration and validation functions for array shapes.
|
||||||
|
|
||||||
Mostly a mildly modified version of nptyping's
|
Mostly a mildly modified version of nptyping's
|
||||||
:func:`npytping.shape_expression.check_shape`
|
:func:`npytping.shape_expression.validate_shape`
|
||||||
and its internals to allow for extended syntax, including ranges of shapes.
|
and its internals to allow for extended syntax, including ranges of shapes.
|
||||||
|
|
||||||
Modifications from nptyping:
|
Modifications from nptyping:
|
||||||
|
@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) ->
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def check_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
|
||||||
"""
|
"""
|
||||||
Check whether the given shape corresponds to the given shape_expression.
|
Check whether the given shape corresponds to the given shape_expression.
|
||||||
:param shape: the shape in question.
|
:param shape: the shape in question.
|
|
@ -3,10 +3,6 @@ import sys
|
||||||
import pytest
|
import pytest
|
||||||
from typing import Any, Tuple, Union, Type
|
from typing import Any, Tuple, Union, Type
|
||||||
|
|
||||||
if sys.version_info.minor >= 10:
|
|
||||||
from typing import TypeAlias
|
|
||||||
else:
|
|
||||||
from typing_extensions import TypeAlias
|
|
||||||
from pydantic import BaseModel, computed_field, ConfigDict
|
from pydantic import BaseModel, computed_field, ConfigDict
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
from numpydantic.ndarray import NDArrayMeta
|
from numpydantic.ndarray import NDArrayMeta
|
||||||
|
@ -15,6 +11,15 @@ import numpy as np
|
||||||
|
|
||||||
from tests.fixtures import *
|
from tests.fixtures import *
|
||||||
|
|
||||||
|
if sys.version_info.minor >= 10:
|
||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
YES_PIPE = True
|
||||||
|
else:
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
YES_PIPE = False
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
|
@ -80,6 +85,9 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||||
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||||
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||||
|
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
|
||||||
|
if YES_PIPE:
|
||||||
|
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
|
@ -119,62 +127,93 @@ def shape_cases(request) -> ValidationCase:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
DTYPE_CASES = [
|
||||||
scope="module",
|
ValidationCase(dtype=float, passes=True),
|
||||||
params=[
|
ValidationCase(dtype=int, passes=False),
|
||||||
ValidationCase(dtype=float, passes=True),
|
ValidationCase(dtype=np.uint8, passes=False),
|
||||||
ValidationCase(dtype=int, passes=False),
|
ValidationCase(annotation=NUMBER, dtype=int, passes=True),
|
||||||
ValidationCase(dtype=np.uint8, passes=False),
|
ValidationCase(annotation=NUMBER, dtype=float, passes=True),
|
||||||
ValidationCase(annotation=NUMBER, dtype=int, passes=True),
|
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True),
|
||||||
ValidationCase(annotation=NUMBER, dtype=float, passes=True),
|
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True),
|
||||||
ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True),
|
ValidationCase(annotation=NUMBER, dtype=str, passes=False),
|
||||||
ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True),
|
ValidationCase(annotation=INTEGER, dtype=int, passes=True),
|
||||||
ValidationCase(annotation=NUMBER, dtype=str, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
||||||
ValidationCase(annotation=INTEGER, dtype=int, passes=True),
|
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
||||||
ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True),
|
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
||||||
ValidationCase(annotation=INTEGER, dtype=float, passes=False),
|
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
|
||||||
ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
||||||
ValidationCase(annotation=INTEGER, dtype=str, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
||||||
ValidationCase(annotation=FLOAT, dtype=float, passes=True),
|
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
||||||
ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True),
|
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
||||||
ValidationCase(annotation=FLOAT, dtype=int, passes=False),
|
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
|
||||||
ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False),
|
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
||||||
ValidationCase(annotation=FLOAT, dtype=str, passes=False),
|
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
||||||
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
||||||
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
||||||
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
||||||
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
||||||
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
|
||||||
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True),
|
||||||
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
|
ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True),
|
||||||
],
|
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False),
|
||||||
ids=[
|
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False),
|
||||||
"float",
|
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False),
|
||||||
"int",
|
]
|
||||||
"uint8",
|
|
||||||
"number-int",
|
DTYPE_IDS = [
|
||||||
"number-float",
|
"float",
|
||||||
"number-uint8",
|
"int",
|
||||||
"number-float16",
|
"uint8",
|
||||||
"number-str",
|
"number-int",
|
||||||
"integer-int",
|
"number-float",
|
||||||
"integer-uint8",
|
"number-uint8",
|
||||||
"integer-float",
|
"number-float16",
|
||||||
"integer-float32",
|
"number-str",
|
||||||
"integer-str",
|
"integer-int",
|
||||||
"float-float",
|
"integer-uint8",
|
||||||
"float-float32",
|
"integer-float",
|
||||||
"float-int",
|
"integer-float32",
|
||||||
"float-uint8",
|
"integer-str",
|
||||||
"float-str",
|
"float-float",
|
||||||
"str-str",
|
"float-float32",
|
||||||
"str-int",
|
"float-int",
|
||||||
"str-float",
|
"float-uint8",
|
||||||
"model-model",
|
"float-str",
|
||||||
"model-badmodel",
|
"str-str",
|
||||||
"model-int",
|
"str-int",
|
||||||
"model-subclass",
|
"str-float",
|
||||||
],
|
"model-model",
|
||||||
)
|
"model-badmodel",
|
||||||
|
"model-int",
|
||||||
|
"model-subclass",
|
||||||
|
"union-type-uint32",
|
||||||
|
"union-type-float32",
|
||||||
|
"union-type-uint64",
|
||||||
|
"union-type-float64",
|
||||||
|
"union-type-str",
|
||||||
|
]
|
||||||
|
|
||||||
|
if YES_PIPE:
|
||||||
|
DTYPE_CASES.extend(
|
||||||
|
[
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False),
|
||||||
|
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
DTYPE_IDS.extend(
|
||||||
|
[
|
||||||
|
"union-pipe-uint32",
|
||||||
|
"union-pipe-float32",
|
||||||
|
"union-pipe-uint64",
|
||||||
|
"union-pipe-float64",
|
||||||
|
"union-pipe-str",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=DTYPE_CASES, ids=DTYPE_IDS)
|
||||||
def dtype_cases(request) -> ValidationCase:
|
def dtype_cases(request) -> ValidationCase:
|
||||||
return request.param
|
return request.param
|
||||||
|
|
|
@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
|
||||||
|
|
||||||
if roundtrip:
|
if roundtrip:
|
||||||
if dump_array:
|
if dump_array:
|
||||||
assert as_json["array"] == lol_array
|
assert as_json["value"] == lol_array
|
||||||
else:
|
else:
|
||||||
if as_json.get("file", False):
|
if as_json.get("file", False):
|
||||||
assert "array" not in as_json
|
assert "array" not in as_json
|
||||||
|
|
|
@ -10,6 +10,8 @@ from typing import Callable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from numpydantic.serialization import _walk_and_apply
|
||||||
|
|
||||||
pytestmark = pytest.mark.serialization
|
pytestmark = pytest.mark.serialization
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,3 +95,34 @@ def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank):
|
||||||
|
|
||||||
# shouldn't have absolutized subpath even if it's pathlike
|
# shouldn't have absolutized subpath even if it's pathlike
|
||||||
assert data["path"] == expected_dataset
|
assert data["path"] == expected_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_walk_and_apply():
|
||||||
|
"""
|
||||||
|
Walk and apply should recursively apply a function to everything in a nesty structure
|
||||||
|
"""
|
||||||
|
test = {
|
||||||
|
"a": 1,
|
||||||
|
"b": 1,
|
||||||
|
"c": [
|
||||||
|
{"a": 1, "b": {"a": 1, "b": 1}, "c": [1, 1, 1]},
|
||||||
|
{"a": 1, "b": [1, 1, 1]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _mult_2(v, skip: bool = False):
|
||||||
|
return v * 2
|
||||||
|
|
||||||
|
def _assert_2(v, skip: bool = False):
|
||||||
|
assert v == 2
|
||||||
|
return v
|
||||||
|
|
||||||
|
walked = _walk_and_apply(test, _mult_2)
|
||||||
|
_walk_and_apply(walked, _assert_2)
|
||||||
|
|
||||||
|
assert walked["a"] == 2
|
||||||
|
assert walked["c"][0]["a"] == 2
|
||||||
|
assert walked["c"][0]["b"]["a"] == 2
|
||||||
|
assert all([w == 2 for w in walked["c"][0]["c"]])
|
||||||
|
assert walked["c"][1]["a"] == 2
|
||||||
|
assert all([w == 2 for w in walked["c"][1]["b"]])
|
||||||
|
|
Loading…
Reference in a new issue