Merge pull request #24 from p2p-ld/dtype-union
Some checks are pending
Lint / Ruff Linting (push) Waiting to run
Lint / Black Formatting (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (<2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, macos-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.10) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.11) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, ubuntu-latest, 3.9) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.12) (push) Waiting to run
Tests / test (>=2.0.0, windows-latest, 3.9) (push) Waiting to run
Tests / finish-coverage (push) Blocked by required conditions

[dtype] Support Unions
This commit is contained in:
Jonny Saunders 2024-09-24 00:27:07 -07:00 committed by GitHub
commit 1cf69eb18c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 443 additions and 120 deletions

2
.gitignore vendored
View file

@ -160,3 +160,5 @@ cython_debug/
#.idea/ #.idea/
.pdm-python .pdm-python
ndarray.pyi ndarray.pyi
prof/

View file

@ -0,0 +1,7 @@
# dtype
```{eval-rst}
.. automodule:: numpydantic.validation.dtype
:members:
:undoc-members:
```

View file

@ -0,0 +1,6 @@
# validation
```{toctree}
dtype
shape
```

View file

@ -1,7 +1,7 @@
# shape # shape
```{eval-rst} ```{eval-rst}
.. automodule:: numpydantic.shape .. automodule:: numpydantic.validation.shape
:members: :members:
:undoc-members: :undoc-members:
``` ```

View file

@ -4,6 +4,33 @@
### 1.6.* ### 1.6.*
#### 1.6.1 - 24-09-23 - Support Union Dtypes
It's now possible to do this, like it always should have been
```python
class MyModel(BaseModel):
array: NDArray[Any, int | float]
```
**Features**
- Support for Union Dtypes
**Structure**
- New `validation` module containing `shape` and `dtype` convenience methods
to declutter main namespace and make a grouping for related code
- Rename all serialized arrays within a container dict to `value` to be able
to identify them by convention and avoid long iteration - see perf below.
**Perf**
- Avoid iterating over every item in an array trying to convert it to a path for
a several order of magnitude perf improvement over `1.6.0` (oops)
**Docs**
- Page for `dtypes`, mostly stubs at the moment, but more explicit documentation
about what kind of dtypes we support.
#### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization #### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization
Roundtrip JSON serialization is here - with serialization to list of lists, Roundtrip JSON serialization is here - with serialization to list of lists,

98
docs/dtype.md Normal file
View file

@ -0,0 +1,98 @@
# dtype
```{todo}
This section is under construction as of 1.6.1
Much of the details of dtypes are covered in [syntax](./syntax.md)
and in {mod}`numpydantic.dtype` , but this section will specifically
address how dtypes are handled both generically and by interfaces
as we expand custom dtype handling <3.
For details of support and implementation until the docs have time for some love,
please see the tests, which are the source of truth for the functionality
of the library for now and forever.
```
Recall the general syntax:
```
NDArray[Shape, dtype]
```
These are the docs for what can do in `dtype`.
## Scalar Dtypes
Python builtin types and numpy types should be handled transparently,
with some exception for complex numbers and objects (described below).
### Numbers
#### Complex numbers
```{todo}
Document limitations for complex numbers and strategies for serialization/validation
```
### Datetimes
```{todo}
Datetimes are supported by every interface except :class:`.VideoInterface` ,
with the caveat that HDF5 loses timezone information, and thus all timestamps should
be re-encoded to UTC before saving/loading.
More generic datetime support is TODO.
```
### Objects
```{todo}
Generic objects are supported by all interfaces except
:class:`.VideoInterface` , :class;`.HDF5Interface` , and :class:`.ZarrInterface` .
this might be expected, but there is also hope, TODO fill in serialization plans.
```
### Strings
```{todo}
Strings are supported by all interfaces except :class:`.VideoInterface` .
TODO is fill in the subtleties of how this works
```
## Generic Dtypes
```{todo}
For now these are handled as tuples of dtypes, see the source of
{ref}`numpydantic.dtype.Float` . They should either be handled as Unions
or as a more prescribed meta-type.
For now, use `int` and `float` to refer to the general concepts of
"any int" or "any float" even if this is a bit mismatched from the numpy usage.
```
## Extended Python Typing Universe
### Union Types
Union types can be used as expected.
Union types are tested recursively -- if any item within a ``Union`` matches
the expected dtype at a given level of recursion, the dtype test passes.
```python
class MyModel(BaseModel):
array: NDArray[Any, int | float]
```
## Compound Dtypes
```{todo}
Compound dtypes are currently unsupported,
though the HDF5 interface supports indexing into compound dtypes
as separable dimensions/arrays using the third "field" parameter in
{class}`.hdf5.H5ArrayPath` .
```

