Like 15 validation errors away from fully reading a file.

Just going to make a special completion phase for the top-level NWBFile because it's so special cased it's not worth trying to generalize.

- Resolving hardlinked datasets in read maps
- Resolving 2nd-order empty groups
- Completing container groups
- Move fixing references to recursive function
- Multiple passes for a given phase
- NDArrayProxies should resolve dask or h5py reading on their own
- Correctly build split schema from PydanticProvider instead of needing external logic
This commit is contained in:
sneakers-the-rat 2023-10-02 22:19:11 -07:00
parent eca7a5ec2e
commit 34f8969fa9
5 changed files with 213 additions and 80 deletions

View file

@ -247,7 +247,6 @@ class NWBPydanticGenerator(PydanticGenerator):
# Don't get classes that are defined in this schema! # Don't get classes that are defined in this schema!
if module_name == self.schema.name: if module_name == self.schema.name:
continue continue
# pdb.set_trace()
schema_name = module_name.split('.')[0] schema_name = module_name.split('.')[0]
if self.versions and schema_name != self.schema.name.split('.')[0] and schema_name in self.versions: if self.versions and schema_name != self.schema.name.split('.')[0] and schema_name in self.versions:
version = version_module_case(self.versions[schema_name]) version = version_module_case(self.versions[schema_name])
@ -255,6 +254,7 @@ class NWBPydanticGenerator(PydanticGenerator):
local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + module_case(module_name) local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + module_case(module_name)
else: else:
local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + 'namespace' local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + 'namespace'
else: else:
local_mod_name = '.' + module_case(module_name) local_mod_name = '.' + module_case(module_name)
@ -709,6 +709,14 @@ class NWBPydanticGenerator(PydanticGenerator):
self.generate_python_range(slot_range, s, class_def) self.generate_python_range(slot_range, s, class_def)
for slot_range in slot_ranges for slot_range in slot_ranges
] ]
# --------------------------------------------------
# Special Case - since we get abstract classes from
# potentially multiple versions (which are then different)
# model classes, we allow container classes to also
# be generic descendants of BaseModel
# --------------------------------------------------
if 'DynamicTable' in pyranges:
pyranges.append('BaseModel')
pyranges = list(set(pyranges)) # remove duplicates pyranges = list(set(pyranges)) # remove duplicates
pyranges.sort() pyranges.sort()

View file

@ -80,7 +80,6 @@ class HDF5IO():
provider=provider provider=provider
) )
#pdb.set_trace()
# Apply initial planning phase of reading # Apply initial planning phase of reading
queue.apply_phase(ReadPhases.plan) queue.apply_phase(ReadPhases.plan)
print('phase - plan completed') print('phase - plan completed')
@ -89,23 +88,14 @@ class HDF5IO():
print('phase - read completed') print('phase - read completed')
# pdb.set_trace()
# if len(queue.queue)> 0: # if len(queue.queue)> 0:
# warnings.warn('Did not complete all items during read phase!') # warnings.warn('Did not complete all items during read phase!')
queue.apply_phase(ReadPhases.construct) queue.apply_phase(ReadPhases.construct)
pdb.set_trace() pdb.set_trace()
# --------------------------------------------------
# FIXME: Hardcoding top-level file reading just for the win
# --------------------------------------------------
root = finish_root_hackily(queue)
file = NWBFile(**root)
pdb.set_trace()

View file

