[tests] ndarray proxy and serialization

This commit is contained in:
sneakers-the-rat 2023-10-05 22:05:10 -07:00
parent 0a9ca82476
commit 8405ea948a
3 changed files with 67 additions and 8 deletions

View file

@ -780,7 +780,7 @@ class NWBPydanticGenerator(PydanticGenerator):
except NameError as e:
raise e
def compile_python(text_or_fn: str, package_path: Path = None, module_name:str='test') -> ModuleType:
def compile_python(text_or_fn: str, package_path: Path = None, module_name:str='test') -> ModuleType:
"""
Compile the text or file and return the resulting module
@param text_or_fn: Python text or file name that references python file

View file

@ -25,7 +25,7 @@ def tmp_output_dir_func(tmp_output_dir) -> Path:
tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected
"""
subpath = tmp_output_dir / '__tmp__'
subpath = tmp_output_dir / '__tmpfunc__'
if subpath.exists():
shutil.rmtree(str(subpath))
subpath.mkdir()
@ -37,7 +37,7 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path:
tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected
"""
subpath = tmp_output_dir / '__tmp__'
subpath = tmp_output_dir / '__tmpmod__'
if subpath.exists():
shutil.rmtree(str(subpath))
subpath.mkdir()

View file

@ -1,19 +1,22 @@
import pdb
from typing import Union, Optional, Any
import json
import pytest
import numpy as np
import h5py
from pydantic import BaseModel, ValidationError, Field
from nwb_linkml.types.ndarray import NDArray
from nwb_linkml.types.ndarray import NDArray, NDArrayProxy
from nptyping import Shape, Number
from ..fixtures import data_dir
from ..fixtures import data_dir, tmp_output_dir, tmp_output_dir_func
def test_ndarray_type():
class Model(BaseModel):
array: NDArray[Shape["2 x, * y"], Number]
array_any: Optional[NDArray[Any, Any]] = None
schema = Model.model_json_schema()
assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'}
@ -29,6 +32,8 @@ def test_ndarray_type():
with pytest.raises(ValidationError):
instance = Model(array=np.ones((2,3), dtype=bool))
instance = Model(array=np.zeros((2,3)), array_any = np.ones((3,4,5)))
def test_ndarray_union():
class Model(BaseModel):
@ -55,6 +60,60 @@ def test_ndarray_union():
with pytest.raises(ValidationError):
instance = Model(array=np.random.random((5,10,4,6)))
@pytest.mark.skip()
def test_ndarray_proxy(data_dir):
h5f_source = data_dir / 'aibs_ecephys.nwb'
def test_ndarray_coercion():
"""
Coerce lists to arrays
"""
class Model(BaseModel):
array: NDArray[Shape["* x"], Number]
amod = Model(array=[1,2,3,4.5])
assert np.allclose(amod.array, np.array([1,2,3,4.5]))
with pytest.raises(ValidationError):
amod = Model(array=['a', 'b', 'c'])
def test_ndarray_serialize():
"""
Large arrays should get compressed with blosc, otherwise just to list
"""
class Model(BaseModel):
large_array: NDArray[Any, Number]
small_array: NDArray[Any, Number]
mod = Model(large_array=np.random.random((1024,1024)), small_array=np.random.random((3,3)))
mod_str = mod.model_dump_json()
mod_json = json.loads(mod_str)
for a in ('array', 'shape', 'dtype', 'unpack_fns'):
assert a in mod_json['large_array'].keys()
assert isinstance(mod_json['large_array']['array'], str)
assert isinstance(mod_json['small_array'], list)
# but when we just dump to a dict we don't compress
mod_dict = mod.model_dump()
assert isinstance(mod_dict['large_array'], np.ndarray)
def test_ndarray_proxy(tmp_output_dir_func):
h5f_source = tmp_output_dir_func / 'test.h5'
with h5py.File(h5f_source, 'w') as h5f:
dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3)))
dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
class Model(BaseModel):
array: NDArray[Shape["* x, * y, 3 z"], Number]
mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data'))
subarray = mod.array[0:5, 0:5, :]
assert isinstance(subarray, np.ndarray)
assert isinstance(subarray.sum(), float)
assert mod.array.name == '/data'
with pytest.raises(NotImplementedError):
mod.array[0] = 5
with pytest.raises(ValidationError):
mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad'))