Like 15 validation errors away from fully reading a file.

Just going to make a special completion phase for the top-level NWBFile because it's so special cased it's not worth trying to generalize.

- Resolving hardlinked datasets in read maps
- Resolving 2nd-order empty groups
- Completing container groups
- Move fixing references to recursive function
- Multiple passes for a given phase
- NDArrayProxies should resolve dask or h5py reading on their own
- Correctly build split schema from PydanticProvider instead of needing external logic
This commit is contained in:
sneakers-the-rat 2023-10-02 22:19:11 -07:00
parent eca7a5ec2e
commit 34f8969fa9
5 changed files with 213 additions and 80 deletions

View file

@ -247,7 +247,6 @@ class NWBPydanticGenerator(PydanticGenerator):
# Don't get classes that are defined in this schema!
if module_name == self.schema.name:
continue
# pdb.set_trace()
schema_name = module_name.split('.')[0]
if self.versions and schema_name != self.schema.name.split('.')[0] and schema_name in self.versions:
version = version_module_case(self.versions[schema_name])
@ -255,6 +254,7 @@ class NWBPydanticGenerator(PydanticGenerator):
local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + module_case(module_name)
else:
local_mod_name = '...' + module_case(schema_name) + '.' + version + '.' + 'namespace'
else:
local_mod_name = '.' + module_case(module_name)
@ -709,6 +709,14 @@ class NWBPydanticGenerator(PydanticGenerator):
self.generate_python_range(slot_range, s, class_def)
for slot_range in slot_ranges
]
# --------------------------------------------------
# Special Case - since we get abstract classes from
# potentially multiple versions (which are then different)
# model classes, we allow container classes to also
# be generic descendants of BaseModel
# --------------------------------------------------
if 'DynamicTable' in pyranges:
pyranges.append('BaseModel')
pyranges = list(set(pyranges)) # remove duplicates
pyranges.sort()

View file

@ -80,7 +80,6 @@ class HDF5IO():
provider=provider
)
#pdb.set_trace()
# Apply initial planning phase of reading
queue.apply_phase(ReadPhases.plan)
print('phase - plan completed')
@ -89,23 +88,14 @@ class HDF5IO():
print('phase - read completed')
# pdb.set_trace()
# if len(queue.queue)> 0:
# warnings.warn('Did not complete all items during read phase!')
queue.apply_phase(ReadPhases.construct)
pdb.set_trace()
# --------------------------------------------------
# FIXME: Hardcoding top-level file reading just for the win
# --------------------------------------------------
root = finish_root_hackily(queue)
file = NWBFile(**root)
pdb.set_trace()

View file