View file

@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
design design
syntax syntax
dtype
serialization serialization
interfaces interfaces
``` ```
@ -484,13 +485,13 @@ interfaces
api/index api/index
api/interface/index api/interface/index
api/validation/index
api/dtype api/dtype
api/ndarray api/ndarray
api/maps api/maps
api/meta api/meta
api/schema api/schema
api/serialization api/serialization
api/shape
api/types api/types
``` ```

View file

@ -1,6 +1,6 @@
[project] [project]
name = "numpydantic" name = "numpydantic"
version = "1.6.0" version = "1.6.1"
description = "Type and shape validation and serialization for arbitrary array types in pydantic models" description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
authors = [ authors = [
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
@ -126,7 +126,7 @@ markers = [
] ]
[tool.ruff] [tool.ruff]
target-version = "py311" target-version = "py39"
include = ["src/numpydantic/**/*.py", "pyproject.toml"] include = ["src/numpydantic/**/*.py", "pyproject.toml"]
exclude = ["tests"] exclude = ["tests"]

View file

@ -4,7 +4,7 @@
from numpydantic.ndarray import NDArray from numpydantic.ndarray import NDArray
from numpydantic.meta import update_ndarray_stub from numpydantic.meta import update_ndarray_stub
from numpydantic.shape import Shape from numpydantic.validation.shape import Shape
update_ndarray_stub() update_ndarray_stub()

View file

@ -12,6 +12,9 @@ module, rather than needing to import each individually.
Some types like `Integer` are compound types - tuples of multiple dtypes. Some types like `Integer` are compound types - tuples of multiple dtypes.
Check these using ``in`` rather than ``==``. This interface will develop in future Check these using ``in`` rather than ``==``. This interface will develop in future
versions to allow a single dtype check. versions to allow a single dtype check.
For internal helper functions for validating dtype,
see :mod:`numpydantic.validation.dtype`
""" """
import sys import sys

View file

@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict):
name: str name: str
chunks: Iterable[tuple[int, ...]] chunks: Iterable[tuple[int, ...]]
dtype: str dtype: str
array: list value: list
def to_array_input(self) -> DaskArray: def to_array_input(self) -> DaskArray:
"""Construct a dask array""" """Construct a dask array"""
np_array = np.array(self.array, dtype=self.dtype) np_array = np.array(self.value, dtype=self.dtype)
array = from_array( array = from_array(
np_array, np_array,
name=self.name, name=self.name,
@ -100,7 +100,7 @@ class DaskInterface(Interface):
if info.round_trip: if info.round_trip:
as_json = DaskJsonDict( as_json = DaskJsonDict(
type=cls.name, type=cls.name,
array=as_json, value=as_json,
name=array.name, name=array.name,
chunks=array.chunks, chunks=array.chunks,
dtype=str(np_array.dtype), dtype=str(np_array.dtype),

View file

@ -20,8 +20,8 @@ from numpydantic.exceptions import (
ShapeError, ShapeError,
TooManyMatchesError, TooManyMatchesError,
) )
from numpydantic.shape import check_shape
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation import validate_dtype, validate_shape
T = TypeVar("T", bound=NDArrayType) T = TypeVar("T", bound=NDArrayType)
U = TypeVar("U", bound="JsonDict") U = TypeVar("U", bound="JsonDict")
@ -76,6 +76,21 @@ class InterfaceMark(BaseModel):
class JsonDict(BaseModel): class JsonDict(BaseModel):
""" """
Representation of array when dumped with round_trip == True. Representation of array when dumped with round_trip == True.
.. admonition:: Developer's Note
Any JsonDict that contains an actual array should be named ``value``
rather than array (or any other name), and nothing but the
array data should be named ``value`` .
During JSON serialization, it becomes ambiguous what contains an array
of data vs. an array of metadata. For the moment we would like to
reserve the ability to have lists of metadata, so until we rule that out,
we would like to be able to avoid iterating over every element of an array
in any context parameter transformation like relativizing/absolutizing paths.
To avoid that, it's good to agree on a single value name -- ``value`` --
and avoid using it for anything else.
""" """
type: str type: str
@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]):
Validate the dtype of the given array, returning Validate the dtype of the given array, returning
``True`` if valid, ``False`` if not. ``True`` if valid, ``False`` if not.
""" """
if self.dtype is Any: return validate_dtype(dtype, self.dtype)
return True
if isinstance(self.dtype, tuple):
valid = dtype in self.dtype
elif self.dtype is np.str_:
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else:
# try to match as any subclass, if self.dtype is a class
try:
valid = issubclass(dtype, self.dtype)
except TypeError:
# expected, if dtype or self.dtype is not a class
valid = dtype == self.dtype
return valid
def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None:
""" """
@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]):
if self.shape is Any: if self.shape is Any:
return True return True
return check_shape(shape, self.shape) return validate_shape(shape, self.shape)
def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None: def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None:
""" """

