clean up old hdf5 reader methods, fix truncate_hdf5 method, make proper'd test data files with working references

This commit is contained in:
sneakers-the-rat 2024-09-11 19:02:15 -07:00
parent 1f1325e4aa
commit d31ac29294
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
6 changed files with 173 additions and 853 deletions

View file

@ -22,7 +22,6 @@ Other TODO:
import json import json
import os import os
import pdb
import re import re
import shutil import shutil
import subprocess import subprocess
@ -39,7 +38,12 @@ from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
from nwb_linkml.maps.hdf5 import get_references, resolve_hardlink from nwb_linkml.maps.hdf5 import (
get_attr_references,
get_dataset_references,
get_references,
resolve_hardlink,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nwb_linkml.providers.schema import SchemaProvider from nwb_linkml.providers.schema import SchemaProvider
@ -342,8 +346,6 @@ class HDF5IO:
path = "/" path = "/"
return context[path] return context[path]
pdb.set_trace()
def write(self, path: Path) -> Never: def write(self, path: Path) -> Never:
""" """
Write to NWB file Write to NWB file
@ -396,7 +398,7 @@ class HDF5IO:
provider = SchemaProvider(versions=versions) provider = SchemaProvider(versions=versions)
# build schema so we have them cached # build schema so we have them cached
# provider.build_from_dicts(schema) provider.build_from_dicts(schema)
h5f.close() h5f.close()
return provider return provider
@ -484,7 +486,7 @@ def find_references(h5f: h5py.File, path: str) -> List[str]:
return references return references
def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> Path: def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> Path | None:
""" """
Create a truncated HDF5 file where only the first few samples are kept. Create a truncated HDF5 file where only the first few samples are kept.
@ -500,6 +502,14 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
Returns: Returns:
:class:`pathlib.Path` path of the truncated file :class:`pathlib.Path` path of the truncated file
""" """
if shutil.which("h5repack") is None:
warnings.warn(
"Truncation requires h5repack to be available, "
"or else the truncated files will be no smaller than the originals",
stacklevel=2,
)
return
target = source.parent / (source.stem + "_truncated.hdf5") if target is None else Path(target) target = source.parent / (source.stem + "_truncated.hdf5") if target is None else Path(target)
source = Path(source) source = Path(source)
@ -515,17 +525,34 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
os.chmod(target, 0o774) os.chmod(target, 0o774)
to_resize = [] to_resize = []
attr_refs = {}
dataset_refs = {}
def _need_resizing(name: str, obj: h5py.Dataset | h5py.Group) -> None: def _need_resizing(name: str, obj: h5py.Dataset | h5py.Group) -> None:
if isinstance(obj, h5py.Dataset) and obj.size > n: if isinstance(obj, h5py.Dataset) and obj.size > n:
to_resize.append(name) to_resize.append(name)
print("Resizing datasets...") def _find_attr_refs(name: str, obj: h5py.Dataset | h5py.Group) -> None:
"""Find all references in object attrs"""
refs = get_attr_references(obj)
if refs:
attr_refs[name] = refs
def _find_dataset_refs(name: str, obj: h5py.Dataset | h5py.Group) -> None:
"""Find all references in datasets themselves"""
refs = get_dataset_references(obj)
if refs:
dataset_refs[name] = refs
# first we get the items that need to be resized and then resize them below # first we get the items that need to be resized and then resize them below
# problems with writing to the file from within the visititems call # problems with writing to the file from within the visititems call
print("Planning resize...")
h5f_target = h5py.File(str(target), "r+") h5f_target = h5py.File(str(target), "r+")
h5f_target.visititems(_need_resizing) h5f_target.visititems(_need_resizing)
h5f_target.visititems(_find_attr_refs)
h5f_target.visititems(_find_dataset_refs)
print("Resizing datasets...")
for resize in to_resize: for resize in to_resize:
obj = h5f_target.get(resize) obj = h5f_target.get(resize)
try: try:
@ -535,10 +562,14 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
# so we have to copy and create a new dataset # so we have to copy and create a new dataset
tmp_name = obj.name + "__tmp" tmp_name = obj.name + "__tmp"
original_name = obj.name original_name = obj.name
obj.parent.move(obj.name, tmp_name) obj.parent.move(obj.name, tmp_name)
old_obj = obj.parent.get(tmp_name) old_obj = obj.parent.get(tmp_name)
new_obj = obj.parent.create_dataset(original_name, data=old_obj[0:n]) new_obj = obj.parent.create_dataset(
original_name, data=old_obj[0:n], dtype=old_obj.dtype
)
for k, v in old_obj.attrs.items(): for k, v in old_obj.attrs.items():
new_obj.attrs[k] = v new_obj.attrs[k] = v
del new_obj.parent[tmp_name] del new_obj.parent[tmp_name]
@ -546,16 +577,18 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
h5f_target.close() h5f_target.close()
# use h5repack to actually remove the items from the dataset # use h5repack to actually remove the items from the dataset
if shutil.which("h5repack") is None:
warnings.warn(
"Truncated file made, but since h5repack not found in path, file won't be any smaller",
stacklevel=2,
)
return target
print("Repacking hdf5...") print("Repacking hdf5...")
res = subprocess.run( res = subprocess.run(
["h5repack", "-f", "GZIP=9", str(target), str(target_tmp)], capture_output=True [
"h5repack",
"--verbose=2",
"--enable-error-stack",
"-f",
"GZIP=9",
str(target),
str(target_tmp),
],
capture_output=True,
) )
if res.returncode != 0: if res.returncode != 0:
warnings.warn(f"h5repack did not return 0: {res.stderr} {res.stdout}", stacklevel=2) warnings.warn(f"h5repack did not return 0: {res.stderr} {res.stdout}", stacklevel=2)
@ -563,6 +596,36 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
target_tmp.unlink() target_tmp.unlink()
return target return target
h5f_target = h5py.File(str(target_tmp), "r+")
# recreate references after repacking, because repacking ruins them if they
# are in a compound dtype
for obj_name, obj_refs in attr_refs.items():
obj = h5f_target.get(obj_name)
for attr_name, ref_target in obj_refs.items():
ref_target = h5f_target.get(ref_target)
obj.attrs[attr_name] = ref_target.ref
for obj_name, obj_refs in dataset_refs.items():
obj = h5f_target.get(obj_name)
if isinstance(obj_refs, list):
if len(obj_refs) == 1:
ref_target = h5f_target.get(obj_refs[0])
obj[()] = ref_target.ref
else:
targets = [h5f_target.get(ref).ref for ref in obj_refs[:n]]
obj[:] = targets
else:
# dict for a compound dataset
for col_name, column_refs in obj_refs.items():
targets = [h5f_target.get(ref).ref for ref in column_refs[:n]]
data = obj[:]
data[col_name] = targets
obj[:] = data
h5f_target.flush()
h5f_target.close()
target.unlink() target.unlink()
target_tmp.rename(target) target_tmp.rename(target)

View file

@ -5,772 +5,47 @@ We have sort of diverged from the initial idea of a generalized map as in :class
so we will make our own mapping class here and re-evaluate whether they should be unified later so we will make our own mapping class here and re-evaluate whether they should be unified later
""" """
# FIXME: return and document whatever is left of this godforsaken module after refactoring
# ruff: noqa: D102 # ruff: noqa: D102
# ruff: noqa: D101 # ruff: noqa: D101
import contextlib from typing import List, Union
import datetime
import inspect
import sys
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
import h5py import h5py
from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel, ConfigDict, Field
from nwb_linkml.annotations import unwrap_optional
from nwb_linkml.maps import Map
from nwb_linkml.types.hdf5 import HDF5_Path
if sys.version_info.minor >= 11:
from enum import StrEnum
else:
from enum import Enum
class StrEnum(str, Enum):
"""StrEnum-ish class for python 3.10"""
if TYPE_CHECKING: def get_attr_references(obj: h5py.Dataset | h5py.Group) -> dict[str, str]:
from nwb_linkml.providers.schema import SchemaProvider """
Get any references in object attributes
"""
refs = {
k: obj.file.get(ref).name
for k, ref in obj.attrs.items()
if isinstance(ref, h5py.h5r.Reference)
}
return refs
class ReadPhases(StrEnum): def get_dataset_references(obj: h5py.Dataset | h5py.Group) -> list[str] | dict[str, str]:
plan = "plan"
"""Before reading starts, building an index of objects to read"""
read = "read"
"""Main reading operation"""
construct = "construct"
"""After reading, casting the results of the read into their models"""
class H5SourceItem(BaseModel):
""" """
Descriptor of items for each element when :func:`.flatten_hdf` flattens an hdf5 file. Get references in datasets
Consumed by :class:`.HDF5Map` classes, orchestrated by :class:`.ReadQueue`
"""
path: str
"""Absolute hdf5 path of element"""
h5f_path: str
"""Path to the source hdf5 file"""
leaf: bool
"""
If ``True``, this item has no children
(and thus we should start instantiating it before ascending to parent classes)
"""
h5_type: Literal["group", "dataset"]
"""What kind of hdf5 element this is"""
depends: List[str] = Field(default_factory=list)
"""
Paths of other source items that this item depends on before it can be instantiated.
eg. from softlinks
"""
attrs: dict = Field(default_factory=dict)
"""Any static attrs that can be had from the element"""
namespace: Optional[str] = None
"""Optional: The namespace that the neurodata type belongs to"""
neurodata_type: Optional[str] = None
"""Optional: the neurodata type for this dataset or group"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def parts(self) -> List[str]:
"""path split by /"""
return self.path.split("/")
class H5ReadResult(BaseModel):
"""
Result returned by each of our mapping operations.
Also used as the source for operations in the ``construct`` :class:`.ReadPhases`
"""
path: str
"""absolute hdf5 path of element"""
source: Union[H5SourceItem, "H5ReadResult"]
"""
Source that this result is based on.
The map can modify this item, so the container should update the source
queue on each pass
"""
completed: bool = False
"""
Was this item completed by this map step? False for cases where eg.
we still have dependencies that need to be completed before this one
"""
result: Optional[dict | str | int | float | BaseModel] = None
"""
If completed, built result. A dict that can be instantiated into the model.
If completed is True and result is None, then remove this object
"""
model: Optional[Type[BaseModel]] = None
"""
The model that this item should be cast into
"""
completes: List[HDF5_Path] = Field(default_factory=list)
"""
If this result completes any other fields, we remove them from the build queue.
"""
namespace: Optional[str] = None
"""
Optional: the namespace of the neurodata type for this object
"""
neurodata_type: Optional[str] = None
"""
Optional: The neurodata type to use for this object
"""
applied: List[str] = Field(default_factory=list)
"""
Which map operations were applied to this item
"""
errors: List[str] = Field(default_factory=list)
"""
Problems that occurred during resolution
"""
depends: List[HDF5_Path] = Field(default_factory=list)
"""
Other items that the final resolution of this item depends on
"""
FlatH5 = Dict[str, H5SourceItem]
class HDF5Map(Map):
phase: ReadPhases
priority: int = 0
"""
Within a phase, sort mapping operations from low to high priority
(maybe this should be renamed because highest priority last doesn't make a lot of sense)
"""
@classmethod
@abstractmethod
def check(
cls,
src: H5SourceItem | H5ReadResult,
provider: "SchemaProvider",
completed: Dict[str, H5ReadResult],
) -> bool:
"""Check if this map applies to the given item to read"""
@classmethod
@abstractmethod
def apply(
cls,
src: H5SourceItem | H5ReadResult,
provider: "SchemaProvider",
completed: Dict[str, H5ReadResult],
) -> H5ReadResult:
"""Actually apply the map!"""
# --------------------------------------------------
# Planning maps
# --------------------------------------------------
def check_empty(obj: h5py.Group) -> bool:
"""
Check if a group has no attrs or children OR has no attrs and all its children
also have no attrs and no children
Returns:
bool
""" """
refs = []
# For datasets, apply checks depending on shape of data.
if isinstance(obj, h5py.Dataset): if isinstance(obj, h5py.Dataset):
return False if obj.shape == ():
# scalar
# check if we are empty if isinstance(obj[()], h5py.h5r.Reference):
no_attrs = False refs = [obj.file.get(obj[()]).name]
if len(obj.attrs) == 0: elif len(obj) > 0 and isinstance(obj[0], h5py.h5r.Reference):
no_attrs = True # single-column
refs = [obj.file.get(ref).name for ref in obj[:]]
no_children = False elif len(obj.dtype) > 1:
if len(obj.keys()) == 0: # "compound" datasets
no_children = True refs = {}
for name in obj.dtype.names:
# check if immediate children are empty if isinstance(obj[name][0], h5py.h5r.Reference):
# handles empty groups of empty groups refs[name] = [obj.file.get(ref).name for ref in obj[name]]
children_empty = False return refs
if all(
[
isinstance(item, h5py.Group) and len(item.keys()) == 0 and len(item.attrs) == 0
for item in obj.values()
]
):
children_empty = True
# if we have no attrs and we are a leaf OR our children are empty, remove us
return bool(no_attrs and (no_children or children_empty))
class PruneEmpty(HDF5Map):
"""Remove groups with no attrs"""
phase = ReadPhases.plan
@classmethod
def check(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
if src.h5_type == "group":
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
return check_empty(obj)
@classmethod
def apply(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
return H5ReadResult.model_construct(path=src.path, source=src, completed=True)
class ResolveModelGroup(HDF5Map):
"""
HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs.
We use the model to determine what fields we should get, and then stash references
to the children to process later as :class:`.HDF5_Path`
**Special Case:** Some groups like ``ProcessingGroup`` and others that have an arbitrary
number of named children have a special ``children`` field that is a dictionary mapping
names to the objects themselves.
So for example, this:
/processing/
eye_tracking/
cr_ellipse_fits/
center_x
center_y
...
eye_ellipse_fits/
...
pupil_ellipse_fits/
...
eye_tracking_rig_metadata/
...
would pack the ``eye_tracking`` group (a ``ProcessingModule`` ) as:
{
"name": "eye_tracking",
"children": {
"cr_ellipse_fits": HDF5_Path('/processing/eye_tracking/cr_ellipse_fits'),
"eye_ellipse_fits" : HDF5_Path('/processing/eye_tracking/eye_ellipse_fits'),
...
}
}
We will do some nice things in the model metaclass to make it possible to access the children
like ``nwbfile.processing.cr_ellipse_fits.center_x``
rather than having to switch between indexing and attribute access :)
"""
phase = ReadPhases.read
priority = 10 # do this generally last
@classmethod
def check(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
return bool("neurodata_type" in src.attrs and src.h5_type == "group")
@classmethod
def apply(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
model = provider.get_class(src.namespace, src.neurodata_type)
res = {}
depends = []
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
for key in model.model_fields:
if key == "children":
res[key] = {name: resolve_hardlink(child) for name, child in obj.items()}
depends.extend([resolve_hardlink(child) for child in obj.values()])
elif key in obj.attrs:
res[key] = obj.attrs[key]
continue
elif key in obj:
# make sure it's not empty
if check_empty(obj[key]):
continue
# stash a reference to this, we'll compile it at the end
depends.append(resolve_hardlink(obj[key]))
res[key] = resolve_hardlink(obj[key])
res["hdf5_path"] = src.path
res["name"] = src.parts[-1]
return H5ReadResult(
path=src.path,
source=src,
completed=True,
result=res,
model=model,
namespace=src.namespace,
neurodata_type=src.neurodata_type,
applied=["ResolveModelGroup"],
depends=depends,
)
class ResolveDatasetAsDict(HDF5Map):
"""
Resolve datasets that do not have a ``neurodata_type`` of their own as a dictionary
that will be packaged into a model in the next step. Grabs the array in an
:class:`~numpydantic.interface.hdf5.H5ArrayPath`
under an ``array`` key, and then grabs any additional ``attrs`` as well.
Mutually exclusive with :class:`.ResolveScalars` - this only applies to datasets that are larger
than a single entry.
"""
phase = ReadPhases.read
priority = 11
@classmethod
def check(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
if src.h5_type == "dataset" and "neurodata_type" not in src.attrs:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
return obj.shape != ()
else:
return False
@classmethod
def apply(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
res = {
"array": H5ArrayPath(file=src.h5f_path, path=src.path),
"hdf5_path": src.path,
"name": src.parts[-1],
**src.attrs,
}
return H5ReadResult(
path=src.path, source=src, completed=True, result=res, applied=["ResolveDatasetAsDict"]
)
class ResolveScalars(HDF5Map):
phase = ReadPhases.read
priority = 11 # catchall
@classmethod
def check(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
if src.h5_type == "dataset" and "neurodata_type" not in src.attrs:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
return obj.shape == ()
else:
return False
@classmethod
def apply(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
res = obj[()]
return H5ReadResult(
path=src.path, source=src, completed=True, result=res, applied=["ResolveScalars"]
)
class ResolveContainerGroups(HDF5Map):
"""
Groups like ``/acquisition``` and others that have no ``neurodata_type``
(and thus no model) are returned as a dictionary with :class:`.HDF5_Path` references to
the children they contain
"""
phase = ReadPhases.read
priority = 9
@classmethod
def check(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
if src.h5_type == "group" and "neurodata_type" not in src.attrs and len(src.attrs) == 0:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
return len(obj.keys()) > 0
else:
return False
@classmethod
def apply(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
"""Simple, just return a dict with references to its children"""
depends = []
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
children = {}
for k, v in obj.items():
children[k] = HDF5_Path(v.name)
depends.append(HDF5_Path(v.name))
# res = {
# 'name': src.parts[-1],
# 'hdf5_path': src.path,
# **children
# }
return H5ReadResult(
path=src.path,
source=src,
completed=True,
result=children,
depends=depends,
applied=["ResolveContainerGroups"],
)
# --------------------------------------------------
# Completion Steps
# --------------------------------------------------
class CompletePassThrough(HDF5Map):
"""
Passthrough map for the construction phase for models that don't need any more work done
- :class:`.ResolveDynamicTable`
- :class:`.ResolveDatasetAsDict`
- :class:`.ResolveScalars`
"""
phase = ReadPhases.construct
priority = 1
@classmethod
def check(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
passthrough_ops = ("ResolveDynamicTable", "ResolveDatasetAsDict", "ResolveScalars")
return any(hasattr(src, "applied") and op in src.applied for op in passthrough_ops)
@classmethod
def apply(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
return src
class CompleteContainerGroups(HDF5Map):
"""
Complete container groups (usually top-level groups like /acquisition)
that do not have a ndueodata type of their own by resolving them as dictionaries
of values (that will then be given to their parent model)
"""
phase = ReadPhases.construct
priority = 3
@classmethod
def check(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
return (
src.model is None
and src.neurodata_type is None
and src.source.h5_type == "group"
and all([depend in completed for depend in src.depends])
)
@classmethod
def apply(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
res, errors, completes = resolve_references(src.result, completed)
return H5ReadResult(
result=res,
errors=errors,
completes=completes,
**src.model_dump(exclude={"result", "errors", "completes"}),
)
class CompleteModelGroups(HDF5Map):
phase = ReadPhases.construct
priority = 4
@classmethod
def check(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
return (
src.model is not None
and src.source.h5_type == "group"
and src.neurodata_type != "NWBFile"
and all([depend in completed for depend in src.depends])
)
@classmethod
def apply(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
# gather any results that were left for completion elsewhere
# first get all already-completed items
res = {k: v for k, v in src.result.items() if not isinstance(v, HDF5_Path)}
unpacked_results, errors, completes = resolve_references(src.result, completed)
res.update(unpacked_results)
# now that we have the model in hand, we can solve any datasets that had an array
# but whose attributes are fixed (and thus should just be an array, rather than a subclass)
for k, v in src.model.model_fields.items():
annotation = unwrap_optional(v.annotation)
if (
inspect.isclass(annotation)
and not issubclass(annotation, BaseModel)
and isinstance(res, dict)
and k in res
and isinstance(res[k], dict)
and "array" in res[k]
):
res[k] = res[k]["array"]
instance = src.model(**res)
return H5ReadResult(
path=src.path,
source=src,
result=instance,
model=src.model,
completed=True,
completes=completes,
neurodata_type=src.neurodata_type,
namespace=src.namespace,
applied=src.applied + ["CompleteModelGroups"],
errors=errors,
)
class CompleteNWBFile(HDF5Map):
"""
The Top-Level NWBFile class is so special cased we just make its own completion special case!
.. todo::
This is truly hideous, just meant as a way to get to the finish line on a late night,
will be cleaned up later
"""
phase = ReadPhases.construct
priority = 11
@classmethod
def check(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
return src.neurodata_type == "NWBFile" and all(
[depend in completed for depend in src.depends]
)
@classmethod
def apply(
cls, src: H5ReadResult, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
res = {k: v for k, v in src.result.items() if not isinstance(v, HDF5_Path)}
unpacked_results, errors, completes = resolve_references(src.result, completed)
res.update(unpacked_results)
res["name"] = "root"
res["file_create_date"] = [
datetime.datetime.fromisoformat(ts.decode("utf-8"))
for ts in res["file_create_date"]["array"][:]
]
if "stimulus" not in res:
res["stimulus"] = provider.get_class("core", "NWBFileStimulus")()
electrode_groups = []
egroup_keys = list(res["general"].get("extracellular_ephys", {}).keys())
egroup_dict = {}
for k in egroup_keys:
if k != "electrodes":
egroup = res["general"]["extracellular_ephys"][k]
electrode_groups.append(egroup)
egroup_dict[egroup.hdf5_path] = egroup
del res["general"]["extracellular_ephys"][k]
if len(electrode_groups) > 0:
res["general"]["extracellular_ephys"]["electrode_group"] = electrode_groups
trode_type = provider.get_class("core", "NWBFileGeneralExtracellularEphysElectrodes")
# anmro = list(type(res['general']['extracellular_ephys']['electrodes']).__mro__)
# anmro.insert(1, trode_type)
trodes_original = res["general"]["extracellular_ephys"]["electrodes"]
trodes = trode_type.model_construct(trodes_original.model_dump())
res["general"]["extracellular_ephys"]["electrodes"] = trodes
instance = src.model(**res)
return H5ReadResult(
path=src.path,
source=src,
result=instance,
model=src.model,
completed=True,
completes=completes,
neurodata_type=src.neurodata_type,
namespace=src.namespace,
applied=src.applied + ["CompleteModelGroups"],
errors=errors,
)
class ReadQueue(BaseModel):
"""Container model to store items as they are built"""
h5f: Path = Field(
description=(
"Path to the source hdf5 file used when resolving the queue! "
"Each translation step should handle opening and closing the file, "
"rather than passing a handle around"
)
)
provider: "SchemaProvider" = Field(
description="SchemaProvider used by each of the items in the read queue"
)
queue: Dict[str, H5SourceItem | H5ReadResult] = Field(
default_factory=dict,
description="Items left to be instantiated, keyed by hdf5 path",
)
completed: Dict[str, H5ReadResult] = Field(
default_factory=dict,
description="Items that have already been instantiated, keyed by hdf5 path",
)
model_config = ConfigDict(arbitrary_types_allowed=True)
phases_completed: List[ReadPhases] = Field(
default_factory=list, description="Phases that have already been completed"
)
def apply_phase(self, phase: ReadPhases, max_passes: int = 5) -> None:
phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase]
phase_maps = sorted(phase_maps, key=lambda x: x.priority)
results = []
# TODO: Thread/multiprocess this
for item in self.queue.values():
for op in phase_maps:
if op.check(item, self.provider, self.completed):
# Formerly there was an "exclusive" property in the maps which let
# potentially multiple operations be applied per stage,
# except if an operation was `exclusive` which would break
# iteration over the operations.
# This was removed because it was badly implemented,
# but if there is ever a need to do that,
# then we would need to decide what to do with the multiple results.
results.append(op.apply(item, self.provider, self.completed))
break # out of inner iteration
# remake the source queue and save results
completes = []
for res in results:
# remove the original item
del self.queue[res.path]
if res.completed:
# if the item has been finished and there is some result, add it to the results
if res.result is not None:
self.completed[res.path] = res
# otherwise if the item has been completed and there was no result,
# just drop it.
# if we have completed other things, delete them from the queue
completes.extend(res.completes)
else:
# if we didn't complete the item (eg. we found we needed more dependencies),
# add the updated source to the queue again
if phase != ReadPhases.construct:
self.queue[res.path] = res.source
else:
self.queue[res.path] = res
# delete the ones that were already completed but might have been
# incorrectly added back in the pile
for c in completes:
with contextlib.suppress(KeyError):
del self.queue[c]
# if we have nothing left in our queue, we have completed this phase
# and prepare only ever has one pass
if phase == ReadPhases.plan:
self.phases_completed.append(phase)
return
if len(self.queue) == 0:
self.phases_completed.append(phase)
if phase != ReadPhases.construct:
# if we're not in the last phase, move our completed to our queue
self.queue = self.completed
self.completed = {}
elif max_passes > 0:
self.apply_phase(phase, max_passes=max_passes - 1)
def flatten_hdf(
h5f: h5py.File | h5py.Group, skip: str = "specifications"
) -> Dict[str, H5SourceItem]:
"""
Flatten all child elements of hdf element into a dict of :class:`.H5SourceItem` s
keyed by their path
Args:
h5f (:class:`h5py.File` | :class:`h5py.Group`): HDF file or group to flatten!
"""
items = {}
def _itemize(name: str, obj: h5py.Dataset | h5py.Group) -> None:
if skip in name:
return
leaf = isinstance(obj, h5py.Dataset) or len(obj.keys()) == 0
if isinstance(obj, h5py.Dataset):
h5_type = "dataset"
elif isinstance(obj, h5py.Group):
h5_type = "group"
else:
raise ValueError(f"Object must be a dataset or group! {obj}")
# get references in attrs and datasets to populate dependencies
# depends = get_references(obj)
if not name.startswith("/"):
name = "/" + name
attrs = dict(obj.attrs.items())
items[name] = H5SourceItem.model_construct(
path=name,
h5f_path=h5f.file.filename,
leaf=leaf,
# depends = depends,
h5_type=h5_type,
attrs=attrs,
namespace=attrs.get("namespace"),
neurodata_type=attrs.get("neurodata_type"),
)
h5f.visititems(_itemize)
# then add the root item
_itemize(h5f.name, h5f)
return items
def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]: def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]:
@ -791,57 +66,18 @@ def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]:
List[str]: List of paths that are referenced within this object List[str]: List of paths that are referenced within this object
""" """
# Find references in attrs # Find references in attrs
refs = [ref for ref in obj.attrs.values() if isinstance(ref, h5py.h5r.Reference)] attr_refs = get_attr_references(obj)
dataset_refs = get_dataset_references(obj)
# For datasets, apply checks depending on shape of data. # flatten to list
if isinstance(obj, h5py.Dataset): refs = [ref for ref in attr_refs.values()]
if obj.shape == (): if isinstance(dataset_refs, list):
# scalar refs.extend(dataset_refs)
if isinstance(obj[()], h5py.h5r.Reference):
refs.append(obj[()])
elif len(obj) > 0 and isinstance(obj[0], h5py.h5r.Reference):
# single-column
refs.extend(obj[:].tolist())
elif len(obj.dtype) > 1:
# "compound" datasets
for name in obj.dtype.names:
if isinstance(obj[name][0], h5py.h5r.Reference):
refs.extend(obj[name].tolist())
# dereference and get name of reference
if isinstance(obj, h5py.Dataset):
depends = list(set([obj.parent.get(i).name for i in refs]))
else: else:
depends = list(set([obj.get(i).name for i in refs])) for v in dataset_refs.values():
return depends refs.extend(v)
return refs
def resolve_references(
src: dict, completed: Dict[str, H5ReadResult]
) -> Tuple[dict, List[str], List[HDF5_Path]]:
"""
Recursively replace references to other completed items with their results
"""
completes = []
errors = []
res = {}
for path, item in src.items():
if isinstance(item, HDF5_Path):
other_item = completed.get(item)
if other_item is None:
errors.append(f"Couldn't find: {item}")
res[path] = other_item.result
completes.append(item)
elif isinstance(item, dict):
inner_res, inner_error, inner_completes = resolve_references(item, completed)
res[path] = inner_res
errors.extend(inner_error)
completes.extend(inner_completes)
else:
res[path] = item
return res, errors, completes
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> str: def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> str:

Binary file not shown.

View file

@ -15,7 +15,12 @@ def tmp_output_dir(request: pytest.FixtureRequest) -> Path:
if subdir.name == "git": if subdir.name == "git":
# don't wipe out git repos every time, they don't rly change # don't wipe out git repos every time, they don't rly change
continue continue
elif subdir.is_file() and subdir.parent != path: elif (
subdir.is_file()
and subdir.parent != path
or subdir.is_file()
and subdir.suffix == ".nwb"
):
continue continue
elif subdir.is_file(): elif subdir.is_file():
subdir.unlink(missing_ok=True) subdir.unlink(missing_ok=True)
@ -54,5 +59,5 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def data_dir() -> Path: def data_dir() -> Path:
path = Path(__file__).parent.resolve() / "data" path = Path(__file__).parents[1].resolve() / "data"
return path return path

View file

@ -1,11 +1,10 @@
import pdb
import h5py import h5py
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
from nwb_linkml.io.hdf5 import HDF5IO, filter_dependency_graph, hdf_dependency_graph, truncate_file from nwb_linkml.io.hdf5 import HDF5IO, filter_dependency_graph, hdf_dependency_graph, truncate_file
from nwb_linkml.maps.hdf5 import resolve_hardlink
@pytest.mark.skip() @pytest.mark.skip()
@ -14,7 +13,7 @@ def test_hdf_read(data_dir, dset):
NWBFILE = data_dir / dset NWBFILE = data_dir / dset
io = HDF5IO(path=NWBFILE) io = HDF5IO(path=NWBFILE)
# the test for now is just whether we can read it lol # the test for now is just whether we can read it lol
model = io.read() _ = io.read()
def test_truncate_file(tmp_output_dir): def test_truncate_file(tmp_output_dir):
@ -87,35 +86,6 @@ def test_truncate_file(tmp_output_dir):
assert target_h5f["data"]["dataset_contig"].attrs["anattr"] == 1 assert target_h5f["data"]["dataset_contig"].attrs["anattr"] == 1
@pytest.mark.skip()
def test_flatten_hdf():
from nwb_linkml.maps.hdf5 import flatten_hdf
path = "/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb"
import h5py
h5f = h5py.File(path)
flat = flatten_hdf(h5f)
assert not any(["specifications" in v.path for v in flat.values()])
pdb.set_trace()
raise NotImplementedError("Just a stub for local testing for now, finish me!")
@pytest.mark.dev
def test_dependency_graph(nwb_file, tmp_output_dir):
"""
dependency graph is correctly constructed from an HDF5 file
"""
graph = hdf_dependency_graph(nwb_file)
A_unfiltered = nx.nx_agraph.to_agraph(graph)
A_unfiltered.draw(tmp_output_dir / "test_nwb_unfiltered.png", prog="dot")
graph = filter_dependency_graph(graph)
A_filtered = nx.nx_agraph.to_agraph(graph)
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
pass
@pytest.mark.skip
def test_dependencies_hardlink(nwb_file): def test_dependencies_hardlink(nwb_file):
""" """
Test that hardlinks are resolved (eg. from /processing/ecephys/LFP/ElectricalSeries/electrodes Test that hardlinks are resolved (eg. from /processing/ecephys/LFP/ElectricalSeries/electrodes
@ -126,4 +96,50 @@ def test_dependencies_hardlink(nwb_file):
Returns: Returns:
""" """
pass parent = "/processing/ecephys/LFP/ElectricalSeries"
source = "/processing/ecephys/LFP/ElectricalSeries/electrodes"
target = "/acquisition/ElectricalSeries/electrodes"
# assert that the hardlink exists in the test file
with h5py.File(str(nwb_file), "r") as h5f:
node = h5f.get(source)
linked_node = resolve_hardlink(node)
assert linked_node == target
graph = hdf_dependency_graph(nwb_file)
# the parent should link to the target as a child
assert (parent, target) in graph.edges([parent])
assert graph.edges[parent, target]["label"] == "child"
@pytest.mark.dev
def test_dependency_graph_images(nwb_file, tmp_output_dir):
"""
Generate images of the dependency graph
"""
graph = hdf_dependency_graph(nwb_file)
A_unfiltered = nx.nx_agraph.to_agraph(graph)
A_unfiltered.draw(tmp_output_dir / "test_nwb_unfiltered.png", prog="dot")
graph = filter_dependency_graph(graph)
A_filtered = nx.nx_agraph.to_agraph(graph)
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
@pytest.mark.parametrize(
"dset",
[
{"name": "aibs.nwb", "source": "sub-738651046_ses-760693773.nwb"},
{
"name": "aibs_ecephys.nwb",
"source": "sub-738651046_ses-760693773_probe-769322820_ecephys.nwb",
},
],
)
@pytest.mark.dev
def test_make_truncated_datasets(tmp_output_dir, data_dir, dset):
input_file = tmp_output_dir / dset["source"]
output_file = data_dir / dset["name"]
if not input_file.exists():
return
truncate_file(input_file, output_file, 10)