[tests] ndarray proxy and serialization

This commit is contained in:
sneakers-the-rat 2023-10-05 22:05:10 -07:00
parent 0a9ca82476
commit 8405ea948a
3 changed files with 67 additions and 8 deletions

View file

@ -25,7 +25,7 @@ def tmp_output_dir_func(tmp_output_dir) -> Path:
tmp output dir that gets cleared between every function tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected cleans at the start rather than at cleanup in case the output is to be inspected
""" """
subpath = tmp_output_dir / '__tmp__' subpath = tmp_output_dir / '__tmpfunc__'
if subpath.exists(): if subpath.exists():
shutil.rmtree(str(subpath)) shutil.rmtree(str(subpath))
subpath.mkdir() subpath.mkdir()
@ -37,7 +37,7 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path:
tmp output dir that gets cleared between every function tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected cleans at the start rather than at cleanup in case the output is to be inspected
""" """
subpath = tmp_output_dir / '__tmp__' subpath = tmp_output_dir / '__tmpmod__'
if subpath.exists(): if subpath.exists():
shutil.rmtree(str(subpath)) shutil.rmtree(str(subpath))
subpath.mkdir() subpath.mkdir()

View file

@ -1,19 +1,22 @@
import pdb import pdb
from typing import Union, Optional, Any from typing import Union, Optional, Any
import json
import pytest import pytest
import numpy as np import numpy as np
import h5py
from pydantic import BaseModel, ValidationError, Field from pydantic import BaseModel, ValidationError, Field
from nwb_linkml.types.ndarray import NDArray from nwb_linkml.types.ndarray import NDArray, NDArrayProxy
from nptyping import Shape, Number from nptyping import Shape, Number
from ..fixtures import data_dir from ..fixtures import data_dir, tmp_output_dir, tmp_output_dir_func
def test_ndarray_type(): def test_ndarray_type():
class Model(BaseModel): class Model(BaseModel):
array: NDArray[Shape["2 x, * y"], Number] array: NDArray[Shape["2 x, * y"], Number]
array_any: Optional[NDArray[Any, Any]] = None
schema = Model.model_json_schema() schema = Model.model_json_schema()
assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'} assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'}
@ -29,6 +32,8 @@ def test_ndarray_type():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
instance = Model(array=np.ones((2,3), dtype=bool)) instance = Model(array=np.ones((2,3), dtype=bool))
instance = Model(array=np.zeros((2,3)), array_any = np.ones((3,4,5)))
def test_ndarray_union(): def test_ndarray_union():
class Model(BaseModel): class Model(BaseModel):
@ -55,6 +60,60 @@ def test_ndarray_union():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
instance = Model(array=np.random.random((5,10,4,6))) instance = Model(array=np.random.random((5,10,4,6)))
@pytest.mark.skip() def test_ndarray_coercion():
def test_ndarray_proxy(data_dir): """
h5f_source = data_dir / 'aibs_ecephys.nwb' Coerce lists to arrays
"""
class Model(BaseModel):
array: NDArray[Shape["* x"], Number]
amod = Model(array=[1,2,3,4.5])
assert np.allclose(amod.array, np.array([1,2,3,4.5]))
with pytest.raises(ValidationError):
amod = Model(array=['a', 'b', 'c'])
def test_ndarray_serialize():
"""
Large arrays should get compressed with blosc, otherwise just to list
"""
class Model(BaseModel):
large_array: NDArray[Any, Number]
small_array: NDArray[Any, Number]
mod = Model(large_array=np.random.random((1024,1024)), small_array=np.random.random((3,3)))
mod_str = mod.model_dump_json()
mod_json = json.loads(mod_str)
for a in ('array', 'shape', 'dtype', 'unpack_fns'):
assert a in mod_json['large_array'].keys()
assert isinstance(mod_json['large_array']['array'], str)
assert isinstance(mod_json['small_array'], list)
# but when we just dump to a dict we don't compress
mod_dict = mod.model_dump()
assert isinstance(mod_dict['large_array'], np.ndarray)
def test_ndarray_proxy(tmp_output_dir_func):
h5f_source = tmp_output_dir_func / 'test.h5'
with h5py.File(h5f_source, 'w') as h5f:
dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3)))
dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
class Model(BaseModel):
array: NDArray[Shape["* x, * y, 3 z"], Number]
mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data'))
subarray = mod.array[0:5, 0:5, :]
assert isinstance(subarray, np.ndarray)
assert isinstance(subarray.sum(), float)
assert mod.array.name == '/data'
with pytest.raises(NotImplementedError):
mod.array[0] = 5
with pytest.raises(ValidationError):
mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad'))