@ -4,9 +4,10 @@ Maps for reading and writing from HDF5
We have sort of diverged from the initial idea of a generalized map as in :class:`linkml.map.Map` , We have sort of diverged from the initial idea of a generalized map as in :class:`linkml.map.Map` ,
so we will make our own mapping class here and re-evaluate whether they should be unified later so we will make our own mapping class here and re-evaluate whether they should be unified later
""" """
import pdb
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Literal, List, Dict, Optional, Type, Union from typing import Literal, List, Dict, Optional, Type, Union, Tuple
import h5py import h5py
from enum import StrEnum from enum import StrEnum
@ -135,16 +136,50 @@ class HDF5Map(ABC):
# -------------------------------------------------- # --------------------------------------------------
# Planning maps # Planning maps
# -------------------------------------------------- # --------------------------------------------------
def check_empty(obj: h5py.Group) -> bool:
"""
Check if a group has no attrs or children OR has no attrs and all its children also have no attrs and no children
Returns:
bool
"""
if isinstance(obj, h5py.Dataset):
return False
# check if we are empty
no_attrs = False
if len(obj.attrs) == 0:
no_attrs = True
no_children = False
if len(obj.keys()) == 0:
no_children = True
# check if immediate children are empty
# handles empty groups of empty groups
children_empty = False
if all([isinstance(item, h5py.Group) and \
len(item.keys()) == 0 and \
len(item.attrs) == 0 \
for item in obj.values()]):
children_empty = True
# if we have no attrs and we are a leaf OR our children are empty, remove us
if no_attrs and (no_children or children_empty):
return True
else:
return False
class PruneEmpty(HDF5Map): class PruneEmpty(HDF5Map):
"""Remove groups with no attrs """ """Remove groups with no attrs """
phase = ReadPhases.plan phase = ReadPhases.plan
@classmethod @classmethod
def check(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool: def check(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if src.leaf and src.h5_type == 'group': if src.h5_type == 'group':
with h5py.File(src.h5f_path, 'r') as h5f: with h5py.File(src.h5f_path, 'r') as h5f:
obj = h5f.get(src.path) obj = h5f.get(src.path)
if len(obj.attrs) == 0: return check_empty(obj)
return True
@classmethod @classmethod
def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult: def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
@ -268,15 +303,18 @@ class ResolveModelGroup(HDF5Map):
obj = h5f.get(src.path) obj = h5f.get(src.path)
for key, type in model.model_fields.items(): for key, type in model.model_fields.items():
if key == 'children': if key == 'children':
res[key] = {name: HDF5_Path(child.name) for name, child in obj.items()} res[key] = {name: resolve_hardlink(child) for name, child in obj.items()}
depends.extend([HDF5_Path(child.name) for child in obj.values()]) depends.extend([resolve_hardlink(child)for child in obj.values()])
elif key in obj.attrs: elif key in obj.attrs:
res[key] = obj.attrs[key] res[key] = obj.attrs[key]
continue continue
elif key in obj.keys(): elif key in obj.keys():
# make sure it's not empty
if check_empty(obj[key]):
continue
# stash a reference to this, we'll compile it at the end # stash a reference to this, we'll compile it at the end
depends.append(HDF5_Path(obj[key].name)) depends.append(resolve_hardlink(obj[key]))
res[key] = HDF5_Path(obj[key].name) res[key] = resolve_hardlink(obj[key])
res['hdf5_path'] = src.path res['hdf5_path'] = src.path
@ -393,16 +431,18 @@ class ResolveContainerGroups(HDF5Map):
children[k] = HDF5_Path(v.name) children[k] = HDF5_Path(v.name)
depends.append(HDF5_Path(v.name)) depends.append(HDF5_Path(v.name))
res = { # res = {
'name': src.parts[-1], # 'name': src.parts[-1],
**children # 'hdf5_path': src.path,
} # **children
# }
return H5ReadResult( return H5ReadResult(
path=src.path, path=src.path,
source=src, source=src,
completed=True, completed=True,
result=res, result=children,
depends=depends,
applied=['ResolveContainerGroups'] applied=['ResolveContainerGroups']
) )
@ -411,24 +451,64 @@ class ResolveContainerGroups(HDF5Map):
# Completion Steps # Completion Steps
# -------------------------------------------------- # --------------------------------------------------
class CompleteDynamicTables(HDF5Map): class CompletePassThrough(HDF5Map):
"""Nothing to do! already done!""" """
Passthrough map for the construction phase for models that don't need any more work done
- :class:`.ResolveDynamicTable`
- :class:`.ResolveDatasetAsDict`
- :class:`.ResolveScalars`
"""
phase = ReadPhases.construct phase = ReadPhases.construct
priority = 1 priority = 1
@classmethod @classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool: def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if 'ResolveDynamicTable' in src.applied: passthrough_ops = ('ResolveDynamicTable', 'ResolveDatasetAsDict', 'ResolveScalars')
for op in passthrough_ops:
if hasattr(src, 'applied') and op in src.applied:
return True return True
else:
return False return False
@classmethod @classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult: def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
return src return src
class CompleteContainerGroups(HDF5Map):
"""
Complete container groups (usually top-level groups like /acquisition)
that do not have a ndueodata type of their own by resolving them as dictionaries
of values (that will then be given to their parent model)
"""
phase = ReadPhases.construct
priority = 3
@classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if src.model is None and \
src.neurodata_type is None and \
src.source.h5_type == 'group' and \
all([depend in completed.keys() for depend in src.depends]):
return True
else:
return False
@classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
res, errors, completes = resolve_references(src.result, completed)
return H5ReadResult(
result=res,
errors=errors,
completes=completes,
**src.model_dump(exclude={'result', 'errors', 'completes'})
)
class CompleteModelGroups(HDF5Map): class CompleteModelGroups(HDF5Map):
phase = ReadPhases.construct phase = ReadPhases.construct
priority = 2 priority = 4
@classmethod @classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool: def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
@ -442,31 +522,15 @@ class CompleteModelGroups(HDF5Map):
@classmethod @classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult: def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
# gather any results that were left for completion elsewhere # gather any results that were left for completion elsewhere
# first get all already-completed items
res = {k:v for k,v in src.result.items() if not isinstance(v, HDF5_Path)} res = {k:v for k,v in src.result.items() if not isinstance(v, HDF5_Path)}
errors = [] unpacked_results, errors, completes = resolve_references(src.result, completed)
completes = [] res.update(unpacked_results)
for path, item in src.result.items():
if isinstance(item, HDF5_Path):
other_item = completed.get(item, None)
if other_item is None:
errors.append(f'Couldnt find {item}')
continue
if isinstance(other_item.result, dict):
# resolve any other children that it might have...
# FIXME: refactor this lmao so bad
for k,v in other_item.result.items():
if isinstance(v, HDF5_Path):
inner_result = completed.get(v, None)
if inner_result is None:
errors.append(f'Couldnt find inner item {v}')
continue
other_item.result[k] = inner_result.result
completes.append(v)
res[other_item.result['name']] = other_item.result
else:
res[path] = other_item.result
completes.append(other_item.path) # # final cleanups
# for key, val in res.items():
# # if we're supposed to be a list, but instead we're an array, fix that!
#
#try: #try:
instance = src.model(**res) instance = src.model(**res)
@ -523,11 +587,10 @@ class ReadQueue(BaseModel):
default_factory=list, default_factory=list,
description="Phases that have already been completed") description="Phases that have already been completed")
def apply_phase(self, phase:ReadPhases): def apply_phase(self, phase:ReadPhases, max_passes=5):
phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase] phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase]
phase_maps = sorted(phase_maps, key=lambda x: x.priority) phase_maps = sorted(phase_maps, key=lambda x: x.priority)
results = [] results = []
# TODO: Thread/multiprocess this # TODO: Thread/multiprocess this
@ -585,6 +648,8 @@ class ReadQueue(BaseModel):
# if we're not in the last phase, move our completed to our queue # if we're not in the last phase, move our completed to our queue
self.queue = self.completed self.queue = self.completed
self.completed = {} self.completed = {}
elif max_passes>0:
self.apply_phase(phase, max_passes=max_passes-1)
@ -606,6 +671,7 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str,
def _itemize(name: str, obj: h5py.Dataset | h5py.Group): def _itemize(name: str, obj: h5py.Dataset | h5py.Group):
if skip in name: if skip in name:
return return
leaf = isinstance(obj, h5py.Dataset) or len(obj.keys()) == 0 leaf = isinstance(obj, h5py.Dataset) or len(obj.keys()) == 0
if isinstance(obj, h5py.Dataset): if isinstance(obj, h5py.Dataset):
@ -635,8 +701,8 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str,
) )
h5f.visititems(_itemize) h5f.visititems(_itemize)
# # then add the root item # then add the root item
# _itemize(h5f.name, h5f) _itemize(h5f.name, h5f)
return items return items
@ -681,3 +747,45 @@ def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]:
else: else:
depends = list(set([obj.get(i).name for i in refs])) depends = list(set([obj.get(i).name for i in refs]))
return depends return depends
def resolve_references(src: dict, completed: Dict[str, H5ReadResult]) -> Tuple[dict, List[str], List[HDF5_Path]]:
"""
Recursively replace references to other completed items with their results
"""
completes = []
errors = []
res = {}
for path, item in src.items():
if isinstance(item, HDF5_Path):
other_item = completed.get(item, None)
if other_item is None:
errors.append(f"Couldnt find: {item}")
res[path] = other_item.result
completes.append(item)
elif isinstance(item, dict):
inner_res, inner_error, inner_completes = resolve_references(item, completed)
res[path] = inner_res
errors.extend(inner_error)
completes.extend(inner_completes)
else:
res[path] = item
return res, errors, completes
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> HDF5_Path:
"""
Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have
to do extra work to check if an item is "real" or a hardlink to another item.
Particularly, an item will be excluded from the ``visititems`` method used by
:func:`.flatten_hdf` if it is a hardlink rather than an "original" dataset,
meaning that we don't even have them in our sources list when start reading.
We basically dereference the object and return that path instead of the path
given by the object's ``name``
"""
return HDF5_Path(obj.file[obj.ref].name)

