From aac0c7abdd9ca077ce69ba60136a7371b75b8f5a Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 22 Sep 2023 02:48:40 -0700 Subject: [PATCH] need 2 stop for the night but its sort of happening --- nwb_linkml/src/nwb_linkml/io/hdf5.py | 39 +++++- nwb_linkml/src/nwb_linkml/io/hdf5_scratch.py | 134 +++++++++++++++++++ nwb_linkml/tests/test_io/test_io_hdf5.py | 3 +- 3 files changed, 169 insertions(+), 7 deletions(-) create mode 100644 nwb_linkml/src/nwb_linkml/io/hdf5_scratch.py diff --git a/nwb_linkml/src/nwb_linkml/io/hdf5.py b/nwb_linkml/src/nwb_linkml/io/hdf5.py index 96be760..180e273 100644 --- a/nwb_linkml/src/nwb_linkml/io/hdf5.py +++ b/nwb_linkml/src/nwb_linkml/io/hdf5.py @@ -54,9 +54,8 @@ class H5SourceItem(BaseModel): """What kind of hdf5 element this is""" depends: List[str] = Field(default_factory=list) """Paths of other source items that this item depends on before it can be instantiated. eg. from softlinks""" - - - + attrs: dict = Field(default_factory=dict) + """Any static attrs that can be had from the element""" model_config = ConfigDict(arbitrary_types_allowed=True) @property @@ -64,16 +63,23 @@ class H5SourceItem(BaseModel): """path split by /""" return self.path.split('/') +FlatH5 = Dict[str, H5SourceItem] + class ReadQueue(BaseModel): """Container model to store items as they are built """ + h5f: h5py.File = Field( + description="Open hdf5 file used when resolving the queue!" + ) queue: Dict[str,H5SourceItem] = Field( default_factory=dict, description="Items left to be instantiated, keyed by hdf5 path", ) - completed: Dict[str, BaseModel] = Field( + completed: Dict[str, Any] = Field( default_factory=dict, description="Items that have already been instantiated, keyed by hdf5 path" ) + model_config = ConfigDict(arbitrary_types_allowed=True) + class HDF5IO(): @@ -357,13 +363,36 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str, h5_type = 'group' # dereference and get name of reference depends = list(set([h5f[i].name for i in refs])) + if not name.startswith('/'): + name = '/' + name items[name] = H5SourceItem.model_construct( path = name, leaf = leaf, - depends = depends + depends = depends, + h5_type=h5_type, + attrs = dict(obj.attrs.items()) ) h5f.visititems(_itemize) return items +def sort_flat_hdf(flat: Dict[str, H5SourceItem]) -> Dict[str, H5SourceItem]: + """ + Sort flat hdf5 file in a rough order of solvability + + * First process any leaf items + + * Put any items with dependencies at the end + + Args: + flat: + + Returns: + + """ + class Rank(NamedTuple): + has_depends: bool + not_leaf: bool + + diff --git a/nwb_linkml/src/nwb_linkml/io/hdf5_scratch.py b/nwb_linkml/src/nwb_linkml/io/hdf5_scratch.py new file mode 100644 index 0000000..ed6c2d0 --- /dev/null +++ b/nwb_linkml/src/nwb_linkml/io/hdf5_scratch.py @@ -0,0 +1,134 @@ +""" +Just saving a scratch file temporarily where i was trying a different strategy, +rather than doing one big recursive pass through, try and solve subsections +of the tree and then piece them together once you have the others done. + +sort of working. I think what i need to do is populate the 'depends' +field more so that at each pass i can work through the items whose dependencies +have been solved from the bottom up. +""" + +from nwb_linkml.io.hdf5 import HDF5IO, flatten_hdf +import h5py +from typing import NamedTuple, Tuple, Optional +from nwb_linkml.io.hdf5 import H5SourceItem, FlatH5, ReadQueue, HDF5IO +from nwb_linkml.providers.schema import SchemaProvider +from rich import print +from pydantic import BaseModel + + +class Rank(NamedTuple): + has_depends: bool + not_leaf: bool + not_dataset: bool + has_type: bool + +def sort_flat(item:Tuple[str, H5SourceItem]): + + return Rank( + has_depends=len(item[1].depends)>0, + not_leaf = ~item[1].leaf, + not_dataset = item[1].h5_type != 'dataset', + has_type = 'neurodata_type' in item[1].attrs + ) + +def prune_empty(flat: FlatH5) -> FlatH5: + """ + Groups without children or attrs can be removed + """ + deletes = [] + for k,v in flat.items(): + if v.leaf and v.h5_type == 'group' and len(v.attrs) == 0: + deletes.append(k) + + for k in deletes: + del flat[k] + + return flat + +def resolve_scalars(res: ReadQueue) -> ReadQueue: + for path, item in res.queue.copy().items(): + if item.h5_type == 'group': + continue + dset = res.h5f.get(path) + if dset.shape == (): + res.completed[path] = dset[()] + res.queue.pop(path) + return res + +def resolve_terminal_arrays(res:ReadQueue) -> ReadQueue: + """Terminal arrays can just get loaded as a dict""" + for path, item in res.queue.copy().items(): + if item.h5_type != 'dataset' or not item.leaf or len(item.depends) > 0: + continue + h5_object = res.h5f.get(path) + item_dict = { + 'name': path.split('/')[-1], + 'array': h5_object[:], + **h5_object.attrs, + } + res.completed[path] = item_dict + res.queue.pop(path) + return res + +def attempt_parentless(res:ReadQueue, provider:SchemaProvider) -> ReadQueue: + """Try the groups whose parents have no neurodata type (ie. acquisition)""" + for path, item in res.queue.copy().items(): + if item.h5_type == 'dataset': + continue + group = res.h5f.get(path) + if 'neurodata_type' in group.parent.attrs.keys() or 'neurodata_type' not in group.attrs.keys(): + continue + model = provider.get_class(group.attrs['namespace'], group.attrs['neurodata_type']) + res = naive_instantiation(group, model, res) + return res + + + +def naive_instantiation(element: h5py.Group|h5py.Dataset, model:BaseModel, res:ReadQueue) -> Optional[BaseModel]: + """ + Try to instantiate model with just the attrs and any resolved children + """ + print(element) + kwargs = {} + kwargs['name'] = element.name.split('/')[-1] + for k in element.attrs.keys(): + try: + kwargs[k] = element.attrs[k] + except Exception as e: + print(f'couldnt load attr: {e}') + for key, child in element.items(): + if child.name in res.completed: + kwargs[child.name] = res.completed[child.name] + + kwargs = {k:v for k,v in kwargs.items() if k in model.model_fields.keys()} + + try: + instance = model(**kwargs) + res.queue.pop(element.name) + res.completed[element.name] = instance + print('succeeded') + return res + except Exception as e: + print(f'failed: {e}') + return res + + +# -------------------------------------------------- +path = '/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773_probe-769322820_ecephys.nwb' + +h5io = HDF5IO(path) +provider = h5io.make_provider() + +h5f = h5py.File(path) +flat = flatten_hdf(h5f) + +flat = prune_empty(flat) +flat_sorted = dict(sorted(flat.items(), key=sort_flat)) + +res = ReadQueue(h5f=h5f, queue=flat_sorted.copy()) + +res = resolve_scalars(res) +res = resolve_terminal_arrays(res) +res = attempt_parentless(res, provider) + diff --git a/nwb_linkml/tests/test_io/test_io_hdf5.py b/nwb_linkml/tests/test_io/test_io_hdf5.py index 6f5cb8a..6cd54f9 100644 --- a/nwb_linkml/tests/test_io/test_io_hdf5.py +++ b/nwb_linkml/tests/test_io/test_io_hdf5.py @@ -76,7 +76,6 @@ def test_truncate_file(tmp_output_dir): assert target_h5f[target_h5f['link']['child'].attrs['reference_contig']].name == target_h5f['data']['dataset_contig'].name assert target_h5f[target_h5f['link']['child'].attrs['reference_chunked']].name == target_h5f['data']['dataset_chunked'].name assert target_h5f['data']['dataset_contig'].attrs['anattr'] == 1 - @pytest.mark.skip() def test_flatten_hdf(): from nwb_linkml.io.hdf5 import HDF5IO, flatten_hdf @@ -85,6 +84,6 @@ def test_flatten_hdf(): h5f = h5py.File(path) flat = flatten_hdf(h5f) assert not any(['specifications' in v.path for v in flat.values()]) - + pdb.set_trace() raise NotImplementedError('Just a stub for local testing for now, finish me!')