2024-02-03 06:45:50 +00:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
from typing import Union, Optional, Any
|
|
|
|
import json
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from pydantic import BaseModel, ValidationError, Field
|
|
|
|
from nptyping import Shape, Number
|
|
|
|
|
2024-04-03 23:34:03 +00:00
|
|
|
from numpydantic import NDArray
|
2024-04-04 03:52:33 +00:00
|
|
|
from numpydantic.exceptions import ShapeError, DtypeError
|
2024-02-03 06:45:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
# from .fixtures import tmp_output_dir_func
|
|
|
|
|
|
|
|
|
|
|
|
def test_ndarray_type():
|
|
|
|
class Model(BaseModel):
|
|
|
|
array: NDArray[Shape["2 x, * y"], Number]
|
|
|
|
array_any: Optional[NDArray[Any, Any]] = None
|
|
|
|
|
|
|
|
schema = Model.model_json_schema()
|
|
|
|
assert schema["properties"]["array"]["items"] == {
|
|
|
|
"items": {"type": "number"},
|
|
|
|
"type": "array",
|
|
|
|
}
|
|
|
|
assert schema["properties"]["array"]["maxItems"] == 2
|
|
|
|
assert schema["properties"]["array"]["minItems"] == 2
|
|
|
|
|
|
|
|
# models should instantiate correctly!
|
|
|
|
instance = Model(array=np.zeros((2, 3)))
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
instance = Model(array=np.zeros((4, 6)))
|
|
|
|
|
2024-04-04 03:52:33 +00:00
|
|
|
with pytest.raises(DtypeError):
|
2024-02-03 06:45:50 +00:00
|
|
|
instance = Model(array=np.ones((2, 3), dtype=bool))
|
|
|
|
|
|
|
|
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
|
|
|
|
|
|
|
|
|
|
|
def test_ndarray_union():
|
|
|
|
class Model(BaseModel):
|
|
|
|
array: Optional[
|
|
|
|
Union[
|
|
|
|
NDArray[Shape["* x, * y"], Number],
|
|
|
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
|
|
|
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number],
|
|
|
|
]
|
|
|
|
] = Field(None)
|
|
|
|
|
|
|
|
instance = Model()
|
|
|
|
instance = Model(array=np.random.random((5, 10)))
|
|
|
|
instance = Model(array=np.random.random((5, 10, 3)))
|
|
|
|
instance = Model(array=np.random.random((5, 10, 3, 4)))
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
instance = Model(array=np.random.random((5,)))
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
instance = Model(array=np.random.random((5, 10, 4)))
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
instance = Model(array=np.random.random((5, 10, 3, 6)))
|
|
|
|
|
|
|
|
with pytest.raises(ValidationError):
|
|
|
|
instance = Model(array=np.random.random((5, 10, 4, 6)))
|
|
|
|
|
|
|
|
|
|
|
|
def test_ndarray_coercion():
|
|
|
|
"""
|
|
|
|
Coerce lists to arrays
|
|
|
|
"""
|
|
|
|
|
|
|
|
class Model(BaseModel):
|
|
|
|
array: NDArray[Shape["* x"], Number]
|
|
|
|
|
|
|
|
amod = Model(array=[1, 2, 3, 4.5])
|
|
|
|
assert np.allclose(amod.array, np.array([1, 2, 3, 4.5]))
|
2024-04-04 03:52:33 +00:00
|
|
|
with pytest.raises(DtypeError):
|
2024-02-03 06:45:50 +00:00
|
|
|
amod = Model(array=["a", "b", "c"])
|
|
|
|
|
|
|
|
|
|
|
|
def test_ndarray_serialize():
|
|
|
|
"""
|
2024-04-23 03:00:43 +00:00
|
|
|
Arrays should be dumped to a list when using json, but kept as ndarray otherwise
|
2024-02-03 06:45:50 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
class Model(BaseModel):
|
2024-04-23 03:00:43 +00:00
|
|
|
array: NDArray[Any, Number]
|
2024-02-03 06:45:50 +00:00
|
|
|
|
2024-04-23 03:00:43 +00:00
|
|
|
mod = Model(array=np.random.random((3, 3)))
|
2024-02-03 06:45:50 +00:00
|
|
|
mod_str = mod.model_dump_json()
|
|
|
|
mod_json = json.loads(mod_str)
|
2024-04-23 03:00:43 +00:00
|
|
|
assert isinstance(mod_json["array"], list)
|
2024-02-03 06:45:50 +00:00
|
|
|
|
2024-04-23 03:00:43 +00:00
|
|
|
# but when we just dump to a dict we don't coerce
|
2024-02-03 06:45:50 +00:00
|
|
|
mod_dict = mod.model_dump()
|
2024-04-23 03:01:55 +00:00
|
|
|
assert isinstance(mod_dict["array"], np.ndarray)
|
2024-02-03 06:45:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
# def test_ndarray_proxy(tmp_output_dir_func):
|
|
|
|
# h5f_source = tmp_output_dir_func / 'test.h5'
|
|
|
|
# with h5py.File(h5f_source, 'w') as h5f:
|
|
|
|
# dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3)))
|
|
|
|
# dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
|
|
|
|
#
|
|
|
|
# class Model(BaseModel):
|
|
|
|
# array: NDArray[Shape["* x, * y, 3 z"], Number]
|
|
|
|
#
|
|
|
|
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data'))
|
|
|
|
# subarray = mod.array[0:5, 0:5, :]
|
|
|
|
# assert isinstance(subarray, np.ndarray)
|
|
|
|
# assert isinstance(subarray.sum(), float)
|
|
|
|
# assert mod.array.name == '/data'
|
|
|
|
#
|
|
|
|
# with pytest.raises(NotImplementedError):
|
|
|
|
# mod.array[0] = 5
|
|
|
|
#
|
|
|
|
# with pytest.raises(ValidationError):
|
|
|
|
# mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad'))
|