View file

@ -106,11 +106,11 @@ def dynamictable_to_model(
try: try:
items[col] = da.from_array(group[col]) items[col] = da.from_array(group[col])
except NotImplementedError: except NotImplementedError:
if str in get_inner_types(col_type.annotation): # if str in get_inner_types(col_type.annotation):
# dask can't handle this, we just arrayproxy it # # dask can't handle this, we just arrayproxy it
items[col] = NDArrayProxy(h5f_file=group.file.filename, path=group[col].name) items[col] = NDArrayProxy(h5f_file=group.file.filename, path=group[col].name)
else: #else:
warnings.warn(f"Dask cant handle object type arrays like {col} in {group.name}. Skipping") # warnings.warn(f"Dask cant handle object type arrays like {col} in {group.name}. Skipping")
# pdb.set_trace() # pdb.set_trace()
# # can't auto-chunk with "object" type # # can't auto-chunk with "object" type
# items[col] = da.from_array(group[col], chunks=-1) # items[col] = da.from_array(group[col], chunks=-1)

View file

@ -468,11 +468,11 @@ class PydanticProvider(Provider):
out_file: Optional[Path] = None, out_file: Optional[Path] = None,
version: Optional[str] = None, version: Optional[str] = None,
versions: Optional[dict] = None, versions: Optional[dict] = None,
split: bool = False, split: bool = True,
dump: bool = True, dump: bool = True,
force: bool = False, force: bool = False,
**kwargs **kwargs
) -> str: ) -> str | List[str]:
""" """
Notes: Notes:
@ -528,12 +528,6 @@ class PydanticProvider(Provider):
fn = module_case(fn) + '.py' fn = module_case(fn) + '.py'
out_file = self.path / name / version / fn out_file = self.path / name / version / fn
if out_file.exists() and not force:
with open(out_file, 'r') as ofile:
serialized = ofile.read()
return serialized
default_kwargs = { default_kwargs = {
'split': split, 'split': split,
'emit_metadata': True, 'emit_metadata': True,
@ -541,6 +535,17 @@ class PydanticProvider(Provider):
'pydantic_version': '2' 'pydantic_version': '2'
} }
default_kwargs.update(kwargs) default_kwargs.update(kwargs)
if split:
return self._build_split(path, versions, default_kwargs, dump, out_file, force)
else:
return self._build_unsplit(path, versions, default_kwargs, dump, out_file, force)
def _build_unsplit(self, path, versions, default_kwargs, dump, out_file, force):
if out_file.exists() and not force:
with open(out_file, 'r') as ofile:
serialized = ofile.read()
return serialized
generator = NWBPydanticGenerator( generator = NWBPydanticGenerator(
str(path), str(path),
@ -552,19 +557,28 @@ class PydanticProvider(Provider):
out_file.parent.mkdir(parents=True,exist_ok=True) out_file.parent.mkdir(parents=True,exist_ok=True)
with open(out_file, 'w') as ofile: with open(out_file, 'w') as ofile:
ofile.write(serialized) ofile.write(serialized)
with open(out_file.parent / '__init__.py', 'w') as initfile:
initfile.write(' ') # make initfiles for this directory and parent,
# make parent file, being a bit more careful because it could be for another module initfile = out_file.parent / '__init__.py'
parent_init = out_file.parent.parent / '__init__.py' parent_init = out_file.parent.parent / '__init__.py'
if not parent_init.exists(): for ifile in (initfile, parent_init):
with open(parent_init, 'w') as initfile: if not ifile.exists():
initfile.write(' ') with open(ifile, 'w') as ifile_open:
ifile_open.write(' ')
return serialized return serialized
def _build_split(self, path:Path, versions, default_kwargs, dump, out_file, force) -> List[str]:
serialized = []
for schema_file in path.parent.glob('*.yaml'):
this_out = out_file.parent / (module_case(schema_file.stem) + '.py')
serialized.append(self._build_unsplit(schema_file, versions, default_kwargs, dump, this_out, force))
return serialized
@classmethod @classmethod
def module_name(self, namespace:str, version: str) -> str: def module_name(self, namespace:str, version: str) -> str:
name_pieces = ['nwb_linkml', 'models', 'pydantic', namespace, version_module_case(version), 'namespace'] name_pieces = ['nwb_linkml', 'models', 'pydantic', module_case(namespace), version_module_case(version)]
module_name = '.'.join(name_pieces) module_name = '.'.join(name_pieces)
return module_name return module_name
def import_module( def import_module(
@ -594,6 +608,19 @@ class PydanticProvider(Provider):
if not path.exists(): if not path.exists():
raise ImportError(f'Module has not been built yet {path}') raise ImportError(f'Module has not been built yet {path}')
module_name = self.module_name(namespace, version) module_name = self.module_name(namespace, version)
# import module level first - when python does relative imports,
# it needs to have the parent modules imported separately
# this breaks split model creation when they are outside of the
# package repository (ie. loaded from an nwb file) because it tries
# to look for the containing namespace folder within the nwb_linkml package and fails
init_spec = importlib.util.spec_from_file_location(module_name, path.parent / '__init__.py')
init_module = importlib.util.module_from_spec(init_spec)
sys.modules[module_name] = init_module
init_spec.loader.exec_module(init_module)
# then the namespace package
module_name = module_name + '.namespace'
spec = importlib.util.spec_from_file_location(module_name, path) spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module sys.modules[module_name] = module
@ -638,7 +665,7 @@ class PydanticProvider(Provider):
if version is None: if version is None:
version = self.available_versions[namespace][-1] version = self.available_versions[namespace][-1]
module_name = self.module_name(namespace, version) module_name = self.module_name(namespace, version) + '.namespace'
if module_name in sys.modules: if module_name in sys.modules:
return sys.modules[module_name] return sys.modules[module_name]
@ -745,10 +772,10 @@ class SchemaProvider(Provider):
linkml_provider = LinkMLProvider(path=self.path, verbose=verbose) linkml_provider = LinkMLProvider(path=self.path, verbose=verbose)
pydantic_provider = PydanticProvider(path=self.path, verbose=verbose) pydantic_provider = PydanticProvider(path=self.path, verbose=verbose)
linkml_res = linkml_provider.build(ns_adapter=ns_adapter, **linkml_kwargs) linkml_res = linkml_provider.build(ns_adapter=ns_adapter, versions=self.versions, **linkml_kwargs)
results = {} results = {}
for ns, ns_result in linkml_res.items(): for ns, ns_result in linkml_res.items():
results[ns] = pydantic_provider.build(ns_result['namespace'], **pydantic_kwargs) results[ns] = pydantic_provider.build(ns_result['namespace'], versions=self.versions, **pydantic_kwargs)
return results return results
def get(self, namespace: str, version: Optional[str] = None) -> ModuleType: def get(self, namespace: str, version: Optional[str] = None) -> ModuleType: