From 8405ea948ada53f90398b975abb1ed1eb94de59d Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 5 Oct 2023 22:05:10 -0700 Subject: [PATCH] [tests] ndarray proxy and serialization --- .../src/nwb_linkml/generators/pydantic.py | 2 +- nwb_linkml/tests/fixtures.py | 4 +- nwb_linkml/tests/test_types/test_ndarray.py | 69 +++++++++++++++++-- 3 files changed, 67 insertions(+), 8 deletions(-) diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 1ead9f7..96b5b85 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -780,7 +780,7 @@ class NWBPydanticGenerator(PydanticGenerator): except NameError as e: raise e -def compile_python(text_or_fn: str, package_path: Path = None, module_name:str='test') -> ModuleType: +def compile_python(text_or_fn: str, package_path: Path = None, module_name:str='test') -> ModuleType: """ Compile the text or file and return the resulting module @param text_or_fn: Python text or file name that references python file diff --git a/nwb_linkml/tests/fixtures.py b/nwb_linkml/tests/fixtures.py index 68be6e2..bc1c81c 100644 --- a/nwb_linkml/tests/fixtures.py +++ b/nwb_linkml/tests/fixtures.py @@ -25,7 +25,7 @@ def tmp_output_dir_func(tmp_output_dir) -> Path: tmp output dir that gets cleared between every function cleans at the start rather than at cleanup in case the output is to be inspected """ - subpath = tmp_output_dir / '__tmp__' + subpath = tmp_output_dir / '__tmpfunc__' if subpath.exists(): shutil.rmtree(str(subpath)) subpath.mkdir() @@ -37,7 +37,7 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path: tmp output dir that gets cleared between every function cleans at the start rather than at cleanup in case the output is to be inspected """ - subpath = tmp_output_dir / '__tmp__' + subpath = tmp_output_dir / '__tmpmod__' if subpath.exists(): shutil.rmtree(str(subpath)) subpath.mkdir() diff --git a/nwb_linkml/tests/test_types/test_ndarray.py b/nwb_linkml/tests/test_types/test_ndarray.py index d7d0506..9b81f17 100644 --- a/nwb_linkml/tests/test_types/test_ndarray.py +++ b/nwb_linkml/tests/test_types/test_ndarray.py @@ -1,19 +1,22 @@ import pdb from typing import Union, Optional, Any +import json import pytest import numpy as np +import h5py from pydantic import BaseModel, ValidationError, Field -from nwb_linkml.types.ndarray import NDArray +from nwb_linkml.types.ndarray import NDArray, NDArrayProxy from nptyping import Shape, Number -from ..fixtures import data_dir +from ..fixtures import data_dir, tmp_output_dir, tmp_output_dir_func def test_ndarray_type(): class Model(BaseModel): array: NDArray[Shape["2 x, * y"], Number] + array_any: Optional[NDArray[Any, Any]] = None schema = Model.model_json_schema() assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'} @@ -29,6 +32,8 @@ def test_ndarray_type(): with pytest.raises(ValidationError): instance = Model(array=np.ones((2,3), dtype=bool)) + instance = Model(array=np.zeros((2,3)), array_any = np.ones((3,4,5))) + def test_ndarray_union(): class Model(BaseModel): @@ -55,6 +60,60 @@ def test_ndarray_union(): with pytest.raises(ValidationError): instance = Model(array=np.random.random((5,10,4,6))) -@pytest.mark.skip() -def test_ndarray_proxy(data_dir): - h5f_source = data_dir / 'aibs_ecephys.nwb' +def test_ndarray_coercion(): + """ + Coerce lists to arrays + """ + + class Model(BaseModel): + array: NDArray[Shape["* x"], Number] + + amod = Model(array=[1,2,3,4.5]) + assert np.allclose(amod.array, np.array([1,2,3,4.5])) + with pytest.raises(ValidationError): + amod = Model(array=['a', 'b', 'c']) + +def test_ndarray_serialize(): + """ + Large arrays should get compressed with blosc, otherwise just to list + """ + class Model(BaseModel): + large_array: NDArray[Any, Number] + small_array: NDArray[Any, Number] + + mod = Model(large_array=np.random.random((1024,1024)), small_array=np.random.random((3,3))) + mod_str = mod.model_dump_json() + mod_json = json.loads(mod_str) + for a in ('array', 'shape', 'dtype', 'unpack_fns'): + assert a in mod_json['large_array'].keys() + assert isinstance(mod_json['large_array']['array'], str) + assert isinstance(mod_json['small_array'], list) + + # but when we just dump to a dict we don't compress + mod_dict = mod.model_dump() + assert isinstance(mod_dict['large_array'], np.ndarray) + + + +def test_ndarray_proxy(tmp_output_dir_func): + h5f_source = tmp_output_dir_func / 'test.h5' + with h5py.File(h5f_source, 'w') as h5f: + dset_good = h5f.create_dataset('/data', data=np.random.random((1024,1024,3))) + dset_bad = h5f.create_dataset('/data_bad', data=np.random.random((1024, 1024, 4))) + + class Model(BaseModel): + array: NDArray[Shape["* x, * y, 3 z"], Number] + + mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data')) + subarray = mod.array[0:5, 0:5, :] + assert isinstance(subarray, np.ndarray) + assert isinstance(subarray.sum(), float) + assert mod.array.name == '/data' + + with pytest.raises(NotImplementedError): + mod.array[0] = 5 + + with pytest.raises(ValidationError): + mod = Model(array=NDArrayProxy(h5f_file=h5f_source, path='/data_bad')) + +