diff --git a/.gitignore b/.gitignore index 3b0fc39..1470b8c 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,6 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .pdm-python -ndarray.pyi \ No newline at end of file +ndarray.pyi + +prof/ \ No newline at end of file diff --git a/docs/api/validation/dtype.md b/docs/api/validation/dtype.md new file mode 100644 index 0000000..fd599ca --- /dev/null +++ b/docs/api/validation/dtype.md @@ -0,0 +1,7 @@ +# dtype + +```{eval-rst} +.. automodule:: numpydantic.validation.dtype + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/docs/api/validation/index.md b/docs/api/validation/index.md new file mode 100644 index 0000000..d0f5a11 --- /dev/null +++ b/docs/api/validation/index.md @@ -0,0 +1,6 @@ +# validation + +```{toctree} +dtype +shape +``` \ No newline at end of file diff --git a/docs/api/shape.md b/docs/api/validation/shape.md similarity index 54% rename from docs/api/shape.md rename to docs/api/validation/shape.md index 4c8638c..60a7a19 100644 --- a/docs/api/shape.md +++ b/docs/api/validation/shape.md @@ -1,7 +1,7 @@ # shape ```{eval-rst} -.. automodule:: numpydantic.shape +.. automodule:: numpydantic.validation.shape :members: :undoc-members: ``` \ No newline at end of file diff --git a/docs/changelog.md b/docs/changelog.md index af6375e..5874507 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,33 @@ ### 1.6.* +#### 1.6.1 - 24-09-23 - Support Union Dtypes + +It's now possible to do this, like it always should have been + +```python +class MyModel(BaseModel): + array: NDArray[Any, int | float] +``` + +**Features** +- Support for Union Dtypes + +**Structure** +- New `validation` module containing `shape` and `dtype` convenience methods + to declutter main namespace and make a grouping for related code +- Rename all serialized arrays within a container dict to `value` to be able + to identify them by convention and avoid long iteration - see perf below. + +**Perf** +- Avoid iterating over every item in an array trying to convert it to a path for + a several order of magnitude perf improvement over `1.6.0` (oops) + +**Docs** +- Page for `dtypes`, mostly stubs at the moment, but more explicit documentation + about what kind of dtypes we support. + + #### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization Roundtrip JSON serialization is here - with serialization to list of lists, diff --git a/docs/dtype.md b/docs/dtype.md new file mode 100644 index 0000000..f7be5ca --- /dev/null +++ b/docs/dtype.md @@ -0,0 +1,98 @@ +# dtype + +```{todo} +This section is under construction as of 1.6.1 + +Much of the details of dtypes are covered in [syntax](./syntax.md) +and in {mod}`numpydantic.dtype` , but this section will specifically +address how dtypes are handled both generically and by interfaces +as we expand custom dtype handling <3. + +For details of support and implementation until the docs have time for some love, +please see the tests, which are the source of truth for the functionality +of the library for now and forever. +``` + +Recall the general syntax: + +``` +NDArray[Shape, dtype] +``` + +These are the docs for what can do in `dtype`. + +## Scalar Dtypes + +Python builtin types and numpy types should be handled transparently, +with some exception for complex numbers and objects (described below). + +### Numbers + +#### Complex numbers + +```{todo} +Document limitations for complex numbers and strategies for serialization/validation +``` + +### Datetimes + +```{todo} +Datetimes are supported by every interface except :class:`.VideoInterface` , +with the caveat that HDF5 loses timezone information, and thus all timestamps should +be re-encoded to UTC before saving/loading. + +More generic datetime support is TODO. +``` + +### Objects + +```{todo} +Generic objects are supported by all interfaces except +:class:`.VideoInterface` , :class;`.HDF5Interface` , and :class:`.ZarrInterface` . + +this might be expected, but there is also hope, TODO fill in serialization plans. +``` + +### Strings + +```{todo} +Strings are supported by all interfaces except :class:`.VideoInterface` . + +TODO is fill in the subtleties of how this works +``` + +## Generic Dtypes + +```{todo} +For now these are handled as tuples of dtypes, see the source of +{ref}`numpydantic.dtype.Float` . They should either be handled as Unions +or as a more prescribed meta-type. + +For now, use `int` and `float` to refer to the general concepts of +"any int" or "any float" even if this is a bit mismatched from the numpy usage. +``` + +## Extended Python Typing Universe + +### Union Types + +Union types can be used as expected. + +Union types are tested recursively -- if any item within a ``Union`` matches +the expected dtype at a given level of recursion, the dtype test passes. + +```python +class MyModel(BaseModel): + array: NDArray[Any, int | float] +``` + +## Compound Dtypes + +```{todo} +Compound dtypes are currently unsupported, +though the HDF5 interface supports indexing into compound dtypes +as separable dimensions/arrays using the third "field" parameter in +{class}`.hdf5.H5ArrayPath` . +``` + + diff --git a/docs/index.md b/docs/index.md index 0880a59..bced657 100644 --- a/docs/index.md +++ b/docs/index.md @@ -473,6 +473,7 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True}) design syntax +dtype serialization interfaces ``` @@ -484,13 +485,13 @@ interfaces api/index api/interface/index +api/validation/index api/dtype api/ndarray api/maps api/meta api/schema api/serialization -api/shape api/types ``` diff --git a/pyproject.toml b/pyproject.toml index 0e6926b..af66c4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "numpydantic" -version = "1.6.0" +version = "1.6.1" description = "Type and shape validation and serialization for arbitrary array types in pydantic models" authors = [ {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, @@ -126,7 +126,7 @@ markers = [ ] [tool.ruff] -target-version = "py311" +target-version = "py39" include = ["src/numpydantic/**/*.py", "pyproject.toml"] exclude = ["tests"] diff --git a/src/numpydantic/__init__.py b/src/numpydantic/__init__.py index 803d8d2..98b8bd1 100644 --- a/src/numpydantic/__init__.py +++ b/src/numpydantic/__init__.py @@ -4,7 +4,7 @@ from numpydantic.ndarray import NDArray from numpydantic.meta import update_ndarray_stub -from numpydantic.shape import Shape +from numpydantic.validation.shape import Shape update_ndarray_stub() diff --git a/src/numpydantic/dtype.py b/src/numpydantic/dtype.py index 12d766a..84b28cc 100644 --- a/src/numpydantic/dtype.py +++ b/src/numpydantic/dtype.py @@ -12,6 +12,9 @@ module, rather than needing to import each individually. Some types like `Integer` are compound types - tuples of multiple dtypes. Check these using ``in`` rather than ``==``. This interface will develop in future versions to allow a single dtype check. + +For internal helper functions for validating dtype, +see :mod:`numpydantic.validation.dtype` """ import sys diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index cd36a65..95d0619 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict): name: str chunks: Iterable[tuple[int, ...]] dtype: str - array: list + value: list def to_array_input(self) -> DaskArray: """Construct a dask array""" - np_array = np.array(self.array, dtype=self.dtype) + np_array = np.array(self.value, dtype=self.dtype) array = from_array( np_array, name=self.name, @@ -100,7 +100,7 @@ class DaskInterface(Interface): if info.round_trip: as_json = DaskJsonDict( type=cls.name, - array=as_json, + value=as_json, name=array.name, chunks=array.chunks, dtype=str(np_array.dtype), diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index bee85b6..42bb891 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -20,8 +20,8 @@ from numpydantic.exceptions import ( ShapeError, TooManyMatchesError, ) -from numpydantic.shape import check_shape from numpydantic.types import DtypeType, NDArrayType, ShapeType +from numpydantic.validation import validate_dtype, validate_shape T = TypeVar("T", bound=NDArrayType) U = TypeVar("U", bound="JsonDict") @@ -76,6 +76,21 @@ class InterfaceMark(BaseModel): class JsonDict(BaseModel): """ Representation of array when dumped with round_trip == True. + + .. admonition:: Developer's Note + + Any JsonDict that contains an actual array should be named ``value`` + rather than array (or any other name), and nothing but the + array data should be named ``value`` . + + During JSON serialization, it becomes ambiguous what contains an array + of data vs. an array of metadata. For the moment we would like to + reserve the ability to have lists of metadata, so until we rule that out, + we would like to be able to avoid iterating over every element of an array + in any context parameter transformation like relativizing/absolutizing paths. + To avoid that, it's good to agree on a single value name -- ``value`` -- + and avoid using it for anything else. + """ type: str @@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]): Validate the dtype of the given array, returning ``True`` if valid, ``False`` if not. """ - if self.dtype is Any: - return True - - if isinstance(self.dtype, tuple): - valid = dtype in self.dtype - elif self.dtype is np.str_: - valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( - np.str_, - str, - ) - else: - # try to match as any subclass, if self.dtype is a class - try: - valid = issubclass(dtype, self.dtype) - except TypeError: - # expected, if dtype or self.dtype is not a class - valid = dtype == self.dtype - - return valid + return validate_dtype(dtype, self.dtype) def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: """ @@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]): if self.shape is Any: return True - return check_shape(shape, self.shape) + return validate_shape(shape, self.shape) def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None: """ diff --git a/src/numpydantic/interface/numpy.py b/src/numpydantic/interface/numpy.py index ad97474..6c84232 100644 --- a/src/numpydantic/interface/numpy.py +++ b/src/numpydantic/interface/numpy.py @@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict): type: Literal["numpy"] dtype: str - array: list + value: list def to_array_input(self) -> ndarray: """ Construct a numpy array """ - return np.array(self.array, dtype=self.dtype) + return np.array(self.value, dtype=self.dtype) class NumpyInterface(Interface): @@ -99,6 +99,6 @@ class NumpyInterface(Interface): if info.round_trip: json_array = NumpyJsonDict( - type=cls.name, dtype=str(array.dtype), array=json_array + type=cls.name, dtype=str(array.dtype), value=json_array ) return json_array diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index 41cad03..5dc647e 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict): type: Literal["zarr"] file: Optional[str] = None path: Optional[str] = None - array: Optional[list] = None + value: Optional[list] = None def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]: """ @@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict): if self.file: array = ZarrArrayPath(file=self.file, path=self.path) else: - array = zarr.array(self.array) + array = zarr.array(self.value) return array @@ -202,7 +202,7 @@ class ZarrInterface(Interface): as_json["info"]["hexdigest"] = array.hexdigest() if dump_array or not is_file: - as_json["array"] = array[:].tolist() + as_json["value"] = array[:].tolist() as_json = ZarrJsonDict(**as_json) else: diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index fb81f69..6969d44 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa """ -from typing import TYPE_CHECKING, Any, Tuple +from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin import numpy as np from pydantic import GetJsonSchemaHandler @@ -29,6 +29,7 @@ from numpydantic.schema import ( ) from numpydantic.serialization import jsonize_array from numpydantic.types import DtypeType, NDArrayType, ShapeType +from numpydantic.validation.dtype import is_union from numpydantic.vendor.nptyping.error import InvalidArgumentsError from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta from numpydantic.vendor.nptyping.nptyping_type import NPTypingType @@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): except InterfaceError: return False + def _is_literal_like(cls, item: Any) -> bool: + """ + Changes from nptyping: + - doesn't just ducktype for literal but actually, yno, checks for being literal + """ + return get_origin(item) is Literal + def _get_shape(cls, dtype_candidate: Any) -> "Shape": """ Override of base method to use our local definition of shape """ - from numpydantic.shape import Shape + from numpydantic.validation.shape import Shape if dtype_candidate is Any or dtype_candidate is Shape: shape = Any @@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): if dtype_candidate is Any: dtype = Any - elif is_dtype: + elif is_dtype or is_union(dtype_candidate): dtype = dtype_candidate elif issubclass(dtype_candidate, Structure): # pragma: no cover dtype = dtype_candidate diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index cafa1f4..bfea3aa 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in import hashlib import json -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, get_args import numpy as np from pydantic import BaseModel @@ -16,6 +16,7 @@ from numpydantic import dtype as dt from numpydantic.interface import Interface from numpydantic.maps import np_to_python from numpydantic.types import DtypeType, NDArrayType, ShapeType +from numpydantic.validation.dtype import is_union from numpydantic.vendor.nptyping.structure import StructureMeta if TYPE_CHECKING: # pragma: no cover @@ -51,7 +52,6 @@ def _lol_dtype( """Get the innermost dtype schema to use in the generated pydantic schema""" if isinstance(dtype, StructureMeta): # pragma: no cover raise NotImplementedError("Structured dtypes are currently unsupported") - if isinstance(dtype, tuple): # if it's a meta-type that refers to a generic float/int, just make that if dtype in (dt.Float, dt.Number): @@ -66,7 +66,10 @@ def _lol_dtype( array_type = core_schema.union_schema( [_lol_dtype(t, _handler) for t in types_] ) - + elif is_union(dtype): + array_type = core_schema.union_schema( + [_lol_dtype(t, _handler) for t in get_args(dtype)] + ) else: try: python_type = np_to_python[dtype] @@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema: array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic core schema to use in the innermost list entry """ - from numpydantic.shape import _is_range + from numpydantic.validation.shape import _is_range shape_parts = [part.strip() for part in shape.__args__[0].split(",")] # labels, if present diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index 1f1edd0..f901994 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` . """ from pathlib import Path -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, Iterable, TypeVar, Union from pydantic_core.core_schema import SerializationInfo @@ -16,6 +16,9 @@ U = TypeVar("U") def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: """Use an interface class to render an array as JSON""" + # perf: keys to skip in generation - anything named "value" is array data. + skip = ["value"] + interface_cls = Interface.match_output(value) array = interface_cls.to_json(value, info) if isinstance(array, JsonDict): @@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: if info.context.get("mark_interface", False): array = interface_cls.mark_json(array) + if isinstance(array, list): + return array + + # ---- Perf Barrier ------------------------------------------------------ + # put context args intended to **wrap** the array above + # put context args intended to **modify** the array below + # + # above, we assume that a list is **data** not to be modified. + # below, we must mark whenever the data is in the line of fire + # to avoid an expensive iteration. + if info.context.get("absolute_paths", False): - array = _absolutize_paths(array) + array = _absolutize_paths(array, skip) else: relative_to = info.context.get("relative_to", ".") - array = _relativize_paths(array, relative_to) + array = _relativize_paths(array, relative_to, skip) else: - # relativize paths by default - array = _relativize_paths(array, ".") + if isinstance(array, list): + return array + + # ---- Perf Barrier ------------------------------------------------------ + # same as above, ensure any keys that contain array values are skipped right now + + array = _relativize_paths(array, ".", skip) return array -def _relativize_paths(value: dict, relative_to: str = ".") -> dict: +def _relativize_paths( + value: dict, relative_to: str = ".", skip: Iterable = tuple() +) -> dict: """ Make paths relative to either the current directory or the provided ``relative_to`` directory, if provided in the context @@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict: # pdb.set_trace() def _r_path(v: Any) -> Any: + if not isinstance(v, (str, Path)): + return v try: path = Path(v) if not path.exists(): @@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict: except (TypeError, ValueError): return v - return _walk_and_apply(value, _r_path) + return _walk_and_apply(value, _r_path, skip) -def _absolutize_paths(value: dict) -> dict: +def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict: def _a_path(v: Any) -> Any: try: path = Path(v) @@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict: except (TypeError, ValueError): return v - return _walk_and_apply(value, _a_path) + return _walk_and_apply(value, _a_path, skip) -def _walk_and_apply(value: T, f: Callable[[U], U]) -> T: +def _walk_and_apply(value: T, f: Callable[[U, bool], U], skip: Iterable = tuple()) -> T: """ Walk an object, applying a function """ if isinstance(value, dict): for k, v in value.items(): + if k in skip: + continue if isinstance(v, dict): - _walk_and_apply(v, f) + _walk_and_apply(v, f, skip) elif isinstance(v, list): - value[k] = [_walk_and_apply(sub_v, f) for sub_v in v] + value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v] else: value[k] = f(v) elif isinstance(value, list): - value = [_walk_and_apply(v, f) for v in value] + value = [_walk_and_apply(v, f, skip) for v in value] else: value = f(value) return value diff --git a/src/numpydantic/validation/__init__.py b/src/numpydantic/validation/__init__.py new file mode 100644 index 0000000..73d7374 --- /dev/null +++ b/src/numpydantic/validation/__init__.py @@ -0,0 +1,11 @@ +""" +Helper functions for validation +""" + +from numpydantic.validation.dtype import validate_dtype +from numpydantic.validation.shape import validate_shape + +__all__ = [ + "validate_dtype", + "validate_shape", +] diff --git a/src/numpydantic/validation/dtype.py b/src/numpydantic/validation/dtype.py new file mode 100644 index 0000000..5eeb124 --- /dev/null +++ b/src/numpydantic/validation/dtype.py @@ -0,0 +1,63 @@ +""" +Helper functions for validation of dtype. + +For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype` +""" + +import sys +from typing import Any, Union, get_args, get_origin + +import numpy as np + +from numpydantic.types import DtypeType + +if sys.version_info >= (3, 10): + from types import UnionType +else: + UnionType = None + + +def validate_dtype(dtype: Any, target: DtypeType) -> bool: + """ + Validate a dtype against the target dtype + + Args: + dtype: The dtype to validate + target (:class:`.DtypeType`): The target dtype + + Returns: + bool: ``True`` if valid, ``False`` otherwise + """ + if target is Any: + return True + + if isinstance(target, tuple): + valid = dtype in target + elif is_union(target): + valid = any( + [validate_dtype(dtype, target_dt) for target_dt in get_args(target)] + ) + elif target is np.str_: + valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( + np.str_, + str, + ) + else: + # try to match as any subclass, if target is a class + try: + valid = issubclass(dtype, target) + except TypeError: + # expected, if dtype or target is not a class + valid = dtype == target + + return valid + + +def is_union(dtype: DtypeType) -> bool: + """ + Check if a dtype is a union + """ + if UnionType is None: + return get_origin(dtype) is Union + else: + return get_origin(dtype) in (Union, UnionType) diff --git a/src/numpydantic/shape.py b/src/numpydantic/validation/shape.py similarity index 98% rename from src/numpydantic/shape.py rename to src/numpydantic/validation/shape.py index 62a567f..e899ecd 100644 --- a/src/numpydantic/shape.py +++ b/src/numpydantic/validation/shape.py @@ -2,7 +2,7 @@ Declaration and validation functions for array shapes. Mostly a mildly modified version of nptyping's -:func:`npytping.shape_expression.check_shape` +:func:`npytping.shape_expression.validate_shape` and its internals to allow for extended syntax, including ranges of shapes. Modifications from nptyping: @@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> @lru_cache -def check_shape(shape: ShapeTuple, target: "Shape") -> bool: +def validate_shape(shape: ShapeTuple, target: "Shape") -> bool: """ Check whether the given shape corresponds to the given shape_expression. :param shape: the shape in question. diff --git a/tests/conftest.py b/tests/conftest.py index c9035f4..96f8a7e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,10 +3,6 @@ import sys import pytest from typing import Any, Tuple, Union, Type -if sys.version_info.minor >= 10: - from typing import TypeAlias -else: - from typing_extensions import TypeAlias from pydantic import BaseModel, computed_field, ConfigDict from numpydantic import NDArray, Shape from numpydantic.ndarray import NDArrayMeta @@ -15,6 +11,15 @@ import numpy as np from tests.fixtures import * +if sys.version_info.minor >= 10: + from typing import TypeAlias + + YES_PIPE = True +else: + from typing_extensions import TypeAlias + + YES_PIPE = False + def pytest_addoption(parser): parser.addoption( @@ -80,6 +85,9 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] STRING: TypeAlias = NDArray[Shape["*, *, *"], str] MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] +UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]] +if YES_PIPE: + UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32] @pytest.fixture( @@ -119,62 +127,93 @@ def shape_cases(request) -> ValidationCase: return request.param -@pytest.fixture( - scope="module", - params=[ - ValidationCase(dtype=float, passes=True), - ValidationCase(dtype=int, passes=False), - ValidationCase(dtype=np.uint8, passes=False), - ValidationCase(annotation=NUMBER, dtype=int, passes=True), - ValidationCase(annotation=NUMBER, dtype=float, passes=True), - ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True), - ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True), - ValidationCase(annotation=NUMBER, dtype=str, passes=False), - ValidationCase(annotation=INTEGER, dtype=int, passes=True), - ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True), - ValidationCase(annotation=INTEGER, dtype=float, passes=False), - ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False), - ValidationCase(annotation=INTEGER, dtype=str, passes=False), - ValidationCase(annotation=FLOAT, dtype=float, passes=True), - ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True), - ValidationCase(annotation=FLOAT, dtype=int, passes=False), - ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False), - ValidationCase(annotation=FLOAT, dtype=str, passes=False), - ValidationCase(annotation=STRING, dtype=str, passes=True), - ValidationCase(annotation=STRING, dtype=int, passes=False), - ValidationCase(annotation=STRING, dtype=float, passes=False), - ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), - ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), - ValidationCase(annotation=MODEL, dtype=int, passes=False), - ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), - ], - ids=[ - "float", - "int", - "uint8", - "number-int", - "number-float", - "number-uint8", - "number-float16", - "number-str", - "integer-int", - "integer-uint8", - "integer-float", - "integer-float32", - "integer-str", - "float-float", - "float-float32", - "float-int", - "float-uint8", - "float-str", - "str-str", - "str-int", - "str-float", - "model-model", - "model-badmodel", - "model-int", - "model-subclass", - ], -) +DTYPE_CASES = [ + ValidationCase(dtype=float, passes=True), + ValidationCase(dtype=int, passes=False), + ValidationCase(dtype=np.uint8, passes=False), + ValidationCase(annotation=NUMBER, dtype=int, passes=True), + ValidationCase(annotation=NUMBER, dtype=float, passes=True), + ValidationCase(annotation=NUMBER, dtype=np.uint8, passes=True), + ValidationCase(annotation=NUMBER, dtype=np.float16, passes=True), + ValidationCase(annotation=NUMBER, dtype=str, passes=False), + ValidationCase(annotation=INTEGER, dtype=int, passes=True), + ValidationCase(annotation=INTEGER, dtype=np.uint8, passes=True), + ValidationCase(annotation=INTEGER, dtype=float, passes=False), + ValidationCase(annotation=INTEGER, dtype=np.float32, passes=False), + ValidationCase(annotation=INTEGER, dtype=str, passes=False), + ValidationCase(annotation=FLOAT, dtype=float, passes=True), + ValidationCase(annotation=FLOAT, dtype=np.float32, passes=True), + ValidationCase(annotation=FLOAT, dtype=int, passes=False), + ValidationCase(annotation=FLOAT, dtype=np.uint8, passes=False), + ValidationCase(annotation=FLOAT, dtype=str, passes=False), + ValidationCase(annotation=STRING, dtype=str, passes=True), + ValidationCase(annotation=STRING, dtype=int, passes=False), + ValidationCase(annotation=STRING, dtype=float, passes=False), + ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True), + ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), + ValidationCase(annotation=MODEL, dtype=int, passes=False), + ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), + ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True), + ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True), + ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False), + ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False), + ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False), +] + +DTYPE_IDS = [ + "float", + "int", + "uint8", + "number-int", + "number-float", + "number-uint8", + "number-float16", + "number-str", + "integer-int", + "integer-uint8", + "integer-float", + "integer-float32", + "integer-str", + "float-float", + "float-float32", + "float-int", + "float-uint8", + "float-str", + "str-str", + "str-int", + "str-float", + "model-model", + "model-badmodel", + "model-int", + "model-subclass", + "union-type-uint32", + "union-type-float32", + "union-type-uint64", + "union-type-float64", + "union-type-str", +] + +if YES_PIPE: + DTYPE_CASES.extend( + [ + ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True), + ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True), + ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False), + ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False), + ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False), + ] + ) + DTYPE_IDS.extend( + [ + "union-pipe-uint32", + "union-pipe-float32", + "union-pipe-uint64", + "union-pipe-float64", + "union-pipe-str", + ] + ) + + +@pytest.fixture(scope="module", params=DTYPE_CASES, ids=DTYPE_IDS) def dtype_cases(request) -> ValidationCase: return request.param diff --git a/tests/test_interface/test_zarr.py b/tests/test_interface/test_zarr.py index ed5c252..6b21b20 100644 --- a/tests/test_interface/test_zarr.py +++ b/tests/test_interface/test_zarr.py @@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array): if roundtrip: if dump_array: - assert as_json["array"] == lol_array + assert as_json["value"] == lol_array else: if as_json.get("file", False): assert "array" not in as_json diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 702dc1a..5d0b2d8 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -10,6 +10,8 @@ from typing import Callable import numpy as np import json +from numpydantic.serialization import _walk_and_apply + pytestmark = pytest.mark.serialization @@ -93,3 +95,34 @@ def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank): # shouldn't have absolutized subpath even if it's pathlike assert data["path"] == expected_dataset + + +def test_walk_and_apply(): + """ + Walk and apply should recursively apply a function to everything in a nesty structure + """ + test = { + "a": 1, + "b": 1, + "c": [ + {"a": 1, "b": {"a": 1, "b": 1}, "c": [1, 1, 1]}, + {"a": 1, "b": [1, 1, 1]}, + ], + } + + def _mult_2(v, skip: bool = False): + return v * 2 + + def _assert_2(v, skip: bool = False): + assert v == 2 + return v + + walked = _walk_and_apply(test, _mult_2) + _walk_and_apply(walked, _assert_2) + + assert walked["a"] == 2 + assert walked["c"][0]["a"] == 2 + assert walked["c"][0]["b"]["a"] == 2 + assert all([w == 2 for w in walked["c"][0]["c"]]) + assert walked["c"][1]["a"] == 2 + assert all([w == 2 for w in walked["c"][1]["b"]])