@ -4,9 +4,10 @@ Maps for reading and writing from HDF5
We have sort of diverged from the initial idea of a generalized map as in :class:`linkml.map.Map` ,
so we will make our own mapping class here and re-evaluate whether they should be unified later
"""
import pdb
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal, List, Dict, Optional, Type, Union
from typing import Literal, List, Dict, Optional, Type, Union, Tuple
import h5py
from enum import StrEnum
@ -135,16 +136,50 @@ class HDF5Map(ABC):
# --------------------------------------------------
# Planning maps
# --------------------------------------------------
def check_empty(obj: h5py.Group) -> bool:
"""
Check if a group has no attrs or children OR has no attrs and all its children also have no attrs and no children
Returns:
bool
"""
if isinstance(obj, h5py.Dataset):
return False
# check if we are empty
no_attrs = False
if len(obj.attrs) == 0:
no_attrs = True
no_children = False
if len(obj.keys()) == 0:
no_children = True
# check if immediate children are empty
# handles empty groups of empty groups
children_empty = False
if all([isinstance(item, h5py.Group) and \
len(item.keys()) == 0 and \
len(item.attrs) == 0 \
for item in obj.values()]):
children_empty = True
# if we have no attrs and we are a leaf OR our children are empty, remove us
if no_attrs and (no_children or children_empty):
return True
else:
return False
class PruneEmpty(HDF5Map):
"""Remove groups with no attrs """
phase = ReadPhases.plan
@classmethod
def check(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if src.leaf and src.h5_type == 'group':
if src.h5_type == 'group':
with h5py.File(src.h5f_path, 'r') as h5f:
obj = h5f.get(src.path)
if len(obj.attrs) == 0:
return True
return check_empty(obj)
@classmethod
def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
@ -268,15 +303,18 @@ class ResolveModelGroup(HDF5Map):
obj = h5f.get(src.path)
for key, type in model.model_fields.items():
if key == 'children':
res[key] = {name: HDF5_Path(child.name) for name, child in obj.items()}
depends.extend([HDF5_Path(child.name) for child in obj.values()])
res[key] = {name: resolve_hardlink(child) for name, child in obj.items()}
depends.extend([resolve_hardlink(child)for child in obj.values()])
elif key in obj.attrs:
res[key] = obj.attrs[key]
continue
elif key in obj.keys():
# make sure it's not empty
if check_empty(obj[key]):
continue
# stash a reference to this, we'll compile it at the end
depends.append(HDF5_Path(obj[key].name))
res[key] = HDF5_Path(obj[key].name)
depends.append(resolve_hardlink(obj[key]))
res[key] = resolve_hardlink(obj[key])
res['hdf5_path'] = src.path
@ -393,16 +431,18 @@ class ResolveContainerGroups(HDF5Map):
children[k] = HDF5_Path(v.name)
depends.append(HDF5_Path(v.name))
res = {
'name': src.parts[-1],
**children
}
# res = {
# 'name': src.parts[-1],
# 'hdf5_path': src.path,
# **children
# }
return H5ReadResult(
path=src.path,
source=src,
completed=True,
result=res,
result=children,
depends=depends,
applied=['ResolveContainerGroups']
)
@ -411,24 +451,64 @@ class ResolveContainerGroups(HDF5Map):
# Completion Steps
# --------------------------------------------------
class CompleteDynamicTables(HDF5Map):
"""Nothing to do! already done!"""
class CompletePassThrough(HDF5Map):
"""
Passthrough map for the construction phase for models that don't need any more work done
- :class:`.ResolveDynamicTable`
- :class:`.ResolveDatasetAsDict`
- :class:`.ResolveScalars`
"""
phase = ReadPhases.construct
priority = 1
@classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if 'ResolveDynamicTable' in src.applied:
passthrough_ops = ('ResolveDynamicTable', 'ResolveDatasetAsDict', 'ResolveScalars')
for op in passthrough_ops:
if hasattr(src, 'applied') and op in src.applied:
return True
return False
@classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
return src
class CompleteContainerGroups(HDF5Map):
"""
Complete container groups (usually top-level groups like /acquisition)
that do not have a ndueodata type of their own by resolving them as dictionaries
of values (that will then be given to their parent model)
"""
phase = ReadPhases.construct
priority = 3
@classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if src.model is None and \
src.neurodata_type is None and \
src.source.h5_type == 'group' and \
all([depend in completed.keys() for depend in src.depends]):
return True
else:
return False
@classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
return src
res, errors, completes = resolve_references(src.result, completed)
return H5ReadResult(
result=res,
errors=errors,
completes=completes,
**src.model_dump(exclude={'result', 'errors', 'completes'})
)
class CompleteModelGroups(HDF5Map):
phase = ReadPhases.construct
priority = 2
priority = 4
@classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
@ -442,31 +522,15 @@ class CompleteModelGroups(HDF5Map):
@classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
# gather any results that were left for completion elsewhere
# first get all already-completed items
res = {k:v for k,v in src.result.items() if not isinstance(v, HDF5_Path)}
errors = []
completes = []
for path, item in src.result.items():
if isinstance(item, HDF5_Path):
other_item = completed.get(item, None)
if other_item is None:
errors.append(f'Couldnt find {item}')
continue
if isinstance(other_item.result, dict):
# resolve any other children that it might have...
# FIXME: refactor this lmao so bad
for k,v in other_item.result.items():
if isinstance(v, HDF5_Path):
inner_result = completed.get(v, None)
if inner_result is None:
errors.append(f'Couldnt find inner item {v}')
continue
other_item.result[k] = inner_result.result
completes.append(v)
res[other_item.result['name']] = other_item.result
else:
res[path] = other_item.result
unpacked_results, errors, completes = resolve_references(src.result, completed)
res.update(unpacked_results)
completes.append(other_item.path)
# # final cleanups
# for key, val in res.items():
# # if we're supposed to be a list, but instead we're an array, fix that!
#
#try:
instance = src.model(**res)
@ -523,11 +587,10 @@ class ReadQueue(BaseModel):
default_factory=list,
description="Phases that have already been completed")
def apply_phase(self, phase:ReadPhases):
def apply_phase(self, phase:ReadPhases, max_passes=5):
phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase]
phase_maps = sorted(phase_maps, key=lambda x: x.priority)
results = []
# TODO: Thread/multiprocess this
@ -585,6 +648,8 @@ class ReadQueue(BaseModel):
# if we're not in the last phase, move our completed to our queue
self.queue = self.completed
self.completed = {}
elif max_passes>0:
self.apply_phase(phase, max_passes=max_passes-1)
@ -606,6 +671,7 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str,
def _itemize(name: str, obj: h5py.Dataset | h5py.Group):
if skip in name:
return
leaf = isinstance(obj, h5py.Dataset) or len(obj.keys()) == 0
if isinstance(obj, h5py.Dataset):
@ -635,8 +701,8 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str,
)
h5f.visititems(_itemize)
# # then add the root item
# _itemize(h5f.name, h5f)
# then add the root item
_itemize(h5f.name, h5f)
return items
@ -681,3 +747,45 @@ def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]:
else:
depends = list(set([obj.get(i).name for i in refs]))
return depends
def resolve_references(src: dict, completed: Dict[str, H5ReadResult]) -> Tuple[dict, List[str], List[HDF5_Path]]:
"""
Recursively replace references to other completed items with their results
"""
completes = []
errors = []
res = {}
for path, item in src.items():
if isinstance(item, HDF5_Path):
other_item = completed.get(item, None)
if other_item is None:
errors.append(f"Couldnt find: {item}")
res[path] = other_item.result
completes.append(item)
elif isinstance(item, dict):
inner_res, inner_error, inner_completes = resolve_references(item, completed)
res[path] = inner_res
errors.extend(inner_error)
completes.extend(inner_completes)
else:
res[path] = item
return res, errors, completes
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> HDF5_Path:
"""
Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have
to do extra work to check if an item is "real" or a hardlink to another item.
Particularly, an item will be excluded from the ``visititems`` method used by
:func:`.flatten_hdf` if it is a hardlink rather than an "original" dataset,
meaning that we don't even have them in our sources list when start reading.
We basically dereference the object and return that path instead of the path
given by the object's ``name``
"""
return HDF5_Path(obj.file[obj.ref].name)

View file

@ -106,11 +106,11 @@ def dynamictable_to_model(
try:
items[col] = da.from_array(group[col])
except NotImplementedError:
if str in get_inner_types(col_type.annotation):
# dask can't handle this, we just arrayproxy it
items[col] = NDArrayProxy(h5f_file=group.file.filename, path=group[col].name)
else:
warnings.warn(f"Dask cant handle object type arrays like {col} in {group.name}. Skipping")
# if str in get_inner_types(col_type.annotation):
# # dask can't handle this, we just arrayproxy it
items[col] = NDArrayProxy(h5f_file=group.file.filename, path=group[col].name)
#else:
# warnings.warn(f"Dask cant handle object type arrays like {col} in {group.name}. Skipping")
# pdb.set_trace()
# # can't auto-chunk with "object" type
# items[col] = da.from_array(group[col], chunks=-1)

View file

@ -468,11 +468,11 @@ class PydanticProvider(Provider):
out_file: Optional[Path] = None,
version: Optional[str] = None,
versions: Optional[dict] = None,
split: bool = False,
split: bool = True,
dump: bool = True,
force: bool = False,
**kwargs
) -> str:
) -> str | List[str]:
"""
Notes:
@ -528,12 +528,6 @@ class PydanticProvider(Provider):
fn = module_case(fn) + '.py'
out_file = self.path / name / version / fn
if out_file.exists() and not force:
with open(out_file, 'r') as ofile:
serialized = ofile.read()
return serialized
default_kwargs = {
'split': split,
'emit_metadata': True,
@ -541,6 +535,17 @@ class PydanticProvider(Provider):
'pydantic_version': '2'
}
default_kwargs.update(kwargs)
if split:
return self._build_split(path, versions, default_kwargs, dump, out_file, force)
else:
return self._build_unsplit(path, versions, default_kwargs, dump, out_file, force)
def _build_unsplit(self, path, versions, default_kwargs, dump, out_file, force):
if out_file.exists() and not force:
with open(out_file, 'r') as ofile:
serialized = ofile.read()
return serialized
generator = NWBPydanticGenerator(
str(path),
@ -552,19 +557,28 @@ class PydanticProvider(Provider):
out_file.parent.mkdir(parents=True,exist_ok=True)
with open(out_file, 'w') as ofile:
ofile.write(serialized)
with open(out_file.parent / '__init__.py', 'w') as initfile:
initfile.write(' ')
# make parent file, being a bit more careful because it could be for another module
# make initfiles for this directory and parent,
initfile = out_file.parent / '__init__.py'
parent_init = out_file.parent.parent / '__init__.py'
if not parent_init.exists():
with open(parent_init, 'w') as initfile:
initfile.write(' ')
for ifile in (initfile, parent_init):
if not ifile.exists():
with open(ifile, 'w') as ifile_open:
ifile_open.write(' ')
return serialized
def _build_split(self, path:Path, versions, default_kwargs, dump, out_file, force) -> List[str]:
serialized = []
for schema_file in path.parent.glob('*.yaml'):
this_out = out_file.parent / (module_case(schema_file.stem) + '.py')
serialized.append(self._build_unsplit(schema_file, versions, default_kwargs, dump, this_out, force))
return serialized
@classmethod
def module_name(self, namespace:str, version: str) -> str:
name_pieces = ['nwb_linkml', 'models', 'pydantic', namespace, version_module_case(version), 'namespace']
name_pieces = ['nwb_linkml', 'models', 'pydantic', module_case(namespace), version_module_case(version)]
module_name = '.'.join(name_pieces)
return module_name
def import_module(
@ -594,6 +608,19 @@ class PydanticProvider(Provider):
if not path.exists():
raise ImportError(f'Module has not been built yet {path}')
module_name = self.module_name(namespace, version)
# import module level first - when python does relative imports,
# it needs to have the parent modules imported separately
# this breaks split model creation when they are outside of the
# package repository (ie. loaded from an nwb file) because it tries
# to look for the containing namespace folder within the nwb_linkml package and fails
init_spec = importlib.util.spec_from_file_location(module_name, path.parent / '__init__.py')
init_module = importlib.util.module_from_spec(init_spec)
sys.modules[module_name] = init_module
init_spec.loader.exec_module(init_module)
# then the namespace package
module_name = module_name + '.namespace'
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
@ -638,7 +665,7 @@ class PydanticProvider(Provider):
if version is None:
version = self.available_versions[namespace][-1]
module_name = self.module_name(namespace, version)
module_name = self.module_name(namespace, version) + '.namespace'
if module_name in sys.modules:
return sys.modules[module_name]
@ -745,10 +772,10 @@ class SchemaProvider(Provider):
linkml_provider = LinkMLProvider(path=self.path, verbose=verbose)
pydantic_provider = PydanticProvider(path=self.path, verbose=verbose)
linkml_res = linkml_provider.build(ns_adapter=ns_adapter, **linkml_kwargs)
linkml_res = linkml_provider.build(ns_adapter=ns_adapter, versions=self.versions, **linkml_kwargs)
results = {}
for ns, ns_result in linkml_res.items():
results[ns] = pydantic_provider.build(ns_result['namespace'], **pydantic_kwargs)
results[ns] = pydantic_provider.build(ns_result['namespace'], versions=self.versions, **pydantic_kwargs)
return results
def get(self, namespace: str, version: Optional[str] = None) -> ModuleType: