mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2024-11-10 00:34:29 +00:00
more incremental progress towards a v0.1.0, importing tests that will surely fail.
This commit is contained in:
parent
657f981514
commit
690f9cd53a
8 changed files with 376 additions and 54 deletions
|
@ -4,7 +4,9 @@ Type and shape validation and serialization for numpy arrays in pydantic models
|
||||||
|
|
||||||
This package was picked out of [nwb-linkml](https://github.com/p2p-ld/nwb-linkml/), a
|
This package was picked out of [nwb-linkml](https://github.com/p2p-ld/nwb-linkml/), a
|
||||||
translation of the [NWB](https://www.nwb.org/) schema language and data format to
|
translation of the [NWB](https://www.nwb.org/) schema language and data format to
|
||||||
linkML and pydantic models.
|
linkML and pydantic models. It's in a hurried and limited form to make it
|
||||||
|
available for a LinkML hackathon, but will be matured as part of `nwb-linkml` development
|
||||||
|
as the primary place this logic exists.
|
||||||
|
|
||||||
It does two primary things:
|
It does two primary things:
|
||||||
- **Provide types** - Annotations (based on [npytyping](https://github.com/ramonhagenaars/nptyping))
|
- **Provide types** - Annotations (based on [npytyping](https://github.com/ramonhagenaars/nptyping))
|
||||||
|
@ -12,8 +14,30 @@ It does two primary things:
|
||||||
- **Generate models from LinkML** - extend the LinkML pydantic generator to create models that
|
- **Generate models from LinkML** - extend the LinkML pydantic generator to create models that
|
||||||
that use the [linkml-arrays](https://github.com/linkml/linkml-arrays) syntax
|
that use the [linkml-arrays](https://github.com/linkml/linkml-arrays) syntax
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Python type annotation system is weird and not like the rest of Python!
|
||||||
|
(at least until [PEP 0649](https://peps.python.org/pep-0649/) gets mainlined).
|
||||||
|
Similarly, Pydantic 2's core_schema system is wonderful but still relatively poorly
|
||||||
|
documented for custom types! This package does the work of plugging them in
|
||||||
|
together to make some kind of type validation frankenstein.
|
||||||
|
|
||||||
|
The first problem is that type annotations are evaluated statically by python, mypy,
|
||||||
|
etc. This means you can't use typical python syntax for declaring types - it has to
|
||||||
|
be present at the time `__new__` is called, rather than `__init__`.
|
||||||
|
|
||||||
|
- pydantic schema
|
||||||
|
- validation
|
||||||
|
- serialization
|
||||||
|
- lazy loading
|
||||||
|
- compression
|
||||||
|
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 2
|
:maxdepth: 2
|
||||||
:caption: Contents:
|
:caption: Contents:
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
hooks
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -7,17 +7,14 @@ import base64
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from pathlib import Path
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import blosc2
|
import blosc2
|
||||||
import h5py
|
|
||||||
import nptyping.structure
|
import nptyping.structure
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# TODO: conditional import
|
# TODO: conditional import of dask, remove from required dependencies
|
||||||
from dask.array.core import Array as DaskArray
|
from dask.array.core import Array as DaskArray
|
||||||
from nptyping import NDArray as _NDArray
|
|
||||||
from nptyping import Shape
|
from nptyping import Shape
|
||||||
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
from nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||||
from nptyping.nptyping_type import NPTypingType
|
from nptyping.nptyping_type import NPTypingType
|
||||||
|
@ -27,14 +24,19 @@ from pydantic_core.core_schema import ListSchema
|
||||||
|
|
||||||
from numpydantic.maps import np_to_python
|
from numpydantic.maps import np_to_python
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from numpydantic.proxy import NDArrayProxy
|
||||||
|
|
||||||
COMPRESSION_THRESHOLD = 16 * 1024
|
COMPRESSION_THRESHOLD = 16 * 1024
|
||||||
"""
|
"""
|
||||||
Arrays larger than this size (in bytes) will be compressed and b64 encoded when
|
Arrays larger than this size (in bytes) will be compressed and b64 encoded when
|
||||||
serializing to JSON.
|
serializing to JSON.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ARRAY_TYPES = Union[np.ndarray, DaskArray, "NDArrayProxy"]
|
||||||
|
|
||||||
def list_of_lists_schema(shape: Shape, array_type_handler) -> ListSchema:
|
|
||||||
|
def list_of_lists_schema(shape: Shape, array_type_handler: dict) -> ListSchema:
|
||||||
"""Make a pydantic JSON schema for an array as a list of lists."""
|
"""Make a pydantic JSON schema for an array as a list of lists."""
|
||||||
shape_parts = shape.__args__[0].split(",")
|
shape_parts = shape.__args__[0].split(",")
|
||||||
split_parts = [
|
split_parts = [
|
||||||
|
@ -66,7 +68,7 @@ def list_of_lists_schema(shape: Shape, array_type_handler) -> ListSchema:
|
||||||
return list_schema
|
return list_schema
|
||||||
|
|
||||||
|
|
||||||
def jsonize_array(array: np.ndarray | DaskArray) -> list | dict:
|
def jsonize_array(array: ARRAY_TYPES) -> list | dict:
|
||||||
"""
|
"""
|
||||||
Render an array to base python types that can be serialized to JSON
|
Render an array to base python types that can be serialized to JSON
|
||||||
|
|
||||||
|
@ -166,7 +168,7 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
|
|
||||||
# get pydantic core schema for the given specified type
|
# get pydantic core schema for the given specified type
|
||||||
if isinstance(dtype, nptyping.structure.StructureMeta):
|
if isinstance(dtype, nptyping.structure.StructureMeta):
|
||||||
raise NotImplementedError("Jonny finish this")
|
raise NotImplementedError("Finish handling structured dtypes!")
|
||||||
# functools.reduce(operator.or_, [int, float, str])
|
# functools.reduce(operator.or_, [int, float, str])
|
||||||
else:
|
else:
|
||||||
array_type_handler = _handler.generate_schema(np_to_python[dtype])
|
array_type_handler = _handler.generate_schema(np_to_python[dtype])
|
||||||
|
@ -201,48 +203,3 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
jsonize_array, when_used="json"
|
jsonize_array, when_used="json"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NDArrayProxy:
|
|
||||||
"""
|
|
||||||
Thin proxy to numpy arrays stored within hdf5 files,
|
|
||||||
only read into memory when accessed, but otherwise
|
|
||||||
passthrough all attempts to access attributes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, h5f_file: Path | str, path: str):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
h5f_file (:class:`pathlib.Path`): Path to source HDF5 file
|
|
||||||
path (str): Location within HDF5 file where this array is located
|
|
||||||
"""
|
|
||||||
self.h5f_file = Path(h5f_file)
|
|
||||||
self.path = path
|
|
||||||
|
|
||||||
def __getattr__(self, item) -> Any:
|
|
||||||
with h5py.File(self.h5f_file, "r") as h5f:
|
|
||||||
obj = h5f.get(self.path)
|
|
||||||
return getattr(obj, item)
|
|
||||||
|
|
||||||
def __getitem__(self, slice: slice) -> np.ndarray:
|
|
||||||
with h5py.File(self.h5f_file, "r") as h5f:
|
|
||||||
obj = h5f.get(self.path)
|
|
||||||
return obj[slice]
|
|
||||||
|
|
||||||
def __setitem__(self, slice, value) -> None:
|
|
||||||
raise NotImplementedError("Cant write into an arrayproxy yet!")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_pydantic_core_schema__(
|
|
||||||
cls,
|
|
||||||
_source_type: _NDArray,
|
|
||||||
_handler: Callable[[Any], core_schema.CoreSchema],
|
|
||||||
) -> core_schema.CoreSchema:
|
|
||||||
# return core_schema.no_info_after_validator_function(
|
|
||||||
# serialization=core_schema.plain_serializer_function_ser_schema(
|
|
||||||
# lambda array: array.tolist(),
|
|
||||||
# when_used='json'
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
|
|
||||||
return NDArray_.__get_pydantic_core_schema__(cls, _source_type, _handler)
|
|
||||||
|
|
53
numpydantic/proxy.py
Normal file
53
numpydantic/proxy.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
from collections.abc import Callable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from nptyping import NDArray as _NDArray
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
|
||||||
|
class NDArrayProxy:
|
||||||
|
"""
|
||||||
|
Thin proxy to numpy arrays stored within hdf5 files,
|
||||||
|
only read into memory when accessed, but otherwise
|
||||||
|
passthrough all attempts to access attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, h5f_file: Path | str, path: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
h5f_file (:class:`pathlib.Path`): Path to source HDF5 file
|
||||||
|
path (str): Location within HDF5 file where this array is located
|
||||||
|
"""
|
||||||
|
self.h5f_file = Path(h5f_file)
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
def __getattr__(self, item) -> Any:
|
||||||
|
with h5py.File(self.h5f_file, "r") as h5f:
|
||||||
|
obj = h5f.get(self.path)
|
||||||
|
return getattr(obj, item)
|
||||||
|
|
||||||
|
def __getitem__(self, slice: slice) -> np.ndarray:
|
||||||
|
with h5py.File(self.h5f_file, "r") as h5f:
|
||||||
|
obj = h5f.get(self.path)
|
||||||
|
return obj[slice]
|
||||||
|
|
||||||
|
def __setitem__(self, slice, value) -> None:
|
||||||
|
raise NotImplementedError("Cant write into an arrayproxy yet!")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls,
|
||||||
|
_source_type: _NDArray,
|
||||||
|
_handler: Callable[[Any], core_schema.CoreSchema],
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
# return core_schema.no_info_after_validator_function(
|
||||||
|
# serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
# lambda array: array.tolist(),
|
||||||
|
# when_used='json'
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
return NDArray_.__get_pydantic_core_schema__(cls, _source_type, _handler)
|
|
@ -70,6 +70,7 @@ testpaths = [
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
include = ["numpydantic/**/*.py", "pyproject.toml"]
|
include = ["numpydantic/**/*.py", "pyproject.toml"]
|
||||||
|
exclude = ["tests"]
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
|
|
0
tests/conftest.py
Normal file
0
tests/conftest.py
Normal file
40
tests/fixtures.py
Normal file
40
tests/fixtures.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tmp_output_dir() -> Path:
|
||||||
|
path = Path(__file__).parent.resolve() / "__tmp__"
|
||||||
|
if path.exists():
|
||||||
|
shutil.rmtree(str(path))
|
||||||
|
path.mkdir()
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def tmp_output_dir_func(tmp_output_dir) -> Path:
|
||||||
|
"""
|
||||||
|
tmp output dir that gets cleared between every function
|
||||||
|
cleans at the start rather than at cleanup in case the output is to be inspected
|
||||||
|
"""
|
||||||
|
subpath = tmp_output_dir / "__tmpfunc__"
|
||||||
|
if subpath.exists():
|
||||||
|
shutil.rmtree(str(subpath))
|
||||||
|
subpath.mkdir()
|
||||||
|
return subpath
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def tmp_output_dir_mod(tmp_output_dir) -> Path:
|
||||||
|
"""
|
||||||
|
tmp output dir that gets cleared between every function
|
||||||
|
cleans at the start rather than at cleanup in case the output is to be inspected
|
||||||
|
"""
|
||||||
|
subpath = tmp_output_dir / "__tmpmod__"
|
||||||
|
if subpath.exists():
|
||||||
|
shutil.rmtree(str(subpath))
|
||||||
|
subpath.mkdir()
|
||||||
|
return subpath
|
120
tests/test_linkml.py
Normal file
120
tests/test_linkml.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
"""
|
||||||
|
Test custom features of the pydantic generator
|
||||||
|
|
||||||
|
Note that since this is largely a subclass, we don't test all of the functionality of the generator
|
||||||
|
because it's tested in the base linkml package.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_arraylike(imported_schema):
|
||||||
|
"""
|
||||||
|
Arraylike classes are converted to slots that specify nptyping arrays
|
||||||
|
|
||||||
|
array: Optional[Union[
|
||||||
|
NDArray[Shape["* x, * y"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 z"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 z, 4 a"], Number]
|
||||||
|
]] = Field(None)
|
||||||
|
"""
|
||||||
|
# check that we have gotten an NDArray annotation and its shape is correct
|
||||||
|
array = imported_schema["core"].MainTopLevel.model_fields["array"].annotation
|
||||||
|
args = typing.get_args(array)
|
||||||
|
for i, shape in enumerate(("* x, * y", "* x, * y, 3 z", "* x, * y, 3 z, 4 a")):
|
||||||
|
assert isinstance(args[i], NDArrayMeta)
|
||||||
|
assert args[i].__args__[0].__args__
|
||||||
|
assert args[i].__args__[1] == np.number
|
||||||
|
|
||||||
|
# we shouldn't have an actual class for the array
|
||||||
|
assert not hasattr(imported_schema["core"], "MainTopLevel__Array")
|
||||||
|
assert not hasattr(imported_schema["core"], "MainTopLevelArray")
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_fields(imported_schema):
|
||||||
|
"""
|
||||||
|
Our root model should have the special fields we injected
|
||||||
|
"""
|
||||||
|
base = imported_schema["core"].ConfiguredBaseModel
|
||||||
|
assert "hdf5_path" in base.model_fields
|
||||||
|
assert "object_id" in base.model_fields
|
||||||
|
|
||||||
|
|
||||||
|
def test_linkml_meta(imported_schema):
|
||||||
|
"""
|
||||||
|
We should be able to store some linkml metadata with our classes
|
||||||
|
"""
|
||||||
|
meta = imported_schema["core"].LinkML_Meta
|
||||||
|
assert "tree_root" in meta.model_fields
|
||||||
|
assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root == True
|
||||||
|
assert imported_schema["core"].OtherClass.linkml_meta.default.tree_root == False
|
||||||
|
|
||||||
|
|
||||||
|
def test_skip(linkml_schema):
|
||||||
|
"""
|
||||||
|
We can skip slots and classes
|
||||||
|
"""
|
||||||
|
modules = generate_and_import(
|
||||||
|
linkml_schema,
|
||||||
|
split=False,
|
||||||
|
generator_kwargs={
|
||||||
|
"SKIP_SLOTS": ("SkippableSlot",),
|
||||||
|
"SKIP_CLASSES": ("Skippable", "skippable"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert not hasattr(modules["core"], "Skippable")
|
||||||
|
assert "SkippableSlot" not in modules["core"].MainTopLevel.model_fields
|
||||||
|
|
||||||
|
|
||||||
|
def test_inline_with_identifier(imported_schema):
|
||||||
|
"""
|
||||||
|
By default, if a class has an identifier attribute, it is inlined
|
||||||
|
as a string rather than its class. We overrode that to be able to make dictionaries of collections
|
||||||
|
"""
|
||||||
|
main = imported_schema["core"].MainTopLevel
|
||||||
|
inline = main.model_fields["inline_dict"].annotation
|
||||||
|
assert typing.get_origin(typing.get_args(inline)[0]) == dict
|
||||||
|
# god i hate pythons typing interface
|
||||||
|
otherclass, stillanother = typing.get_args(
|
||||||
|
typing.get_args(typing.get_args(inline)[0])[1]
|
||||||
|
)
|
||||||
|
assert otherclass is imported_schema["core"].OtherClass
|
||||||
|
assert stillanother is imported_schema["core"].StillAnotherClass
|
||||||
|
|
||||||
|
|
||||||
|
def test_namespace(imported_schema):
|
||||||
|
"""
|
||||||
|
Namespace schema import all classes from the other schema
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
ns = imported_schema["namespace"]
|
||||||
|
|
||||||
|
for classname, modname in (
|
||||||
|
("MainThing", "test_schema.imported"),
|
||||||
|
("Arraylike", "test_schema.imported"),
|
||||||
|
("MainTopLevel", "test_schema.core"),
|
||||||
|
("Skippable", "test_schema.core"),
|
||||||
|
("OtherClass", "test_schema.core"),
|
||||||
|
("StillAnotherClass", "test_schema.core"),
|
||||||
|
):
|
||||||
|
assert hasattr(ns, classname)
|
||||||
|
if imported_schema["split"]:
|
||||||
|
assert getattr(ns, classname).__module__ == modname
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_set_item(imported_schema):
|
||||||
|
"""We can get and set without explicitly addressing array"""
|
||||||
|
cls = imported_schema["core"].MainTopLevel(array=np.array([[1, 2, 3], [4, 5, 6]]))
|
||||||
|
cls[0] = 50
|
||||||
|
assert (cls[0] == 50).all()
|
||||||
|
assert (cls.array[0] == 50).all()
|
||||||
|
|
||||||
|
cls[1, 1] = 100
|
||||||
|
assert cls[1, 1] == 100
|
||||||
|
assert cls.array[1, 1] == 100
|
127
tests/test_ndarray.py
Normal file
127
tests/test_ndarray.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from typing import Union, Optional, Any
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, ValidationError, Field
|
||||||
|
from nptyping import Shape, Number
|
||||||
|
|
||||||
|
from numpydantic.ndarray import NDArray
|
||||||
|
from numpydantic.proxy import NDArrayProxy
|
||||||
|
|
||||||
|
|
||||||
|
# from .fixtures import tmp_output_dir_func
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_type():
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: NDArray[Shape["2 x, * y"], Number]
|
||||||
|
array_any: Optional[NDArray[Any, Any]] = None
|
||||||
|
|
||||||
|
schema = Model.model_json_schema()
|
||||||
|
assert schema["properties"]["array"]["items"] == {
|
||||||
|
"items": {"type": "number"},
|
||||||
|
"type": "array",
|
||||||
|
}
|
||||||
|
assert schema["properties"]["array"]["maxItems"] == 2
|
||||||
|
assert schema["properties"]["array"]["minItems"] == 2
|
||||||
|
|
||||||
|
# models should instantiate correctly!
|
||||||
|
instance = Model(array=np.zeros((2, 3)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.zeros((4, 6)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.ones((2, 3), dtype=bool))
|
||||||
|
|
||||||
|
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_union():
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: Optional[
|
||||||
|
Union[
|
||||||
|
NDArray[Shape["* x, * y"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
|
||||||
|
]
|
||||||
|
] = Field(None)
|
||||||
|
|
||||||
|
instance = Model()
|
||||||
|
instance = Model(array=np.random.random((5, 10)))
|
||||||
|
instance = Model(array=np.random.random((5, 10, 3)))
|
||||||
|
instance = Model(array=np.random.random((5, 10, 3, 4)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5,)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5, 10, 4)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5, 10, 3, 6)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5, 10, 4, 6)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_coercion():
|
||||||
|
"""
|
||||||
|
Coerce lists to arrays
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: NDArray[Shape["* x"], Number]
|
||||||
|
|
||||||
|
amod = Model(array=[1, 2, 3, 4.5])
|
||||||
|
assert np.allclose(amod.array, np.array([1, 2, 3, 4.5]))
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
amod = Model(array=["a", "b", "c"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_serialize():
|
||||||
|
"""
|
||||||
|
Large arrays should get compressed with blosc, otherwise just to list
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
large_array: NDArray[Any, Number]
|
||||||
|
small_array: NDArray[Any, Number]
|
||||||
|
|
||||||
|
mod = Model(
|
||||||
|
large_array=np.random.random((1024, 1024)), small_array=np.random.random((3, 3))
|
||||||
|
)
|
||||||
|
mod_str = mod.model_dump_json()
|
||||||
|
mod_json = json.loads(mod_str)
|
||||||
|
for a in ("array", "shape", "dtype", "unpack_fns"):
|
||||||
|
assert a in mod_json["large_array"].keys()
|
||||||
|
assert isinstance(mod_json["large_array"]["array"], str)
|
||||||
|
assert isinstance(mod_json["small_array"], list)
|
||||||
|
|
||||||
|
# but when we just dump to a dict we don't compress
|
||||||
|
mod_dict = mod.model_dump()
|
||||||
|
assert isinstance(mod_dict["large_array"], np.ndarray)
|
||||||
|
|
||||||
|
|
||||||
|
# def test_ndarray_proxy(tmp_output_dir_func):
|
||||||
|
# h5f_source = tmp_output_dir_func / 'test.h5'
|
||||||
|
# with h5py.File(h5f_source, 'w') as h5f:
|
||||||
|
# dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3)))
|
||||||
|
# dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
|
||||||
|
#
|
||||||
|
# class Model(BaseModel):
|
||||||
|
# array: NDArray[Shape["* x, * y, 3 z"], Number]
|
||||||
|
#
|
||||||
|
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data'))
|
||||||
|
# subarray = mod.array[0:5, 0:5, :]
|
||||||
|
# assert isinstance(subarray, np.ndarray)
|
||||||
|
# assert isinstance(subarray.sum(), float)
|
||||||
|
# assert mod.array.name == '/data'
|
||||||
|
#
|
||||||
|
# with pytest.raises(NotImplementedError):
|
||||||
|
# mod.array[0] = 5
|
||||||
|
#
|
||||||
|
# with pytest.raises(ValidationError):
|
||||||
|
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad'))
|
Loading…
Reference in a new issue