View file

@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict):
type: Literal["numpy"] type: Literal["numpy"]
dtype: str dtype: str
array: list value: list
def to_array_input(self) -> ndarray: def to_array_input(self) -> ndarray:
""" """
Construct a numpy array Construct a numpy array
""" """
return np.array(self.array, dtype=self.dtype) return np.array(self.value, dtype=self.dtype)
class NumpyInterface(Interface): class NumpyInterface(Interface):
@ -99,6 +99,6 @@ class NumpyInterface(Interface):
if info.round_trip: if info.round_trip:
json_array = NumpyJsonDict( json_array = NumpyJsonDict(
type=cls.name, dtype=str(array.dtype), array=json_array type=cls.name, dtype=str(array.dtype), value=json_array
) )
return json_array return json_array

View file

@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict):
type: Literal["zarr"] type: Literal["zarr"]
file: Optional[str] = None file: Optional[str] = None
path: Optional[str] = None path: Optional[str] = None
array: Optional[list] = None value: Optional[list] = None
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]: def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
""" """
@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict):
if self.file: if self.file:
array = ZarrArrayPath(file=self.file, path=self.path) array = ZarrArrayPath(file=self.file, path=self.path)
else: else:
array = zarr.array(self.array) array = zarr.array(self.value)
return array return array
@ -202,7 +202,7 @@ class ZarrInterface(Interface):
as_json["info"]["hexdigest"] = array.hexdigest() as_json["info"]["hexdigest"] = array.hexdigest()
if dump_array or not is_file: if dump_array or not is_file:
as_json["array"] = array[:].tolist() as_json["value"] = array[:].tolist()
as_json = ZarrJsonDict(**as_json) as_json = ZarrJsonDict(**as_json)
else: else:

View file

@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa
""" """
from typing import TYPE_CHECKING, Any, Tuple from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin
import numpy as np import numpy as np
from pydantic import GetJsonSchemaHandler from pydantic import GetJsonSchemaHandler
@ -29,6 +29,7 @@ from numpydantic.schema import (
) )
from numpydantic.serialization import jsonize_array from numpydantic.serialization import jsonize_array
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.error import InvalidArgumentsError from numpydantic.vendor.nptyping.error import InvalidArgumentsError
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
from numpydantic.vendor.nptyping.nptyping_type import NPTypingType from numpydantic.vendor.nptyping.nptyping_type import NPTypingType
@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
except InterfaceError: except InterfaceError:
return False return False
def _is_literal_like(cls, item: Any) -> bool:
"""
Changes from nptyping:
- doesn't just ducktype for literal but actually, yno, checks for being literal
"""
return get_origin(item) is Literal
def _get_shape(cls, dtype_candidate: Any) -> "Shape": def _get_shape(cls, dtype_candidate: Any) -> "Shape":
""" """
Override of base method to use our local definition of shape Override of base method to use our local definition of shape
""" """
from numpydantic.shape import Shape from numpydantic.validation.shape import Shape
if dtype_candidate is Any or dtype_candidate is Shape: if dtype_candidate is Any or dtype_candidate is Shape:
shape = Any shape = Any
@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
if dtype_candidate is Any: if dtype_candidate is Any:
dtype = Any dtype = Any
elif is_dtype: elif is_dtype or is_union(dtype_candidate):
dtype = dtype_candidate dtype = dtype_candidate
elif issubclass(dtype_candidate, Structure): # pragma: no cover elif issubclass(dtype_candidate, Structure): # pragma: no cover
dtype = dtype_candidate dtype = dtype_candidate

