mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
[tests] ndarray proxy and serialization
This commit is contained in:
parent
0a9ca82476
commit
8405ea948a
3 changed files with 67 additions and 8 deletions
|
@ -780,7 +780,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
except NameError as e:
|
||||
raise e
|
||||
|
||||
def compile_python(text_or_fn: str, package_path: Path = None, module_name:str='test') -> ModuleType:
|
||||
def compile_python(text_or_fn: str, package_path: Path = None, module_name:str='test') -> ModuleType:
|
||||
"""
|
||||
Compile the text or file and return the resulting module
|
||||
@param text_or_fn: Python text or file name that references python file
|
||||
|
|
|
@ -25,7 +25,7 @@ def tmp_output_dir_func(tmp_output_dir) -> Path:
|
|||
tmp output dir that gets cleared between every function
|
||||
cleans at the start rather than at cleanup in case the output is to be inspected
|
||||
"""
|
||||
subpath = tmp_output_dir / '__tmp__'
|
||||
subpath = tmp_output_dir / '__tmpfunc__'
|
||||
if subpath.exists():
|
||||
shutil.rmtree(str(subpath))
|
||||
subpath.mkdir()
|
||||
|
@ -37,7 +37,7 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path:
|
|||
tmp output dir that gets cleared between every function
|
||||
cleans at the start rather than at cleanup in case the output is to be inspected
|
||||
"""
|
||||
subpath = tmp_output_dir / '__tmp__'
|
||||
subpath = tmp_output_dir / '__tmpmod__'
|
||||
if subpath.exists():
|
||||
shutil.rmtree(str(subpath))
|
||||
subpath.mkdir()
|
||||
|
|
|
@ -1,19 +1,22 @@
|
|||
import pdb
|
||||
from typing import Union, Optional, Any
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
||||
from pydantic import BaseModel, ValidationError, Field
|
||||
from nwb_linkml.types.ndarray import NDArray
|
||||
from nwb_linkml.types.ndarray import NDArray, NDArrayProxy
|
||||
from nptyping import Shape, Number
|
||||
|
||||
from ..fixtures import data_dir
|
||||
from ..fixtures import data_dir, tmp_output_dir, tmp_output_dir_func
|
||||
def test_ndarray_type():
|
||||
|
||||
class Model(BaseModel):
|
||||
array: NDArray[Shape["2 x, * y"], Number]
|
||||
array_any: Optional[NDArray[Any, Any]] = None
|
||||
|
||||
schema = Model.model_json_schema()
|
||||
assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'}
|
||||
|
@ -29,6 +32,8 @@ def test_ndarray_type():
|
|||
with pytest.raises(ValidationError):
|
||||
instance = Model(array=np.ones((2,3), dtype=bool))
|
||||
|
||||
instance = Model(array=np.zeros((2,3)), array_any = np.ones((3,4,5)))
|
||||
|
||||
|
||||
def test_ndarray_union():
|
||||
class Model(BaseModel):
|
||||
|
@ -55,6 +60,60 @@ def test_ndarray_union():
|
|||
with pytest.raises(ValidationError):
|
||||
instance = Model(array=np.random.random((5,10,4,6)))
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_ndarray_proxy(data_dir):
|
||||
h5f_source = data_dir / 'aibs_ecephys.nwb'
|
||||
def test_ndarray_coercion():
|
||||
"""
|
||||
Coerce lists to arrays
|
||||
"""
|
||||
|
||||
class Model(BaseModel):
|
||||
array: NDArray[Shape["* x"], Number]
|
||||
|
||||
amod = Model(array=[1,2,3,4.5])
|
||||
assert np.allclose(amod.array, np.array([1,2,3,4.5]))
|
||||
with pytest.raises(ValidationError):
|
||||
amod = Model(array=['a', 'b', 'c'])
|
||||
|
||||
def test_ndarray_serialize():
|
||||
"""
|
||||
Large arrays should get compressed with blosc, otherwise just to list
|
||||
"""
|
||||
class Model(BaseModel):
|
||||
large_array: NDArray[Any, Number]
|
||||
small_array: NDArray[Any, Number]
|
||||
|
||||
mod = Model(large_array=np.random.random((1024,1024)), small_array=np.random.random((3,3)))
|
||||
mod_str = mod.model_dump_json()
|
||||
mod_json = json.loads(mod_str)
|
||||
for a in ('array', 'shape', 'dtype', 'unpack_fns'):
|
||||
assert a in mod_json['large_array'].keys()
|
||||
assert isinstance(mod_json['large_array']['array'], str)
|
||||
assert isinstance(mod_json['small_array'], list)
|
||||
|
||||
# but when we just dump to a dict we don't compress
|
||||
mod_dict = mod.model_dump()
|
||||
assert isinstance(mod_dict['large_array'], np.ndarray)
|
||||
|
||||
|
||||
|
||||
def test_ndarray_proxy(tmp_output_dir_func):
|
||||
h5f_source = tmp_output_dir_func / 'test.h5'
|
||||
with h5py.File(h5f_source, 'w') as h5f:
|
||||
dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3)))
|
||||
dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4)))
|
||||
|
||||
class Model(BaseModel):
|
||||
array: NDArray[Shape["* x, * y, 3 z"], Number]
|
||||
|
||||
mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data'))
|
||||
subarray = mod.array[0:5, 0:5, :]
|
||||
assert isinstance(subarray, np.ndarray)
|
||||
assert isinstance(subarray.sum(), float)
|
||||
assert mod.array.name == '/data'
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
mod.array[0] = 5
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad'))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue