tidy hdf5 io module, copy rather than get references

This commit is contained in:
sneakers-the-rat 2024-10-03 00:10:10 -07:00
parent 9560b9f839
commit 748b304426
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 320 additions and 343 deletions

View file

@ -29,7 +29,7 @@ import sys
import warnings import warnings
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload
import h5py import h5py
import networkx as nx import networkx as nx
@ -38,13 +38,6 @@ from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
from nwb_linkml.maps.hdf5 import (
get_attr_references,
get_dataset_references,
get_references,
resolve_hardlink,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nwb_linkml.providers.schema import SchemaProvider from nwb_linkml.providers.schema import SchemaProvider
from nwb_models.models import NWBFile from nwb_models.models import NWBFile
@ -59,222 +52,6 @@ SKIP_PATTERN = re.compile("(^/specifications.*)|(\.specloc)")
"""Nodes to always skip in reading e.g. because they are handled elsewhere""" """Nodes to always skip in reading e.g. because they are handled elsewhere"""
def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
"""
Directed dependency graph of dataset and group nodes in an NWBFile such that
each node ``n_i`` is connected to node ``n_j`` if
* ``n_j`` is ``n_i``'s child
* ``n_i`` contains a reference to ``n_j``
Resolve references in
* Attributes
* Dataset columns
* Compound dtypes
Edges are labeled with ``reference`` or ``child`` depending on the type of edge it is,
and attributes from the hdf5 file are added as node attributes.
Args:
h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph
Returns:
:class:`networkx.DiGraph`
"""
if isinstance(h5f, (Path, str)):
h5f = h5py.File(h5f, "r")
g = nx.DiGraph()
def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None:
if SKIP_PATTERN.match(node.name):
return
# find references in attributes
refs = get_references(node)
# add edges from references
edges = [(node.name, ref) for ref in refs if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="reference")
# add children, if group
if isinstance(node, h5py.Group):
children = [
resolve_hardlink(child)
for child in node.values()
if not SKIP_PATTERN.match(child.name)
]
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="child")
# ensure node added to graph
if len(edges) == 0:
g.add_node(node.name)
# store attrs in node
g.nodes[node.name].update(node.attrs)
# apply to root
_visit_item(h5f.name, h5f)
h5f.visititems(_visit_item)
return g
def filter_dependency_graph(g: nx.DiGraph) -> nx.DiGraph:
"""
Remove nodes from a dependency graph if they
* have no neurodata type AND
* have no outbound edges
OR
* are a VectorIndex (which are handled by the dynamictable mixins)
"""
remove_nodes = []
node: str
for node in g.nodes:
ndtype = g.nodes[node].get("neurodata_type", None)
if (ndtype is None and g.out_degree(node) == 0) or SKIP_PATTERN.match(node):
remove_nodes.append(node)
g.remove_nodes_from(remove_nodes)
return g
def _load_node(
path: str, h5f: h5py.File, provider: "SchemaProvider", context: dict
) -> dict | BaseModel:
"""
Load an individual node in the graph, then removes it from the graph
Args:
path:
g:
context:
Returns:
"""
obj = h5f.get(path)
if isinstance(obj, h5py.Dataset):
args = _load_dataset(obj, h5f, context)
elif isinstance(obj, h5py.Group):
args = _load_group(obj, h5f, context)
else:
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
if "neurodata_type" in obj.attrs:
# SPECIAL CASE: ignore `.specloc`
if ".specloc" in args:
del args[".specloc"]
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
return model(**args)
else:
if "name" in args:
del args["name"]
if "hdf5_path" in args:
del args["hdf5_path"]
return args
def _load_dataset(
dataset: h5py.Dataset, h5f: h5py.File, context: dict
) -> Union[dict, str, int, float]:
"""
Resolves datasets that do not have a ``neurodata_type`` as a dictionary or a scalar.
If the dataset is a single value without attrs, load it and return as a scalar value.
Otherwise return a :class:`.H5ArrayPath` as a reference to the dataset in the `value` key.
"""
res = {}
if dataset.shape == ():
val = dataset[()]
if isinstance(val, h5py.h5r.Reference):
val = context.get(h5f[val].name)
# if this is just a scalar value, return it
if not dataset.attrs:
return val
res["value"] = val
elif len(dataset) > 0 and isinstance(dataset[0], h5py.h5r.Reference):
# vector of references
res["value"] = [context.get(h5f[ref].name) for ref in dataset[:]]
elif len(dataset.dtype) > 1:
# compound dataset - check if any of the fields are references
for name in dataset.dtype.names:
if isinstance(dataset[name][0], h5py.h5r.Reference):
res[name] = [context.get(h5f[ref].name) for ref in dataset[name]]
else:
res[name] = H5ArrayPath(h5f.filename, dataset.name, name)
else:
res["value"] = H5ArrayPath(h5f.filename, dataset.name)
res.update(dataset.attrs)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
res["name"] = dataset.name.split("/")[-1]
res["hdf5_path"] = dataset.name
# resolve attr references
for k, v in res.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
if SKIP_PATTERN.match(ref_path):
res[k] = ref_path
else:
res[k] = context[ref_path]
if len(res) == 1:
return res["value"]
else:
return res
def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
"""
Load a group!
"""
res = {}
res.update(group.attrs)
for child_name, child in group.items():
if child.name in context:
res[child_name] = context[child.name]
elif isinstance(child, h5py.Dataset):
res[child_name] = _load_dataset(child, h5f, context)
elif isinstance(child, h5py.Group):
res[child_name] = _load_group(child, h5f, context)
else:
raise TypeError(
"Can only handle preinstantiated child objects in context, datasets, and group,"
f" got {child} for {child_name}"
)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
name = group.name.split("/")[-1]
if name:
res["name"] = name
res["hdf5_path"] = group.name
# resolve attr references
for k, v in res.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
if SKIP_PATTERN.match(ref_path):
res[k] = ref_path
else:
res[k] = context[ref_path]
return res
class HDF5IO: class HDF5IO:
""" """
Read (and eventually write) from an NWB HDF5 file. Read (and eventually write) from an NWB HDF5 file.
@ -294,32 +71,9 @@ class HDF5IO:
""" """
Read data into models from an NWB File. Read data into models from an NWB File.
The read process is in several stages:
* Use :meth:`.make_provider` to generate any needed LinkML Schema or Pydantic Classes
using a :class:`.SchemaProvider`
* :func:`flatten_hdf` file into a :class:`.ReadQueue` of nodes.
* Apply the queue's :class:`ReadPhases` :
* ``plan`` - trim any blank nodes, sort nodes to read, etc.
* ``read`` - load the actual data into temporary holding objects
* ``construct`` - cast the read data into models.
Read is split into stages like this to handle references between objects,
where the read result of one node
might depend on another having already been completed.
It also allows us to parallelize the operations
since each mapping operation is independent of the results of all the others in that pass.
.. todo:: .. todo::
Implement reading, skipping arrays - they are fast to read with the ArrayProxy class Document this!
and dask, but there are times when we might want to leave them out of the read entirely.
This might be better implemented as a filter on ``model_dump`` ,
but to investigate further how best to support reading just metadata,
or even some specific field value, or if
we should leave that to other implementations like eg. after we do SQL export then
not rig up a whole query system ourselves.
Args: Args:
path (Optional[str]): If ``None`` (default), read whole file. path (Optional[str]): If ``None`` (default), read whole file.
@ -408,6 +162,240 @@ class HDF5IO:
return provider return provider
def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
"""
Directed dependency graph of dataset and group nodes in an NWBFile such that
each node ``n_i`` is connected to node ``n_j`` if
* ``n_j`` is ``n_i``'s child
* ``n_i`` contains a reference to ``n_j``
Resolve references in
* Attributes
* Dataset columns
* Compound dtypes
Edges are labeled with ``reference`` or ``child`` depending on the type of edge it is,
and attributes from the hdf5 file are added as node attributes.
Args:
h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph
Returns:
:class:`networkx.DiGraph`
"""
if isinstance(h5f, (Path, str)):
h5f = h5py.File(h5f, "r")
g = nx.DiGraph()
def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None:
if SKIP_PATTERN.match(node.name):
return
# find references in attributes
refs = get_references(node)
# add edges from references
edges = [(node.name, ref) for ref in refs if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="reference")
# add children, if group
if isinstance(node, h5py.Group):
children = [
resolve_hardlink(child)
for child in node.values()
if not SKIP_PATTERN.match(child.name)
]
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="child")
# ensure node added to graph
if len(edges) == 0:
g.add_node(node.name)
# store attrs in node
g.nodes[node.name].update(node.attrs)
# apply to root
_visit_item(h5f.name, h5f)
h5f.visititems(_visit_item)
return g
def filter_dependency_graph(g: nx.DiGraph) -> nx.DiGraph:
"""
Remove nodes from a dependency graph if they
* have no neurodata type AND
* have no outbound edges
OR
* They match the :ref:`.SKIP_PATTERN`
"""
remove_nodes = []
node: str
for node in g.nodes:
ndtype = g.nodes[node].get("neurodata_type", None)
if (ndtype is None and g.out_degree(node) == 0) or SKIP_PATTERN.match(node):
remove_nodes.append(node)
g.remove_nodes_from(remove_nodes)
return g
def _load_node(
path: str, h5f: h5py.File, provider: "SchemaProvider", context: dict
) -> dict | BaseModel:
"""
Load an individual node in the graph, then removes it from the graph
Args:
path:
g:
context:
Returns:
"""
obj = h5f.get(path)
if isinstance(obj, h5py.Dataset):
args = _load_dataset(obj, h5f, context)
elif isinstance(obj, h5py.Group):
args = _load_group(obj, h5f, context)
else:
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
if "neurodata_type" in obj.attrs:
# SPECIAL CASE: ignore `.specloc`
if ".specloc" in args:
del args[".specloc"]
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
return model(**args)
else:
if "name" in args:
del args["name"]
if "hdf5_path" in args:
del args["hdf5_path"]
return args
def _load_dataset(
dataset: h5py.Dataset, h5f: h5py.File, context: dict
) -> Union[dict, str, int, float]:
"""
Resolves datasets that do not have a ``neurodata_type`` as a dictionary or a scalar.
If the dataset is a single value without attrs, load it and return as a scalar value.
Otherwise return a :class:`.H5ArrayPath` as a reference to the dataset in the `value` key.
"""
res = {}
if dataset.shape == ():
val = dataset[()]
if isinstance(val, h5py.h5r.Reference):
val = _copy(context.get(h5f[val].name))
# if this is just a scalar value, return it
if not dataset.attrs:
return val
res["value"] = val
elif len(dataset) > 0 and isinstance(dataset[0], h5py.h5r.Reference):
# vector of references
res["value"] = [_copy(context.get(h5f[ref].name)) for ref in dataset[:]]
elif len(dataset.dtype) > 1:
# compound dataset - check if any of the fields are references
for name in dataset.dtype.names:
if isinstance(dataset[name][0], h5py.h5r.Reference):
res[name] = [_copy(context.get(h5f[ref].name)) for ref in dataset[name]]
else:
res[name] = H5ArrayPath(h5f.filename, dataset.name, name)
else:
res["value"] = H5ArrayPath(h5f.filename, dataset.name)
res.update(dataset.attrs)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
res["name"] = dataset.name.split("/")[-1]
res["hdf5_path"] = dataset.name
# resolve attr references
res = _resolve_attr_references(res, h5f, context)
if len(res) == 1 and "value" in res:
return res["value"]
else:
return res
def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
"""
Load a group!
"""
res = {}
res.update(group.attrs)
for child_name, child in group.items():
if child.name in context:
res[child_name] = _copy(context[child.name])
elif isinstance(child, h5py.Dataset):
res[child_name] = _load_dataset(child, h5f, context)
elif isinstance(child, h5py.Group):
res[child_name] = _load_group(child, h5f, context)
else:
raise TypeError(
"Can only handle preinstantiated child objects in context, datasets, and group,"
f" got {child} for {child_name}"
)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
name = group.name.split("/")[-1]
if name:
res["name"] = name
res["hdf5_path"] = group.name
res = _resolve_attr_references(res, h5f, context)
return res
def _resolve_attr_references(res: dict, h5f: h5py.File, context: dict) -> dict:
"""Resolve references to objects that have already been created"""
for k, v in res.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
if SKIP_PATTERN.match(ref_path):
res[k] = ref_path
else:
res[k] = _copy(context[ref_path])
return res
def _copy(obj: Any) -> Any:
"""
Get a copy of an object, using model_copy if we're a pydantic model.
Used to get shallow copies to avoid object ID overlaps while dumping,
pydantic treats any repeat appearance of an id
"""
if isinstance(obj, BaseModel):
return obj.model_copy()
else:
try:
return obj.copy()
except AttributeError:
# no copy method, fine
return obj
def read_specs_as_dicts(group: h5py.Group) -> dict: def read_specs_as_dicts(group: h5py.Group) -> dict:
""" """
Utility function to iterate through the `/specifications` group and Utility function to iterate through the `/specifications` group and
@ -491,6 +479,90 @@ def find_references(h5f: h5py.File, path: str) -> List[str]:
return references return references
def get_attr_references(obj: h5py.Dataset | h5py.Group) -> dict[str, str]:
"""
Get any references in object attributes
"""
refs = {
k: obj.file.get(ref).name
for k, ref in obj.attrs.items()
if isinstance(ref, h5py.h5r.Reference)
}
return refs
def get_dataset_references(obj: h5py.Dataset | h5py.Group) -> list[str] | dict[str, str]:
"""
Get references in datasets
"""
refs = []
# For datasets, apply checks depending on shape of data.
if isinstance(obj, h5py.Dataset):
if obj.shape == ():
# scalar
if isinstance(obj[()], h5py.h5r.Reference):
refs = [obj.file.get(obj[()]).name]
elif len(obj) > 0 and isinstance(obj[0], h5py.h5r.Reference):
# single-column
refs = [obj.file.get(ref).name for ref in obj[:]]
elif len(obj.dtype) > 1:
# "compound" datasets
refs = {}
for name in obj.dtype.names:
if isinstance(obj[name][0], h5py.h5r.Reference):
refs[name] = [obj.file.get(ref).name for ref in obj[name]]
return refs
def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]:
"""
Find all hdf5 object references in a dataset or group
Locate references in
* Attrs
* Scalar datasets
* Single-column datasets
* Multi-column datasets
Distinct from :func:`.find_references` which finds a references *to* an object.
Args:
obj (:class:`h5py.Dataset` | :class:`h5py.Group`): Object to evaluate
Returns:
List[str]: List of paths that are referenced within this object
"""
# Find references in attrs
attr_refs = get_attr_references(obj)
dataset_refs = get_dataset_references(obj)
# flatten to list
refs = [ref for ref in attr_refs.values()]
if isinstance(dataset_refs, list):
refs.extend(dataset_refs)
else:
for v in dataset_refs.values():
refs.extend(v)
return refs
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> str:
"""
Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have
to do extra work to check if an item is "real" or a hardlink to another item.
Particularly, an item will be excluded from the ``visititems`` method used by
:func:`.flatten_hdf` if it is a hardlink rather than an "original" dataset,
meaning that we don't even have them in our sources list when start reading.
We basically dereference the object and return that path instead of the path
given by the object's ``name``
"""
return obj.file[obj.ref].name
def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> Path | None: def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> Path | None:
""" """
Create a truncated HDF5 file where only the first few samples are kept. Create a truncated HDF5 file where only the first few samples are kept.