View file

@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
import hashlib import hashlib
import json import json
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional, get_args
import numpy as np import numpy as np
from pydantic import BaseModel from pydantic import BaseModel
@ -16,6 +16,7 @@ from numpydantic import dtype as dt
from numpydantic.interface import Interface from numpydantic.interface import Interface
from numpydantic.maps import np_to_python from numpydantic.maps import np_to_python
from numpydantic.types import DtypeType, NDArrayType, ShapeType from numpydantic.types import DtypeType, NDArrayType, ShapeType
from numpydantic.validation.dtype import is_union
from numpydantic.vendor.nptyping.structure import StructureMeta from numpydantic.vendor.nptyping.structure import StructureMeta
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
@ -51,7 +52,6 @@ def _lol_dtype(
"""Get the innermost dtype schema to use in the generated pydantic schema""" """Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, StructureMeta): # pragma: no cover if isinstance(dtype, StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported") raise NotImplementedError("Structured dtypes are currently unsupported")
if isinstance(dtype, tuple): if isinstance(dtype, tuple):
# if it's a meta-type that refers to a generic float/int, just make that # if it's a meta-type that refers to a generic float/int, just make that
if dtype in (dt.Float, dt.Number): if dtype in (dt.Float, dt.Number):
@ -66,7 +66,10 @@ def _lol_dtype(
array_type = core_schema.union_schema( array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in types_] [_lol_dtype(t, _handler) for t in types_]
) )
elif is_union(dtype):
array_type = core_schema.union_schema(
[_lol_dtype(t, _handler) for t in get_args(dtype)]
)
else: else:
try: try:
python_type = np_to_python[dtype] python_type = np_to_python[dtype]
@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema:
array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic
core schema to use in the innermost list entry core schema to use in the innermost list entry
""" """
from numpydantic.shape import _is_range from numpydantic.validation.shape import _is_range
shape_parts = [part.strip() for part in shape.__args__[0].split(",")] shape_parts = [part.strip() for part in shape.__args__[0].split(",")]
# labels, if present # labels, if present

View file

@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` .
""" """
from pathlib import Path from pathlib import Path
from typing import Any, Callable, TypeVar, Union from typing import Any, Callable, Iterable, TypeVar, Union
from pydantic_core.core_schema import SerializationInfo from pydantic_core.core_schema import SerializationInfo
@ -16,6 +16,9 @@ U = TypeVar("U")
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
"""Use an interface class to render an array as JSON""" """Use an interface class to render an array as JSON"""
# perf: keys to skip in generation - anything named "value" is array data.
skip = ["value"]
interface_cls = Interface.match_output(value) interface_cls = Interface.match_output(value)
array = interface_cls.to_json(value, info) array = interface_cls.to_json(value, info)
if isinstance(array, JsonDict): if isinstance(array, JsonDict):
@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
if info.context.get("mark_interface", False): if info.context.get("mark_interface", False):
array = interface_cls.mark_json(array) array = interface_cls.mark_json(array)
if isinstance(array, list):
return array
# ---- Perf Barrier ------------------------------------------------------
# put context args intended to **wrap** the array above
# put context args intended to **modify** the array below
#
# above, we assume that a list is **data** not to be modified.
# below, we must mark whenever the data is in the line of fire
# to avoid an expensive iteration.
if info.context.get("absolute_paths", False): if info.context.get("absolute_paths", False):
array = _absolutize_paths(array) array = _absolutize_paths(array, skip)
else: else:
relative_to = info.context.get("relative_to", ".") relative_to = info.context.get("relative_to", ".")
array = _relativize_paths(array, relative_to) array = _relativize_paths(array, relative_to, skip)
else: else:
# relativize paths by default if isinstance(array, list):
array = _relativize_paths(array, ".") return array
# ---- Perf Barrier ------------------------------------------------------
# same as above, ensure any keys that contain array values are skipped right now
array = _relativize_paths(array, ".", skip)
return array return array
def _relativize_paths(value: dict, relative_to: str = ".") -> dict: def _relativize_paths(
value: dict, relative_to: str = ".", skip: Iterable = tuple()
) -> dict:
""" """
Make paths relative to either the current directory or the provided Make paths relative to either the current directory or the provided
``relative_to`` directory, if provided in the context ``relative_to`` directory, if provided in the context
@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
# pdb.set_trace() # pdb.set_trace()
def _r_path(v: Any) -> Any: def _r_path(v: Any) -> Any:
if not isinstance(v, (str, Path)):
return v
try: try:
path = Path(v) path = Path(v)
if not path.exists(): if not path.exists():
@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
except (TypeError, ValueError): except (TypeError, ValueError):
return v return v
return _walk_and_apply(value, _r_path) return _walk_and_apply(value, _r_path, skip)
def _absolutize_paths(value: dict) -> dict: def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict:
def _a_path(v: Any) -> Any: def _a_path(v: Any) -> Any:
try: try:
path = Path(v) path = Path(v)
@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict:
except (TypeError, ValueError): except (TypeError, ValueError):
return v return v
return _walk_and_apply(value, _a_path) return _walk_and_apply(value, _a_path, skip)
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T: def _walk_and_apply(value: T, f: Callable[[U, bool], U], skip: Iterable = tuple()) -> T:
""" """
Walk an object, applying a function Walk an object, applying a function
""" """
if isinstance(value, dict): if isinstance(value, dict):
for k, v in value.items(): for k, v in value.items():
if k in skip:
continue
if isinstance(v, dict): if isinstance(v, dict):
_walk_and_apply(v, f) _walk_and_apply(v, f, skip)
elif isinstance(v, list): elif isinstance(v, list):
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v] value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v]
else: else:
value[k] = f(v) value[k] = f(v)
elif isinstance(value, list): elif isinstance(value, list):
value = [_walk_and_apply(v, f) for v in value] value = [_walk_and_apply(v, f, skip) for v in value]
else: else:
value = f(value) value = f(value)
return value return value

View file

@ -0,0 +1,11 @@
"""
Helper functions for validation
"""
from numpydantic.validation.dtype import validate_dtype
from numpydantic.validation.shape import validate_shape
__all__ = [
"validate_dtype",
"validate_shape",
]

View file

@ -0,0 +1,63 @@
"""
Helper functions for validation of dtype.
For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype`
"""
import sys
from typing import Any, Union, get_args, get_origin
import numpy as np
from numpydantic.types import DtypeType
if sys.version_info >= (3, 10):
from types import UnionType
else:
UnionType = None
def validate_dtype(dtype: Any, target: DtypeType) -> bool:
"""
Validate a dtype against the target dtype
Args:
dtype: The dtype to validate
target (:class:`.DtypeType`): The target dtype
Returns:
bool: ``True`` if valid, ``False`` otherwise
"""
if target is Any:
return True
if isinstance(target, tuple):
valid = dtype in target
elif is_union(target):
valid = any(
[validate_dtype(dtype, target_dt) for target_dt in get_args(target)]
)
elif target is np.str_:
valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in (
np.str_,
str,
)
else:
# try to match as any subclass, if target is a class
try:
valid = issubclass(dtype, target)
except TypeError:
# expected, if dtype or target is not a class
valid = dtype == target
return valid
def is_union(dtype: DtypeType) -> bool:
"""
Check if a dtype is a union
"""
if UnionType is None:
return get_origin(dtype) is Union
else:
return get_origin(dtype) in (Union, UnionType)

View file

@ -2,7 +2,7 @@
Declaration and validation functions for array shapes. Declaration and validation functions for array shapes.
Mostly a mildly modified version of nptyping's Mostly a mildly modified version of nptyping's
:func:`npytping.shape_expression.check_shape` :func:`npytping.shape_expression.validate_shape`
and its internals to allow for extended syntax, including ranges of shapes. and its internals to allow for extended syntax, including ranges of shapes.
Modifications from nptyping: Modifications from nptyping:
@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) ->
@lru_cache @lru_cache
def check_shape(shape: ShapeTuple, target: "Shape") -> bool: def validate_shape(shape: ShapeTuple, target: "Shape") -> bool:
""" """
Check whether the given shape corresponds to the given shape_expression. Check whether the given shape corresponds to the given shape_expression.
:param shape: the shape in question. :param shape: the shape in question.

View file

@ -3,10 +3,6 @@ import sys
import pytest import pytest
from typing import Any, Tuple, Union, Type from typing import Any, Tuple, Union, Type
if sys.version_info.minor >= 10:
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from pydantic import BaseModel, computed_field, ConfigDict from pydantic import BaseModel, computed_field, ConfigDict
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
from numpydantic.ndarray import NDArrayMeta from numpydantic.ndarray import NDArrayMeta
@ -15,6 +11,15 @@ import numpy as np
from tests.fixtures import * from tests.fixtures import *
if sys.version_info.minor >= 10:
from typing import TypeAlias
YES_PIPE = True
else:
from typing_extensions import TypeAlias
YES_PIPE = False
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption( parser.addoption(
@ -80,6 +85,9 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str] STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]]
if YES_PIPE:
UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32]
@pytest.fixture( @pytest.fixture(
@ -119,9 +127,7 @@ def shape_cases(request) -> ValidationCase:
return request.param return request.param
@pytest.fixture( DTYPE_CASES = [
scope="module",
params=[
ValidationCase(dtype=float, passes=True), ValidationCase(dtype=float, passes=True),
ValidationCase(dtype=int, passes=False), ValidationCase(dtype=int, passes=False),
ValidationCase(dtype=np.uint8, passes=False), ValidationCase(dtype=np.uint8, passes=False),
@ -147,8 +153,14 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
ValidationCase(annotation=MODEL, dtype=int, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False),
ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), ValidationCase(annotation=MODEL, dtype=SubClass, passes=True),
], ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True),
ids=[ ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True),
ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False),
ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False),
]
DTYPE_IDS = [
"float", "float",
"int", "int",
"uint8", "uint8",
@ -174,7 +186,34 @@ def shape_cases(request) -> ValidationCase:
"model-badmodel", "model-badmodel",
"model-int", "model-int",
"model-subclass", "model-subclass",
], "union-type-uint32",
"union-type-float32",
"union-type-uint64",
"union-type-float64",
"union-type-str",
]
if YES_PIPE:
DTYPE_CASES.extend(
[
ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True),
ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False),
ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False),
ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False),
]
) )
DTYPE_IDS.extend(
[
"union-pipe-uint32",
"union-pipe-float32",
"union-pipe-uint64",
"union-pipe-float64",
"union-pipe-str",
]
)
@pytest.fixture(scope="module", params=DTYPE_CASES, ids=DTYPE_IDS)
def dtype_cases(request) -> ValidationCase: def dtype_cases(request) -> ValidationCase:
return request.param return request.param

View file

@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
if roundtrip: if roundtrip:
if dump_array: if dump_array:
assert as_json["array"] == lol_array assert as_json["value"] == lol_array
else: else:
if as_json.get("file", False): if as_json.get("file", False):
assert "array" not in as_json assert "array" not in as_json

View file

@ -10,6 +10,8 @@ from typing import Callable
import numpy as np import numpy as np
import json import json
from numpydantic.serialization import _walk_and_apply
pytestmark = pytest.mark.serialization pytestmark = pytest.mark.serialization
@ -93,3 +95,34 @@ def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank):
# shouldn't have absolutized subpath even if it's pathlike # shouldn't have absolutized subpath even if it's pathlike
assert data["path"] == expected_dataset assert data["path"] == expected_dataset
def test_walk_and_apply():
"""
Walk and apply should recursively apply a function to everything in a nesty structure
"""
test = {
"a": 1,
"b": 1,
"c": [
{"a": 1, "b": {"a": 1, "b": 1}, "c": [1, 1, 1]},
{"a": 1, "b": [1, 1, 1]},
],
}
def _mult_2(v, skip: bool = False):
return v * 2
def _assert_2(v, skip: bool = False):
assert v == 2
return v
walked = _walk_and_apply(test, _mult_2)
_walk_and_apply(walked, _assert_2)
assert walked["a"] == 2
assert walked["c"][0]["a"] == 2
assert walked["c"][0]["b"]["a"] == 2
assert all([w == 2 for w in walked["c"][0]["c"]])
assert walked["c"][1]["a"] == 2
assert all([w == 2 for w in walked["c"][1]["b"]])