diff --git a/.gitignore b/.gitignore index 3b0fc39..1470b8c 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,6 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .pdm-python -ndarray.pyi \ No newline at end of file +ndarray.pyi + +prof/ \ No newline at end of file diff --git a/docs/api/validation/dtype.md b/docs/api/validation/dtype.md new file mode 100644 index 0000000..fd599ca --- /dev/null +++ b/docs/api/validation/dtype.md @@ -0,0 +1,7 @@ +# dtype + +```{eval-rst} +.. automodule:: numpydantic.validation.dtype + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/docs/api/validation/index.md b/docs/api/validation/index.md new file mode 100644 index 0000000..d0f5a11 --- /dev/null +++ b/docs/api/validation/index.md @@ -0,0 +1,6 @@ +# validation + +```{toctree} +dtype +shape +``` \ No newline at end of file diff --git a/docs/api/shape.md b/docs/api/validation/shape.md similarity index 54% rename from docs/api/shape.md rename to docs/api/validation/shape.md index 4c8638c..60a7a19 100644 --- a/docs/api/shape.md +++ b/docs/api/validation/shape.md @@ -1,7 +1,7 @@ # shape ```{eval-rst} -.. automodule:: numpydantic.shape +.. automodule:: numpydantic.validation.shape :members: :undoc-members: ``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 0880a59..9caaaf7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -484,13 +484,13 @@ interfaces api/index api/interface/index +api/validation/index api/dtype api/ndarray api/maps api/meta api/schema api/serialization -api/shape api/types ``` diff --git a/src/numpydantic/__init__.py b/src/numpydantic/__init__.py index 803d8d2..98b8bd1 100644 --- a/src/numpydantic/__init__.py +++ b/src/numpydantic/__init__.py @@ -4,7 +4,7 @@ from numpydantic.ndarray import NDArray from numpydantic.meta import update_ndarray_stub -from numpydantic.shape import Shape +from numpydantic.validation.shape import Shape update_ndarray_stub() diff --git a/src/numpydantic/dtype.py b/src/numpydantic/dtype.py index 12d766a..84b28cc 100644 --- a/src/numpydantic/dtype.py +++ b/src/numpydantic/dtype.py @@ -12,6 +12,9 @@ module, rather than needing to import each individually. Some types like `Integer` are compound types - tuples of multiple dtypes. Check these using ``in`` rather than ``==``. This interface will develop in future versions to allow a single dtype check. + +For internal helper functions for validating dtype, +see :mod:`numpydantic.validation.dtype` """ import sys diff --git a/src/numpydantic/interface/dask.py b/src/numpydantic/interface/dask.py index cd36a65..95d0619 100644 --- a/src/numpydantic/interface/dask.py +++ b/src/numpydantic/interface/dask.py @@ -33,11 +33,11 @@ class DaskJsonDict(JsonDict): name: str chunks: Iterable[tuple[int, ...]] dtype: str - array: list + value: list def to_array_input(self) -> DaskArray: """Construct a dask array""" - np_array = np.array(self.array, dtype=self.dtype) + np_array = np.array(self.value, dtype=self.dtype) array = from_array( np_array, name=self.name, @@ -100,7 +100,7 @@ class DaskInterface(Interface): if info.round_trip: as_json = DaskJsonDict( type=cls.name, - array=as_json, + value=as_json, name=array.name, chunks=array.chunks, dtype=str(np_array.dtype), diff --git a/src/numpydantic/interface/interface.py b/src/numpydantic/interface/interface.py index bee85b6..42bb891 100644 --- a/src/numpydantic/interface/interface.py +++ b/src/numpydantic/interface/interface.py @@ -20,8 +20,8 @@ from numpydantic.exceptions import ( ShapeError, TooManyMatchesError, ) -from numpydantic.shape import check_shape from numpydantic.types import DtypeType, NDArrayType, ShapeType +from numpydantic.validation import validate_dtype, validate_shape T = TypeVar("T", bound=NDArrayType) U = TypeVar("U", bound="JsonDict") @@ -76,6 +76,21 @@ class InterfaceMark(BaseModel): class JsonDict(BaseModel): """ Representation of array when dumped with round_trip == True. + + .. admonition:: Developer's Note + + Any JsonDict that contains an actual array should be named ``value`` + rather than array (or any other name), and nothing but the + array data should be named ``value`` . + + During JSON serialization, it becomes ambiguous what contains an array + of data vs. an array of metadata. For the moment we would like to + reserve the ability to have lists of metadata, so until we rule that out, + we would like to be able to avoid iterating over every element of an array + in any context parameter transformation like relativizing/absolutizing paths. + To avoid that, it's good to agree on a single value name -- ``value`` -- + and avoid using it for anything else. + """ type: str @@ -274,25 +289,7 @@ class Interface(ABC, Generic[T]): Validate the dtype of the given array, returning ``True`` if valid, ``False`` if not. """ - if self.dtype is Any: - return True - - if isinstance(self.dtype, tuple): - valid = dtype in self.dtype - elif self.dtype is np.str_: - valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( - np.str_, - str, - ) - else: - # try to match as any subclass, if self.dtype is a class - try: - valid = issubclass(dtype, self.dtype) - except TypeError: - # expected, if dtype or self.dtype is not a class - valid = dtype == self.dtype - - return valid + return validate_dtype(dtype, self.dtype) def raise_for_dtype(self, valid: bool, dtype: DtypeType) -> None: """ @@ -326,7 +323,7 @@ class Interface(ABC, Generic[T]): if self.shape is Any: return True - return check_shape(shape, self.shape) + return validate_shape(shape, self.shape) def raise_for_shape(self, valid: bool, shape: Tuple[int, ...]) -> None: """ diff --git a/src/numpydantic/interface/numpy.py b/src/numpydantic/interface/numpy.py index ad97474..6c84232 100644 --- a/src/numpydantic/interface/numpy.py +++ b/src/numpydantic/interface/numpy.py @@ -27,13 +27,13 @@ class NumpyJsonDict(JsonDict): type: Literal["numpy"] dtype: str - array: list + value: list def to_array_input(self) -> ndarray: """ Construct a numpy array """ - return np.array(self.array, dtype=self.dtype) + return np.array(self.value, dtype=self.dtype) class NumpyInterface(Interface): @@ -99,6 +99,6 @@ class NumpyInterface(Interface): if info.round_trip: json_array = NumpyJsonDict( - type=cls.name, dtype=str(array.dtype), array=json_array + type=cls.name, dtype=str(array.dtype), value=json_array ) return json_array diff --git a/src/numpydantic/interface/zarr.py b/src/numpydantic/interface/zarr.py index 41cad03..5dc647e 100644 --- a/src/numpydantic/interface/zarr.py +++ b/src/numpydantic/interface/zarr.py @@ -63,7 +63,7 @@ class ZarrJsonDict(JsonDict): type: Literal["zarr"] file: Optional[str] = None path: Optional[str] = None - array: Optional[list] = None + value: Optional[list] = None def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]: """ @@ -73,7 +73,7 @@ class ZarrJsonDict(JsonDict): if self.file: array = ZarrArrayPath(file=self.file, path=self.path) else: - array = zarr.array(self.array) + array = zarr.array(self.value) return array @@ -202,7 +202,7 @@ class ZarrInterface(Interface): as_json["info"]["hexdigest"] = array.hexdigest() if dump_array or not is_file: - as_json["array"] = array[:].tolist() + as_json["value"] = array[:].tolist() as_json = ZarrJsonDict(**as_json) else: diff --git a/src/numpydantic/ndarray.py b/src/numpydantic/ndarray.py index fb81f69..6969d44 100644 --- a/src/numpydantic/ndarray.py +++ b/src/numpydantic/ndarray.py @@ -13,7 +13,7 @@ Extension of nptyping NDArray for pydantic that allows for JSON-Schema serializa """ -from typing import TYPE_CHECKING, Any, Tuple +from typing import TYPE_CHECKING, Any, Literal, Tuple, get_origin import numpy as np from pydantic import GetJsonSchemaHandler @@ -29,6 +29,7 @@ from numpydantic.schema import ( ) from numpydantic.serialization import jsonize_array from numpydantic.types import DtypeType, NDArrayType, ShapeType +from numpydantic.validation.dtype import is_union from numpydantic.vendor.nptyping.error import InvalidArgumentsError from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta from numpydantic.vendor.nptyping.nptyping_type import NPTypingType @@ -86,11 +87,18 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): except InterfaceError: return False + def _is_literal_like(cls, item: Any) -> bool: + """ + Changes from nptyping: + - doesn't just ducktype for literal but actually, yno, checks for being literal + """ + return get_origin(item) is Literal + def _get_shape(cls, dtype_candidate: Any) -> "Shape": """ Override of base method to use our local definition of shape """ - from numpydantic.shape import Shape + from numpydantic.validation.shape import Shape if dtype_candidate is Any or dtype_candidate is Shape: shape = Any @@ -120,7 +128,7 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"): if dtype_candidate is Any: dtype = Any - elif is_dtype: + elif is_dtype or is_union(dtype_candidate): dtype = dtype_candidate elif issubclass(dtype_candidate, Structure): # pragma: no cover dtype = dtype_candidate diff --git a/src/numpydantic/schema.py b/src/numpydantic/schema.py index cafa1f4..bfea3aa 100644 --- a/src/numpydantic/schema.py +++ b/src/numpydantic/schema.py @@ -5,7 +5,7 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in import hashlib import json -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, get_args import numpy as np from pydantic import BaseModel @@ -16,6 +16,7 @@ from numpydantic import dtype as dt from numpydantic.interface import Interface from numpydantic.maps import np_to_python from numpydantic.types import DtypeType, NDArrayType, ShapeType +from numpydantic.validation.dtype import is_union from numpydantic.vendor.nptyping.structure import StructureMeta if TYPE_CHECKING: # pragma: no cover @@ -51,7 +52,6 @@ def _lol_dtype( """Get the innermost dtype schema to use in the generated pydantic schema""" if isinstance(dtype, StructureMeta): # pragma: no cover raise NotImplementedError("Structured dtypes are currently unsupported") - if isinstance(dtype, tuple): # if it's a meta-type that refers to a generic float/int, just make that if dtype in (dt.Float, dt.Number): @@ -66,7 +66,10 @@ def _lol_dtype( array_type = core_schema.union_schema( [_lol_dtype(t, _handler) for t in types_] ) - + elif is_union(dtype): + array_type = core_schema.union_schema( + [_lol_dtype(t, _handler) for t in get_args(dtype)] + ) else: try: python_type = np_to_python[dtype] @@ -110,7 +113,7 @@ def list_of_lists_schema(shape: "Shape", array_type: CoreSchema) -> ListSchema: array_type ( :class:`pydantic_core.CoreSchema` ): The pre-rendered pydantic core schema to use in the innermost list entry """ - from numpydantic.shape import _is_range + from numpydantic.validation.shape import _is_range shape_parts = [part.strip() for part in shape.__args__[0].split(",")] # labels, if present diff --git a/src/numpydantic/serialization.py b/src/numpydantic/serialization.py index 1f1edd0..07924eb 100644 --- a/src/numpydantic/serialization.py +++ b/src/numpydantic/serialization.py @@ -4,7 +4,7 @@ and :func:`pydantic.BaseModel.model_dump_json` . """ from pathlib import Path -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, Iterable, TypeVar, Union from pydantic_core.core_schema import SerializationInfo @@ -16,6 +16,9 @@ U = TypeVar("U") def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: """Use an interface class to render an array as JSON""" + # perf: keys to skip in generation - anything named "value" is array data. + skip = ["value"] + interface_cls = Interface.match_output(value) array = interface_cls.to_json(value, info) if isinstance(array, JsonDict): @@ -25,19 +28,37 @@ def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]: if info.context.get("mark_interface", False): array = interface_cls.mark_json(array) + if isinstance(array, list): + return array + + # ---- Perf Barrier ------------------------------------------------------ + # put context args intended to **wrap** the array above + # put context args intended to **modify** the array below + # + # above, we assume that a list is **data** not to be modified. + # below, we must mark whenever the data is in the line of fire + # to avoid an expensive iteration. + if info.context.get("absolute_paths", False): - array = _absolutize_paths(array) + array = _absolutize_paths(array, skip) else: relative_to = info.context.get("relative_to", ".") - array = _relativize_paths(array, relative_to) + array = _relativize_paths(array, relative_to, skip) else: - # relativize paths by default - array = _relativize_paths(array, ".") + if isinstance(array, list): + return array + + # ---- Perf Barrier ------------------------------------------------------ + # same as above, ensure any keys that contain array values are skipped right now + + array = _relativize_paths(array, ".", skip) return array -def _relativize_paths(value: dict, relative_to: str = ".") -> dict: +def _relativize_paths( + value: dict, relative_to: str = ".", skip: Iterable = tuple() +) -> dict: """ Make paths relative to either the current directory or the provided ``relative_to`` directory, if provided in the context @@ -46,6 +67,8 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict: # pdb.set_trace() def _r_path(v: Any) -> Any: + if not isinstance(v, (str, Path)): + return v try: path = Path(v) if not path.exists(): @@ -54,10 +77,10 @@ def _relativize_paths(value: dict, relative_to: str = ".") -> dict: except (TypeError, ValueError): return v - return _walk_and_apply(value, _r_path) + return _walk_and_apply(value, _r_path, skip) -def _absolutize_paths(value: dict) -> dict: +def _absolutize_paths(value: dict, skip: Iterable = tuple()) -> dict: def _a_path(v: Any) -> Any: try: path = Path(v) @@ -67,23 +90,25 @@ def _absolutize_paths(value: dict) -> dict: except (TypeError, ValueError): return v - return _walk_and_apply(value, _a_path) + return _walk_and_apply(value, _a_path, skip) -def _walk_and_apply(value: T, f: Callable[[U], U]) -> T: +def _walk_and_apply(value: T, f: Callable[[U], U], skip: Iterable = tuple()) -> T: """ Walk an object, applying a function """ if isinstance(value, dict): for k, v in value.items(): + if k in skip: + continue if isinstance(v, dict): - _walk_and_apply(v, f) + _walk_and_apply(v, f, skip) elif isinstance(v, list): - value[k] = [_walk_and_apply(sub_v, f) for sub_v in v] + value[k] = [_walk_and_apply(sub_v, f, skip) for sub_v in v] else: value[k] = f(v) elif isinstance(value, list): - value = [_walk_and_apply(v, f) for v in value] + value = [_walk_and_apply(v, f, skip) for v in value] else: value = f(value) return value diff --git a/src/numpydantic/validation/__init__.py b/src/numpydantic/validation/__init__.py new file mode 100644 index 0000000..73d7374 --- /dev/null +++ b/src/numpydantic/validation/__init__.py @@ -0,0 +1,11 @@ +""" +Helper functions for validation +""" + +from numpydantic.validation.dtype import validate_dtype +from numpydantic.validation.shape import validate_shape + +__all__ = [ + "validate_dtype", + "validate_shape", +] diff --git a/src/numpydantic/validation/dtype.py b/src/numpydantic/validation/dtype.py new file mode 100644 index 0000000..bc7723c --- /dev/null +++ b/src/numpydantic/validation/dtype.py @@ -0,0 +1,55 @@ +""" +Helper functions for validation of dtype. + +For literal dtypes intended for use by end-users, see :mod:`numpydantic.dtype` +""" + +from types import UnionType +from typing import Any, Union, get_args, get_origin + +import numpy as np + +from numpydantic.types import DtypeType + + +def validate_dtype(dtype: Any, target: DtypeType) -> bool: + """ + Validate a dtype against the target dtype + + Args: + dtype: The dtype to validate + target (:class:`.DtypeType`): The target dtype + + Returns: + bool: ``True`` if valid, ``False`` otherwise + """ + if target is Any: + return True + + if isinstance(target, tuple): + valid = dtype in target + elif is_union(target): + valid = any( + [validate_dtype(dtype, target_dt) for target_dt in get_args(target)] + ) + elif target is np.str_: + valid = getattr(dtype, "type", None) in (np.str_, str) or dtype in ( + np.str_, + str, + ) + else: + # try to match as any subclass, if target is a class + try: + valid = issubclass(dtype, target) + except TypeError: + # expected, if dtype or target is not a class + valid = dtype == target + + return valid + + +def is_union(dtype: DtypeType) -> bool: + """ + Check if a dtype is a union + """ + return get_origin(dtype) in (Union, UnionType) diff --git a/src/numpydantic/shape.py b/src/numpydantic/validation/shape.py similarity index 98% rename from src/numpydantic/shape.py rename to src/numpydantic/validation/shape.py index 62a567f..e899ecd 100644 --- a/src/numpydantic/shape.py +++ b/src/numpydantic/validation/shape.py @@ -2,7 +2,7 @@ Declaration and validation functions for array shapes. Mostly a mildly modified version of nptyping's -:func:`npytping.shape_expression.check_shape` +:func:`npytping.shape_expression.validate_shape` and its internals to allow for extended syntax, including ranges of shapes. Modifications from nptyping: @@ -105,7 +105,7 @@ def validate_shape_expression(shape_expression: Union[ShapeExpression, Any]) -> @lru_cache -def check_shape(shape: ShapeTuple, target: "Shape") -> bool: +def validate_shape(shape: ShapeTuple, target: "Shape") -> bool: """ Check whether the given shape corresponds to the given shape_expression. :param shape: the shape in question. diff --git a/tests/conftest.py b/tests/conftest.py index c9035f4..076e95d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,6 +80,8 @@ INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] STRING: TypeAlias = NDArray[Shape["*, *, *"], str] MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel] +UNION_PIPE: TypeAlias = NDArray[Shape["*, *, *"], np.uint32 | np.float32] +UNION_TYPE: TypeAlias = NDArray[Shape["*, *, *"], Union[np.uint32, np.float32]] @pytest.fixture( @@ -147,6 +149,16 @@ def shape_cases(request) -> ValidationCase: ValidationCase(annotation=MODEL, dtype=BadModel, passes=False), ValidationCase(annotation=MODEL, dtype=int, passes=False), ValidationCase(annotation=MODEL, dtype=SubClass, passes=True), + ValidationCase(annotation=UNION_PIPE, dtype=np.uint32, passes=True), + ValidationCase(annotation=UNION_PIPE, dtype=np.float32, passes=True), + ValidationCase(annotation=UNION_PIPE, dtype=np.uint64, passes=False), + ValidationCase(annotation=UNION_PIPE, dtype=np.float64, passes=False), + ValidationCase(annotation=UNION_PIPE, dtype=str, passes=False), + ValidationCase(annotation=UNION_TYPE, dtype=np.uint32, passes=True), + ValidationCase(annotation=UNION_TYPE, dtype=np.float32, passes=True), + ValidationCase(annotation=UNION_TYPE, dtype=np.uint64, passes=False), + ValidationCase(annotation=UNION_TYPE, dtype=np.float64, passes=False), + ValidationCase(annotation=UNION_TYPE, dtype=str, passes=False), ], ids=[ "float", @@ -174,6 +186,16 @@ def shape_cases(request) -> ValidationCase: "model-badmodel", "model-int", "model-subclass", + "union-pipe-uint32", + "union-pipe-float32", + "union-pipe-uint64", + "union-pipe-float64", + "union-pipe-str", + "union-type-uint32", + "union-type-float32", + "union-type-uint64", + "union-type-float64", + "union-type-str", ], ) def dtype_cases(request) -> ValidationCase: diff --git a/tests/test_interface/test_zarr.py b/tests/test_interface/test_zarr.py index ed5c252..6b21b20 100644 --- a/tests/test_interface/test_zarr.py +++ b/tests/test_interface/test_zarr.py @@ -151,7 +151,7 @@ def test_zarr_to_json(store, model_blank, roundtrip, dump_array): if roundtrip: if dump_array: - assert as_json["array"] == lol_array + assert as_json["value"] == lol_array else: if as_json.get("file", False): assert "array" not in as_json