View file

@ -1,95 +0,0 @@
"""
Maps for reading and writing from HDF5
We have sort of diverged from the initial idea of a generalized map as in :class:`linkml.map.Map` ,
so we will make our own mapping class here and re-evaluate whether they should be unified later
"""
# ruff: noqa: D102
# ruff: noqa: D101
from typing import List, Union
import h5py
def get_attr_references(obj: h5py.Dataset | h5py.Group) -> dict[str, str]:
"""
Get any references in object attributes
"""
refs = {
k: obj.file.get(ref).name
for k, ref in obj.attrs.items()
if isinstance(ref, h5py.h5r.Reference)
}
return refs
def get_dataset_references(obj: h5py.Dataset | h5py.Group) -> list[str] | dict[str, str]:
"""
Get references in datasets
"""
refs = []
# For datasets, apply checks depending on shape of data.
if isinstance(obj, h5py.Dataset):
if obj.shape == ():
# scalar
if isinstance(obj[()], h5py.h5r.Reference):
refs = [obj.file.get(obj[()]).name]
elif len(obj) > 0 and isinstance(obj[0], h5py.h5r.Reference):
# single-column
refs = [obj.file.get(ref).name for ref in obj[:]]
elif len(obj.dtype) > 1:
# "compound" datasets
refs = {}
for name in obj.dtype.names:
if isinstance(obj[name][0], h5py.h5r.Reference):
refs[name] = [obj.file.get(ref).name for ref in obj[name]]
return refs
def get_references(obj: h5py.Dataset | h5py.Group) -> List[str]:
"""
Find all hdf5 object references in a dataset or group
Locate references in
* Attrs
* Scalar datasets
* Single-column datasets
* Multi-column datasets
Args:
obj (:class:`h5py.Dataset` | :class:`h5py.Group`): Object to evaluate
Returns:
List[str]: List of paths that are referenced within this object
"""
# Find references in attrs
attr_refs = get_attr_references(obj)
dataset_refs = get_dataset_references(obj)
# flatten to list
refs = [ref for ref in attr_refs.values()]
if isinstance(dataset_refs, list):
refs.extend(dataset_refs)
else:
for v in dataset_refs.values():
refs.extend(v)
return refs
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> str:
"""
Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have
to do extra work to check if an item is "real" or a hardlink to another item.
Particularly, an item will be excluded from the ``visititems`` method used by
:func:`.flatten_hdf` if it is a hardlink rather than an "original" dataset,
meaning that we don't even have them in our sources list when start reading.
We basically dereference the object and return that path instead of the path
given by the object's ``name``
"""
return obj.file[obj.ref].name