mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-09 21:44:27 +00:00
dask and hdf5 array interfaces
This commit is contained in:
parent
a6391c08a3
commit
46060c1154
18 changed files with 330 additions and 37 deletions
|
@ -34,6 +34,8 @@ intersphinx_mapping = {
|
||||||
"linkml": ("https://linkml.io/linkml/", None),
|
"linkml": ("https://linkml.io/linkml/", None),
|
||||||
"linkml_runtime": ("https://linkml.io/linkml/", None),
|
"linkml_runtime": ("https://linkml.io/linkml/", None),
|
||||||
"linkml-runtime": ("https://linkml.io/linkml/", None),
|
"linkml-runtime": ("https://linkml.io/linkml/", None),
|
||||||
|
"dask": ("https://docs.dask.org/en/stable/", None),
|
||||||
|
"h5py": ("https://docs.h5py.org/en/stable/", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
# -- Options for HTML output -------------------------------------------------
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
26
pdm.lock
26
pdm.lock
|
@ -5,7 +5,7 @@
|
||||||
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests"]
|
groups = ["default", "arrays", "dask", "dev", "docs", "hdf5", "tests"]
|
||||||
strategy = ["cross_platform", "inherit_metadata"]
|
strategy = ["cross_platform", "inherit_metadata"]
|
||||||
lock_version = "4.4.1"
|
lock_version = "4.4.1"
|
||||||
content_hash = "sha256:761d4dccd4e594b9e441dddefdb5677d22a4a94c129183e0bb8c88d9acbea1b9"
|
content_hash = "sha256:37b2b742a3addd598fce4747623d941ce0b7b2f18b0c33e2a9a2015196239902"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alabaster"
|
name = "alabaster"
|
||||||
|
@ -371,7 +371,7 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dask"
|
name = "dask"
|
||||||
version = "2024.4.0"
|
version = "2024.4.1"
|
||||||
requires_python = ">=3.9"
|
requires_python = ">=3.9"
|
||||||
summary = "Parallel PyData with Task Scheduling"
|
summary = "Parallel PyData with Task Scheduling"
|
||||||
groups = ["arrays", "dask", "dev", "tests"]
|
groups = ["arrays", "dask", "dev", "tests"]
|
||||||
|
@ -386,24 +386,8 @@ dependencies = [
|
||||||
"toolz>=0.10.0",
|
"toolz>=0.10.0",
|
||||||
]
|
]
|
||||||
files = [
|
files = [
|
||||||
{file = "dask-2024.4.0-py3-none-any.whl", hash = "sha256:f8332781ffde3d3e49df31fe4066e1eab571a87b94a11661a8ecf06e2892ee6d"},
|
{file = "dask-2024.4.1-py3-none-any.whl", hash = "sha256:cac5d28b9de7a7cfde46d6fbd8fa81f5654980d010b44d1dbe04dd13b5b63126"},
|
||||||
{file = "dask-2024.4.0.tar.gz", hash = "sha256:d5be22660b332865e7e868df2f1322a75f6cacaf8dd9ec08057e6fa8a96a19ac"},
|
{file = "dask-2024.4.1.tar.gz", hash = "sha256:6cd8eb03ddc8dc08d6ca5b167b8de559872bc51cc2b6587d0e9dc754ab19cdf0"},
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "dask"
|
|
||||||
version = "2024.4.0"
|
|
||||||
extras = ["array"]
|
|
||||||
requires_python = ">=3.9"
|
|
||||||
summary = "Parallel PyData with Task Scheduling"
|
|
||||||
groups = ["arrays", "dask", "dev", "tests"]
|
|
||||||
dependencies = [
|
|
||||||
"dask==2024.4.0",
|
|
||||||
"numpy>=1.21",
|
|
||||||
]
|
|
||||||
files = [
|
|
||||||
{file = "dask-2024.4.0-py3-none-any.whl", hash = "sha256:f8332781ffde3d3e49df31fe4066e1eab571a87b94a11661a8ecf06e2892ee6d"},
|
|
||||||
{file = "dask-2024.4.0.tar.gz", hash = "sha256:d5be22660b332865e7e868df2f1322a75f6cacaf8dd9ec08057e6fa8a96a19ac"},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -811,7 +795,7 @@ name = "numpy"
|
||||||
version = "1.26.4"
|
version = "1.26.4"
|
||||||
requires_python = ">=3.9"
|
requires_python = ">=3.9"
|
||||||
summary = "Fundamental package for array computing in Python"
|
summary = "Fundamental package for array computing in Python"
|
||||||
groups = ["arrays", "dask", "default", "dev", "hdf5", "tests"]
|
groups = ["arrays", "default", "dev", "hdf5", "tests"]
|
||||||
files = [
|
files = [
|
||||||
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
|
{file = "numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0"},
|
||||||
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
|
{file = "numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a"},
|
||||||
|
|
|
@ -9,6 +9,7 @@ dependencies = [
|
||||||
"pydantic>=2.3.0",
|
"pydantic>=2.3.0",
|
||||||
"nptyping>=2.5.0",
|
"nptyping>=2.5.0",
|
||||||
"blosc2<3.0.0,>=2.5.1",
|
"blosc2<3.0.0,>=2.5.1",
|
||||||
|
"numpy>=1.24.0",
|
||||||
]
|
]
|
||||||
requires-python = "<4.0,>=3.9"
|
requires-python = "<4.0,>=3.9"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -17,7 +18,7 @@ license = {text = "MIT"}
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dask = [
|
dask = [
|
||||||
"dask[array]>=2024.1.1"
|
"dask>=2024.4.0",
|
||||||
]
|
]
|
||||||
hdf5 = [
|
hdf5 = [
|
||||||
"h5py>=3.10.0"
|
"h5py>=3.10.0"
|
||||||
|
@ -99,9 +100,11 @@ select = [
|
||||||
|
|
||||||
]
|
]
|
||||||
ignore = [
|
ignore = [
|
||||||
"ANN101", "ANN102", "ANN401",
|
"ANN101", "ANN102", "ANN401", "ANN204",
|
||||||
# builtin type annotations
|
# builtin type annotations
|
||||||
"UP006", "UP035",
|
"UP006", "UP035",
|
||||||
|
# | for Union types (only supported >=3.10
|
||||||
|
"UP007", "UP038",
|
||||||
# docstrings for __init__
|
# docstrings for __init__
|
||||||
"D107",
|
"D107",
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,7 +6,10 @@ from numpydantic.monkeypatch import apply_patches
|
||||||
apply_patches()
|
apply_patches()
|
||||||
|
|
||||||
from numpydantic.ndarray import NDArray
|
from numpydantic.ndarray import NDArray
|
||||||
|
|
||||||
from numpydantic.meta import update_ndarray_stub
|
from numpydantic.meta import update_ndarray_stub
|
||||||
|
|
||||||
|
from nptyping import Shape
|
||||||
|
|
||||||
update_ndarray_stub()
|
update_ndarray_stub()
|
||||||
|
|
||||||
|
__all__ = ["NDArray", "Shape"]
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
from numpydantic.interface.dask import DaskInterface
|
||||||
|
from numpydantic.interface.hdf5 import H5Interface
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface
|
||||||
from numpydantic.interface.numpy import NumpyInterface
|
from numpydantic.interface.numpy import NumpyInterface
|
||||||
|
|
||||||
__all__ = ["Interface", "NumpyInterface"]
|
__all__ = ["Interface", "DaskInterface", "H5Interface", "NumpyInterface"]
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
from typing import Any
|
||||||
|
from numpydantic.interface.interface import Interface
|
||||||
|
|
||||||
|
try:
|
||||||
|
from dask.array.core import Array as DaskArray
|
||||||
|
except ImportError:
|
||||||
|
DaskArray = None
|
||||||
|
|
||||||
|
|
||||||
|
class DaskInterface(Interface):
|
||||||
|
"""
|
||||||
|
Interface for Dask :class:`~dask.array.core.Array`
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_types = (DaskArray,)
|
||||||
|
return_type = DaskArray
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check(cls, array: Any) -> bool:
|
||||||
|
"""
|
||||||
|
check if array is a dask array
|
||||||
|
"""
|
||||||
|
if DaskArray is not None and isinstance(array, DaskArray):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls) -> bool:
|
||||||
|
"""check if we successfully imported dask"""
|
||||||
|
return DaskArray is not None
|
|
@ -0,0 +1,143 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, NamedTuple, Tuple, Union, TypeAlias
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from numpydantic.interface.interface import Interface
|
||||||
|
from numpydantic.types import NDArrayType
|
||||||
|
|
||||||
|
try:
|
||||||
|
import h5py
|
||||||
|
except ImportError:
|
||||||
|
h5py = None
|
||||||
|
|
||||||
|
H5Arraylike: TypeAlias = Tuple[Union[Path, str], str]
|
||||||
|
|
||||||
|
|
||||||
|
class H5Array(NamedTuple):
|
||||||
|
"""Location specifier for arrays within an HDF5 file"""
|
||||||
|
|
||||||
|
file: Union[Path, str]
|
||||||
|
"""Location of HDF5 file"""
|
||||||
|
path: str
|
||||||
|
"""Path within the HDF5 file"""
|
||||||
|
|
||||||
|
|
||||||
|
class H5Proxy:
|
||||||
|
"""
|
||||||
|
Proxy class to mimic numpy-like array behavior with an HDF5 array
|
||||||
|
|
||||||
|
The attribute and item access methods only open the file for the duration of the method,
|
||||||
|
making it less perilous to share this object between threads and processes.
|
||||||
|
|
||||||
|
This class attempts to be a passthrough class to a :class:`h5py.Dataset` object,
|
||||||
|
including its attributes and item getters/setters.
|
||||||
|
|
||||||
|
When using read-only methods, no locking is attempted (beyond the HDF5 defaults),
|
||||||
|
but when using the write methods (setting an array value), try and use the ``locking``
|
||||||
|
methods of :class:`h5py.File` .
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (pathlib.Path | str): Location of hdf5 file on filesystem
|
||||||
|
path (str): Path to array within hdf5 file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, file: Union[Path, str], path: str):
|
||||||
|
self.file = Path(file)
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
def array_exists(self) -> bool:
|
||||||
|
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
||||||
|
with h5py.File(self.file, "r") as h5f:
|
||||||
|
obj = h5f.get(self.path)
|
||||||
|
return obj is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_h5array(cls, h5array: H5Array) -> "H5Proxy":
|
||||||
|
"""Instantiate using :class:`.H5Array`"""
|
||||||
|
return H5Proxy(file=h5array.file, path=h5array.path)
|
||||||
|
|
||||||
|
def __getattr__(self, item: str):
|
||||||
|
with h5py.File(self.file, "r") as h5f:
|
||||||
|
obj = h5f.get(self.path)
|
||||||
|
return getattr(obj, item)
|
||||||
|
|
||||||
|
def __getitem__(self, item: Union[int, slice]) -> np.ndarray:
|
||||||
|
with h5py.File(self.file, "r") as h5f:
|
||||||
|
obj = h5f.get(self.path)
|
||||||
|
return obj[item]
|
||||||
|
|
||||||
|
def __setitem__(self, key: Union[int, slice], value: Union[int, float, np.ndarray]):
|
||||||
|
with h5py.File(self.file, "r+", locking=True) as h5f:
|
||||||
|
obj = h5f.get(self.path)
|
||||||
|
obj[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
class H5Interface(Interface):
|
||||||
|
"""
|
||||||
|
Interface for Arrays stored as datasets within an HDF5 file.
|
||||||
|
|
||||||
|
Takes a :class:`.H5Array` specifier to select a :class:`h5py.Dataset` from a
|
||||||
|
:class:`h5py.File` and returns a :class:`.H5Proxy` class that acts like a
|
||||||
|
passthrough numpy-like interface to the dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_types = (
|
||||||
|
H5Array,
|
||||||
|
H5Arraylike,
|
||||||
|
)
|
||||||
|
return_type = H5Proxy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls) -> bool:
|
||||||
|
"""Check whether h5py can be imported"""
|
||||||
|
return h5py is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check(cls, array: Union[H5Array, Tuple[Union[Path, str], str]]) -> bool:
|
||||||
|
"""Check that the given array is a :class:`.H5Array` or something that resembles one."""
|
||||||
|
if isinstance(array, H5Array):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if isinstance(array, (tuple, list)) and len(array) == 2:
|
||||||
|
# check that the first arg is an hdf5 file
|
||||||
|
try:
|
||||||
|
file = Path(array[0])
|
||||||
|
except TypeError:
|
||||||
|
# not a path, we don't apply.
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not file.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# hdf5 files are commonly given odd suffixes,
|
||||||
|
# so we just try and open it and see what happens
|
||||||
|
try:
|
||||||
|
with h5py.File(file, "r"):
|
||||||
|
# don't check that the array exists and raise here,
|
||||||
|
# this check is just for whether the validator applies or not.
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
except (FileNotFoundError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def before_validation(self, array: Any) -> NDArrayType:
|
||||||
|
"""Create an :class:`.H5Proxy` to use throughout validation"""
|
||||||
|
if isinstance(array, H5Array):
|
||||||
|
array = H5Proxy.from_h5array(h5array=array)
|
||||||
|
elif isinstance(array, (tuple, list)) and len(array) == 2:
|
||||||
|
array = H5Proxy(file=array[0], path=array[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Need to specify a file and a path within an HDF5 file to use the HDF5 Interface"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not array.array_exists():
|
||||||
|
raise ValueError(
|
||||||
|
f"HDF5 file located at {array.file}, "
|
||||||
|
f"but no array found at {array.path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return array
|
|
@ -1,6 +1,6 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import Any, Generic, List, Type, TypeVar, Tuple
|
from typing import Any, Generic, Tuple, Type, TypeVar
|
||||||
|
|
||||||
from nptyping.shape_expression import check_shape
|
from nptyping.shape_expression import check_shape
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ class Interface(ABC, Generic[T]):
|
||||||
Abstract parent class for interfaces to different array formats
|
Abstract parent class for interfaces to different array formats
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
input_types: Tuple[Any, ...]
|
||||||
return_type: Type[T]
|
return_type: Type[T]
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
|
|
||||||
|
@ -109,6 +110,18 @@ class Interface(ABC, Generic[T]):
|
||||||
"""Return types for all enabled interfaces"""
|
"""Return types for all enabled interfaces"""
|
||||||
return tuple([i.return_type for i in cls.interfaces()])
|
return tuple([i.return_type for i in cls.interfaces()])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def input_types(cls) -> Tuple[Any, ...]:
|
||||||
|
"""Input types for all enabled interfaces"""
|
||||||
|
in_types = []
|
||||||
|
for iface in cls.interfaces():
|
||||||
|
if isinstance(iface.input_types, tuple | list):
|
||||||
|
in_types.extend(iface.input_types)
|
||||||
|
else:
|
||||||
|
in_types.append(iface.input_types)
|
||||||
|
|
||||||
|
return tuple(in_types)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def match(cls, array: Any) -> Type["Interface"]:
|
def match(cls, array: Any) -> Type["Interface"]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -17,6 +17,7 @@ class NumpyInterface(Interface):
|
||||||
Numpy :class:`~numpy.ndarray` s!
|
Numpy :class:`~numpy.ndarray` s!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
input_types = (ndarray, list)
|
||||||
return_type = ndarray
|
return_type = ndarray
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -3,9 +3,12 @@ Metaprogramming functions for numpydantic to modify itself :)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from warnings import warn
|
||||||
|
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface
|
||||||
|
|
||||||
|
_BUILTIN_IMPORTS = ("import typing", "import pathlib")
|
||||||
|
|
||||||
|
|
||||||
def generate_ndarray_stub() -> str:
|
def generate_ndarray_stub() -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -14,11 +17,16 @@ def generate_ndarray_stub() -> str:
|
||||||
|
|
||||||
import_strings = [
|
import_strings = [
|
||||||
f"from {arr.__module__} import {arr.__name__}"
|
f"from {arr.__module__} import {arr.__name__}"
|
||||||
for arr in Interface.array_types()
|
for arr in Interface.input_types()
|
||||||
|
if arr.__module__ != "builtins"
|
||||||
]
|
]
|
||||||
|
import_strings.extend(_BUILTIN_IMPORTS)
|
||||||
import_string = "\n".join(import_strings)
|
import_string = "\n".join(import_strings)
|
||||||
|
|
||||||
class_names = [arr.__name__ for arr in Interface.array_types()]
|
class_names = [
|
||||||
|
arr.__name__ if arr.__module__ != "typing" else str(arr)
|
||||||
|
for arr in Interface.input_types()
|
||||||
|
]
|
||||||
class_union = " | ".join(class_names)
|
class_union = " | ".join(class_names)
|
||||||
ndarray_type = "NDArray = " + class_union
|
ndarray_type = "NDArray = " + class_union
|
||||||
|
|
||||||
|
@ -32,8 +40,11 @@ def update_ndarray_stub() -> None:
|
||||||
"""
|
"""
|
||||||
from numpydantic import ndarray
|
from numpydantic import ndarray
|
||||||
|
|
||||||
stub_string = generate_ndarray_stub()
|
try:
|
||||||
|
stub_string = generate_ndarray_stub()
|
||||||
|
|
||||||
pyi_file = Path(ndarray.__file__).with_suffix(".pyi")
|
pyi_file = Path(ndarray.__file__).with_suffix(".pyi")
|
||||||
with open(pyi_file, "w") as pyi:
|
with open(pyi_file, "w") as pyi:
|
||||||
pyi.write(stub_string)
|
pyi.write(stub_string)
|
||||||
|
except Exception as e:
|
||||||
|
warn(f"ndarray.pyi stub file could not be generated: {e}", stacklevel=1)
|
||||||
|
|
|
@ -165,9 +165,6 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
- https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
|
- https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self: T):
|
|
||||||
pass
|
|
||||||
|
|
||||||
__args__: Tuple[ShapeType, DtypeType] = (Any, Any)
|
__args__: Tuple[ShapeType, DtypeType] = (Any, Any)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -4,12 +4,11 @@ Types for numpydantic
|
||||||
Note that these are types as in python typing types, not classes.
|
Note that these are types as in python typing types, not classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Protocol, Tuple, TypeVar, Union, runtime_checkable
|
from typing import Any, Protocol, Tuple, runtime_checkable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nptyping import DType
|
from nptyping import DType
|
||||||
|
|
||||||
|
|
||||||
ShapeType = Tuple[int, ...] | Any
|
ShapeType = Tuple[int, ...] | Any
|
||||||
DtypeType = np.dtype | str | type | Any | DType
|
DtypeType = np.dtype | str | type | Any | DType
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union, Type
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from numpydantic.interface.hdf5 import H5Array
|
||||||
|
from numpydantic import NDArray, Shape
|
||||||
|
from nptyping import Number
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def model_rgb() -> Type[BaseModel]:
|
||||||
|
class RGB(BaseModel):
|
||||||
|
array: Optional[
|
||||||
|
Union[
|
||||||
|
NDArray[Shape["* x, * y"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
|
||||||
|
]
|
||||||
|
] = Field(None)
|
||||||
|
|
||||||
|
return RGB
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def h5file(tmp_path) -> h5py.File:
|
||||||
|
h5f = h5py.File(tmp_path / "file.h5", "w")
|
||||||
|
yield h5f
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def h5_array(h5file) -> H5Array:
|
||||||
|
"""trivial hdf5 array used for testing array existence"""
|
||||||
|
path = "/data"
|
||||||
|
h5file.create_dataset(path, data=np.zeros((3, 4)))
|
||||||
|
return H5Array(file=Path(h5file.filename), path=path)
|
0
tests/test_interface/__init__.py
Normal file
0
tests/test_interface/__init__.py
Normal file
21
tests/test_interface/conftest.py
Normal file
21
tests/test_interface/conftest.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import dask.array as da
|
||||||
|
|
||||||
|
from numpydantic import interface
|
||||||
|
from tests.conftest import h5_array, h5file
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
scope="function",
|
||||||
|
params=[
|
||||||
|
([[1, 2], [3, 4]], interface.NumpyInterface),
|
||||||
|
(np.zeros((3, 4)), interface.NumpyInterface),
|
||||||
|
(h5_array, interface.H5Interface),
|
||||||
|
(da.random.random((10, 10)), interface.DaskInterface),
|
||||||
|
],
|
||||||
|
ids=["numpy_list", "numpy", "H5Array", "dask"],
|
||||||
|
)
|
||||||
|
def interface_type(request):
|
||||||
|
return request.param
|
44
tests/test_interface/test_dask.py
Normal file
44
tests/test_interface/test_dask.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import dask.array as da
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from numpydantic.interface import DaskInterface
|
||||||
|
|
||||||
|
|
||||||
|
def test_dask_enabled():
|
||||||
|
"""
|
||||||
|
We need dask to be available to run these tests :)
|
||||||
|
"""
|
||||||
|
assert DaskInterface.enabled()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dask_check(interface_type):
|
||||||
|
if interface_type[1] is DaskInterface:
|
||||||
|
assert DaskInterface.check(interface_type[0])
|
||||||
|
else:
|
||||||
|
assert not DaskInterface.check(interface_type[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"array,passes",
|
||||||
|
[
|
||||||
|
(da.random.random((5, 10)), True),
|
||||||
|
(da.random.random((5, 10, 3)), True),
|
||||||
|
(da.random.random((5, 10, 3, 4)), True),
|
||||||
|
(da.random.random((5, 10, 4)), False),
|
||||||
|
(da.random.random((5, 10, 3, 6)), False),
|
||||||
|
(da.random.random((5, 10, 4, 6)), False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_dask_shape(model_rgb, array, passes):
|
||||||
|
if passes:
|
||||||
|
model_rgb(array=array)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
model_rgb(array=array)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO")
|
||||||
|
def test_dask_dtype():
|
||||||
|
pass
|
0
tests/test_interface/test_hdf5.py
Normal file
0
tests/test_interface/test_hdf5.py
Normal file
0
tests/test_interface/test_numpy.py
Normal file
0
tests/test_interface/test_numpy.py
Normal file
Loading…
Reference in a new issue