diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 6bb9ceb..b872ca0 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -247,7 +247,6 @@ class NWBPydanticGenerator(PydanticGenerator): # Don't get classes that are defined in this schema! if module_name == self.schema.name: continue - # pdb.set_trace() schema_name = module_name.split('.')[0] if self.versions and schema_name != self.schema.name.split('.')[0] and schema_name in self.versions: version = version_module_case(self.versions[schema_name]) @@ -255,6 +254,7 @@ class NWBPydanticGenerator(PydanticGenerator): local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + module_case(module_name) else: local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + 'namespace' + else: local_mod_name = '.' + module_case(module_name) @@ -709,6 +709,14 @@ class NWBPydanticGenerator(PydanticGenerator): self.generate_python_range(slot_range, s, class_def) for slot_range in slot_ranges ] + # -------------------------------------------------- + # Special Case - since we get abstract classes from + # potentially multiple versions (which are then different) + # model classes, we allow container classes to also + # be generic descendants of BaseModel + # -------------------------------------------------- + if 'DynamicTable' in pyranges: + pyranges.append('BaseModel') pyranges = list(set(pyranges)) # remove duplicates pyranges.sort() diff --git a/nwb_linkml/src/nwb_linkml/io/hdf5.py b/nwb_linkml/src/nwb_linkml/io/hdf5.py index 69ee62c..214709c 100644 --- a/nwb_linkml/src/nwb_linkml/io/hdf5.py +++ b/nwb_linkml/src/nwb_linkml/io/hdf5.py @@ -80,7 +80,6 @@ class HDF5IO(): provider=provider ) - #pdb.set_trace() # Apply initial planning phase of reading queue.apply_phase(ReadPhases.plan) print('phase - plan completed') @@ -89,23 +88,14 @@ class HDF5IO(): print('phase - read completed') + # pdb.set_trace() # if len(queue.queue)> 0: # warnings.warn('Did not complete all items during read phase!') - - queue.apply_phase(ReadPhases.construct) pdb.set_trace() - # -------------------------------------------------- - # FIXME: Hardcoding top-level file reading just for the win - # -------------------------------------------------- - root = finish_root_hackily(queue) - file = NWBFile(**root) - - - pdb.set_trace() diff --git a/nwb_linkml/src/nwb_linkml/maps/hdf5.py b/nwb_linkml/src/nwb_linkml/maps/hdf5.py index b23f217..8830ce9 100644 --- a/nwb_linkml/src/nwb_linkml/maps/hdf5.py +++ b/nwb_linkml/src/nwb_linkml/maps/hdf5.py @@ -4,9 +4,10 @@ Maps for reading and writing from HDF5 We have sort of diverged from the initial idea of a generalized map as in :class:`linkml.map.Map` , so we will make our own mapping class here and re-evaluate whether they should be unified later """ +import pdb from abc import ABC, abstractmethod from pathlib import Path -from typing import Literal, List, Dict, Optional, Type, Union +from typing import Literal, List, Dict, Optional, Type, Union, Tuple import h5py from enum import StrEnum @@ -135,16 +136,50 @@ class HDF5Map(ABC): # -------------------------------------------------- # Planning maps # -------------------------------------------------- + +def check_empty(obj: h5py.Group) -> bool: + """ + Check if a group has no attrs or children OR has no attrs and all its children also have no attrs and no children + + Returns: + bool + """ + if isinstance(obj, h5py.Dataset): + return False + + # check if we are empty + no_attrs = False + if len(obj.attrs) == 0: + no_attrs = True + + no_children = False + if len(obj.keys()) == 0: + no_children = True + + # check if immediate children are empty + # handles empty groups of empty groups + children_empty = False + if all([isinstance(item, h5py.Group) and \ + len(item.keys()) == 0 and \ + len(item.attrs) == 0 \ + for item in obj.values()]): + children_empty = True + + # if we have no attrs and we are a leaf OR our children are empty, remove us + if no_attrs and (no_children or children_empty): + return True + else: + return False + class PruneEmpty(HDF5Map): """Remove groups with no attrs """ phase = ReadPhases.plan @classmethod def check(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool: - if src.leaf and src.h5_type == 'group': + if src.h5_type == 'group': with h5py.File(src.h5f_path, 'r') as h5f: obj = h5f.get(src.path) - if len(obj.attrs) == 0: - return True + return check_empty(obj) @classmethod def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult: @@ -268,15 +303,18 @@ class ResolveModelGroup(HDF5Map): obj = h5f.get(src.path) for key, type in model.model_fields.items(): if key == 'children': - res[key] = {name: HDF5_Path(child.name) for name, child in obj.items()} - depends.extend([HDF5_Path(child.name) for child in obj.values()]) + res[key] = {name: resolve_hardlink(child) for name, child in obj.items()} + depends.extend([resolve_hardlink(child)for child in obj.values()]) elif key in obj.attrs: res[key] = obj.attrs[key] continue elif key in obj.keys(): + # make sure it's not empty + if check_empty(obj[key]): + continue # stash a reference to this, we'll compile it at the end - depends.append(HDF5_Path(obj[key].name)) - res[key] = HDF5_Path(obj[key].name) + depends.append(resolve_hardlink(obj[key])) + res[key] = resolve_hardlink(obj[key]) res['hdf5_path'] = src.path @@ -393,16 +431,18 @@ class ResolveContainerGroups(HDF5Map): children[k] = HDF5_Path(v.name) depends.append(HDF5_Path(v.name)) - res = { - 'name': src.parts[-1], - **children - } + # res = { + # 'name': src.parts[-1], + # 'hdf5_path': src.path, + # **children + # } return H5ReadResult( path=src.path, source=src, completed=True, - result=res, + result=children, + depends=depends, applied=['ResolveContainerGroups'] ) @@ -411,24 +451,64 @@ class ResolveContainerGroups(HDF5Map): # Completion Steps # -------------------------------------------------- -class CompleteDynamicTables(HDF5Map): - """Nothing to do! already done!""" +class CompletePassThrough(HDF5Map): + """ + Passthrough map for the construction phase for models that don't need any more work done + + - :class:`.ResolveDynamicTable` + - :class:`.ResolveDatasetAsDict` + - :class:`.ResolveScalars` + """ phase = ReadPhases.construct priority = 1 @classmethod def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool: - if 'ResolveDynamicTable' in src.applied: + passthrough_ops = ('ResolveDynamicTable', 'ResolveDatasetAsDict', 'ResolveScalars') + + for op in passthrough_ops: + if hasattr(src, 'applied') and op in src.applied: + return True + return False + + @classmethod + def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult: + return src + +class CompleteContainerGroups(HDF5Map): + """ + Complete container groups (usually top-level groups like /acquisition) + that do not have a ndueodata type of their own by resolving them as dictionaries + of values (that will then be given to their parent model) + + """ + phase = ReadPhases.construct + priority = 3 + + @classmethod + def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool: + if src.model is None and \ + src.neurodata_type is None and \ + src.source.h5_type == 'group' and \ + all([depend in completed.keys() for depend in src.depends]): return True else: return False @classmethod def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult: - return src + res, errors, completes = resolve_references(src.result, completed) + + return H5ReadResult( + result=res, + errors=errors, + completes=completes, + **src.model_dump(exclude={'result', 'errors', 'completes'}) + ) + class CompleteModelGroups(HDF5Map): phase = ReadPhases.construct - priority = 2 + priority = 4 @classmethod def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool: @@ -442,31 +522,15 @@ class CompleteModelGroups(HDF5Map): @classmethod def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult: # gather any results that were left for completion elsewhere + # first get all already-completed items res = {k:v for k,v in src.result.items() if not isinstance(v, HDF5_Path)} - errors = [] - completes = [] - for path, item in src.result.items(): - if isinstance(item, HDF5_Path): - other_item = completed.get(item, None) - if other_item is None: - errors.append(f'Couldnt find {item}') - continue - if isinstance(other_item.result, dict): - # resolve any other children that it might have... - # FIXME: refactor this lmao so bad - for k,v in other_item.result.items(): - if isinstance(v, HDF5_Path): - inner_result = completed.get(v, None) - if inner_result is None: - errors.append(f'Couldnt find inner item {v}') - continue - other_item.result[k] = inner_result.result - completes.append(v) - res[other_item.result['name']] = other_item.result - else: - res[path] = other_item.result + unpacked_results, errors, completes = resolve_references(src.result, completed) + res.update(unpacked_results) - completes.append(other_item.path) + # # final cleanups + # for key, val in res.items(): + # # if we're supposed to be a list, but instead we're an array, fix that! + # #try: instance = src.model(**res) @@ -523,11 +587,10 @@ class ReadQueue(BaseModel): default_factory=list, description="Phases that have already been completed") - def apply_phase(self, phase:ReadPhases): + def apply_phase(self, phase:ReadPhases, max_passes=5): phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase] phase_maps = sorted(phase_maps, key=lambda x: x.priority) - results = [] # TODO: Thread/multiprocess this @@ -585,6 +648,8 @@ class ReadQueue(BaseModel): # if we're not in the last phase, move our completed to our queue self.queue = self.completed self.completed = {} + elif max_passes>0: + self.apply_phase(phase, max_passes=max_passes-1) @@ -606,6 +671,7 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str, def _itemize(name: str, obj: h5py.Dataset | h5py.Group): if skip in name: return + leaf = isinstance(obj, h5py.Dataset) or len(obj.keys()) == 0 if isinstance(obj, h5py.Dataset): @@ -635,8 +701,8 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str, ) h5f.visititems(_itemize) - # # then add the root item - # _itemize(h5f.name, h5f) + # then add the root item + _itemize(h5f.name, h5f) return items @@ -681,3 +747,45 @@ def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]: else: depends = list(set([obj.get(i).name for i in refs])) return depends + +def resolve_references(src: dict, completed: Dict[str, H5ReadResult]) -> Tuple[dict, List[str], List[HDF5_Path]]: + """ + Recursively replace references to other completed items with their results + + """ + completes = [] + errors = [] + res = {} + for path, item in src.items(): + if isinstance(item, HDF5_Path): + other_item = completed.get(item, None) + if other_item is None: + errors.append(f"Couldnt find: {item}") + res[path] = other_item.result + completes.append(item) + + elif isinstance(item, dict): + inner_res, inner_error, inner_completes = resolve_references(item, completed) + res[path] = inner_res + errors.extend(inner_error) + completes.extend(inner_completes) + else: + res[path] = item + return res, errors, completes + +def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> HDF5_Path: + """ + Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have + to do extra work to check if an item is "real" or a hardlink to another item. + + Particularly, an item will be excluded from the ``visititems`` method used by + :func:`.flatten_hdf` if it is a hardlink rather than an "original" dataset, + meaning that we don't even have them in our sources list when start reading. + + We basically dereference the object and return that path instead of the path + given by the object's ``name`` + """ + return HDF5_Path(obj.file[obj.ref].name) + + + diff --git a/nwb_linkml/src/nwb_linkml/maps/hdmf.py b/nwb_linkml/src/nwb_linkml/maps/hdmf.py index 63a7555..d92bcb8 100644 --- a/nwb_linkml/src/nwb_linkml/maps/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/maps/hdmf.py @@ -106,11 +106,11 @@ def dynamictable_to_model( try: items[col] = da.from_array(group[col]) except NotImplementedError: - if str in get_inner_types(col_type.annotation): - # dask can't handle this, we just arrayproxy it - items[col] = NDArrayProxy(h5f_file=group.file.filename, path=group[col].name) - else: - warnings.warn(f"Dask cant handle object type arrays like {col} in {group.name}. Skipping") + # if str in get_inner_types(col_type.annotation): + # # dask can't handle this, we just arrayproxy it + items[col] = NDArrayProxy(h5f_file=group.file.filename, path=group[col].name) + #else: + # warnings.warn(f"Dask cant handle object type arrays like {col} in {group.name}. Skipping") # pdb.set_trace() # # can't auto-chunk with "object" type # items[col] = da.from_array(group[col], chunks=-1) diff --git a/nwb_linkml/src/nwb_linkml/providers/schema.py b/nwb_linkml/src/nwb_linkml/providers/schema.py index fa58b8d..b456cc0 100644 --- a/nwb_linkml/src/nwb_linkml/providers/schema.py +++ b/nwb_linkml/src/nwb_linkml/providers/schema.py @@ -468,11 +468,11 @@ class PydanticProvider(Provider): out_file: Optional[Path] = None, version: Optional[str] = None, versions: Optional[dict] = None, - split: bool = False, + split: bool = True, dump: bool = True, force: bool = False, **kwargs - ) -> str: + ) -> str | List[str]: """ Notes: @@ -528,12 +528,6 @@ class PydanticProvider(Provider): fn = module_case(fn) + '.py' out_file = self.path / name / version / fn - if out_file.exists() and not force: - with open(out_file, 'r') as ofile: - serialized = ofile.read() - return serialized - - default_kwargs = { 'split': split, 'emit_metadata': True, @@ -541,6 +535,17 @@ class PydanticProvider(Provider): 'pydantic_version': '2' } default_kwargs.update(kwargs) + if split: + return self._build_split(path, versions, default_kwargs, dump, out_file, force) + else: + return self._build_unsplit(path, versions, default_kwargs, dump, out_file, force) + + + def _build_unsplit(self, path, versions, default_kwargs, dump, out_file, force): + if out_file.exists() and not force: + with open(out_file, 'r') as ofile: + serialized = ofile.read() + return serialized generator = NWBPydanticGenerator( str(path), @@ -552,19 +557,28 @@ class PydanticProvider(Provider): out_file.parent.mkdir(parents=True,exist_ok=True) with open(out_file, 'w') as ofile: ofile.write(serialized) - with open(out_file.parent / '__init__.py', 'w') as initfile: - initfile.write(' ') - # make parent file, being a bit more careful because it could be for another module + + # make initfiles for this directory and parent, + initfile = out_file.parent / '__init__.py' parent_init = out_file.parent.parent / '__init__.py' - if not parent_init.exists(): - with open(parent_init, 'w') as initfile: - initfile.write(' ') + for ifile in (initfile, parent_init): + if not ifile.exists(): + with open(ifile, 'w') as ifile_open: + ifile_open.write(' ') return serialized + def _build_split(self, path:Path, versions, default_kwargs, dump, out_file, force) -> List[str]: + serialized = [] + for schema_file in path.parent.glob('*.yaml'): + this_out = out_file.parent / (module_case(schema_file.stem) + '.py') + serialized.append(self._build_unsplit(schema_file, versions, default_kwargs, dump, this_out, force)) + return serialized + + @classmethod def module_name(self, namespace:str, version: str) -> str: - name_pieces = ['nwb_linkml', 'models', 'pydantic', namespace, version_module_case(version), 'namespace'] + name_pieces = ['nwb_linkml', 'models', 'pydantic', module_case(namespace), version_module_case(version)] module_name = '.'.join(name_pieces) return module_name def import_module( @@ -594,6 +608,19 @@ class PydanticProvider(Provider): if not path.exists(): raise ImportError(f'Module has not been built yet {path}') module_name = self.module_name(namespace, version) + + # import module level first - when python does relative imports, + # it needs to have the parent modules imported separately + # this breaks split model creation when they are outside of the + # package repository (ie. loaded from an nwb file) because it tries + # to look for the containing namespace folder within the nwb_linkml package and fails + init_spec = importlib.util.spec_from_file_location(module_name, path.parent / '__init__.py') + init_module = importlib.util.module_from_spec(init_spec) + sys.modules[module_name] = init_module + init_spec.loader.exec_module(init_module) + + # then the namespace package + module_name = module_name + '.namespace' spec = importlib.util.spec_from_file_location(module_name, path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module @@ -638,7 +665,7 @@ class PydanticProvider(Provider): if version is None: version = self.available_versions[namespace][-1] - module_name = self.module_name(namespace, version) + module_name = self.module_name(namespace, version) + '.namespace' if module_name in sys.modules: return sys.modules[module_name] @@ -745,10 +772,10 @@ class SchemaProvider(Provider): linkml_provider = LinkMLProvider(path=self.path, verbose=verbose) pydantic_provider = PydanticProvider(path=self.path, verbose=verbose) - linkml_res = linkml_provider.build(ns_adapter=ns_adapter, **linkml_kwargs) + linkml_res = linkml_provider.build(ns_adapter=ns_adapter, versions=self.versions, **linkml_kwargs) results = {} for ns, ns_result in linkml_res.items(): - results[ns] = pydantic_provider.build(ns_result['namespace'], **pydantic_kwargs) + results[ns] = pydantic_provider.build(ns_result['namespace'], versions=self.versions, **pydantic_kwargs) return results def get(self, namespace: str, version: Optional[str] = None) -> ModuleType: