mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-09 21:54:27 +00:00
checkpointing working on model loading. it's a sloggggggggg
This commit is contained in:
parent
cd3d7ca78e
commit
d1498a3733
11 changed files with 258 additions and 127 deletions
|
@ -22,7 +22,7 @@ dependencies = [
|
||||||
"pydantic-settings>=2.0.3",
|
"pydantic-settings>=2.0.3",
|
||||||
"tqdm>=4.66.1",
|
"tqdm>=4.66.1",
|
||||||
'typing-extensions>=4.12.2;python_version<"3.11"',
|
'typing-extensions>=4.12.2;python_version<"3.11"',
|
||||||
"numpydantic>=1.3.3",
|
"numpydantic>=1.5.0",
|
||||||
"black>=24.4.2",
|
"black>=24.4.2",
|
||||||
"pandas>=2.2.2",
|
"pandas>=2.2.2",
|
||||||
"networkx>=3.3",
|
"networkx>=3.3",
|
||||||
|
|
|
@ -10,7 +10,7 @@ import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import ClassVar, Dict, List, Optional, Tuple
|
from typing import ClassVar, Dict, List, Optional, Tuple, Literal
|
||||||
|
|
||||||
from linkml.generators import PydanticGenerator
|
from linkml.generators import PydanticGenerator
|
||||||
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
|
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
|
||||||
|
@ -27,7 +27,7 @@ from linkml_runtime.utils.compile_python import file_text
|
||||||
from linkml_runtime.utils.formatutils import remove_empty_items
|
from linkml_runtime.utils.formatutils import remove_empty_items
|
||||||
from linkml_runtime.utils.schemaview import SchemaView
|
from linkml_runtime.utils.schemaview import SchemaView
|
||||||
|
|
||||||
from nwb_linkml.includes.base import BASEMODEL_GETITEM
|
from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE
|
||||||
from nwb_linkml.includes.hdmf import (
|
from nwb_linkml.includes.hdmf import (
|
||||||
DYNAMIC_TABLE_IMPORTS,
|
DYNAMIC_TABLE_IMPORTS,
|
||||||
DYNAMIC_TABLE_INJECTS,
|
DYNAMIC_TABLE_INJECTS,
|
||||||
|
@ -52,6 +52,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
),
|
),
|
||||||
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
|
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
|
||||||
BASEMODEL_GETITEM,
|
BASEMODEL_GETITEM,
|
||||||
|
BASEMODEL_COERCE_VALUE,
|
||||||
)
|
)
|
||||||
split: bool = True
|
split: bool = True
|
||||||
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
|
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
|
||||||
|
@ -66,6 +67,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
emit_metadata: bool = True
|
emit_metadata: bool = True
|
||||||
gen_classvars: bool = True
|
gen_classvars: bool = True
|
||||||
gen_slots: bool = True
|
gen_slots: bool = True
|
||||||
|
extra_fields: Literal["allow", "forbid", "ignore"] = "allow"
|
||||||
|
|
||||||
skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias")
|
skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias")
|
||||||
|
|
||||||
|
@ -131,6 +133,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
"""Customize dynamictable behavior"""
|
"""Customize dynamictable behavior"""
|
||||||
cls = AfterGenerateClass.inject_dynamictable(cls)
|
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||||
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
||||||
|
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
||||||
|
@ -255,7 +258,7 @@ class AfterGenerateClass:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if cls.cls.name in "DynamicTable":
|
if cls.cls.name in "DynamicTable":
|
||||||
cls.cls.bases = ["DynamicTableMixin"]
|
cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"]
|
||||||
|
|
||||||
if cls.injected_classes is None:
|
if cls.injected_classes is None:
|
||||||
cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy()
|
cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy()
|
||||||
|
@ -269,13 +272,21 @@ class AfterGenerateClass:
|
||||||
else:
|
else:
|
||||||
cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy()
|
cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy()
|
||||||
elif cls.cls.name == "VectorData":
|
elif cls.cls.name == "VectorData":
|
||||||
cls.cls.bases = ["VectorDataMixin"]
|
cls.cls.bases = ["VectorDataMixin", "ConfiguredBaseModel"]
|
||||||
|
# make ``value`` generic on T
|
||||||
|
if "value" in cls.cls.attributes:
|
||||||
|
cls.cls.attributes["value"].range = "Optional[T]"
|
||||||
elif cls.cls.name == "VectorIndex":
|
elif cls.cls.name == "VectorIndex":
|
||||||
cls.cls.bases = ["VectorIndexMixin"]
|
cls.cls.bases = ["VectorIndexMixin", "ConfiguredBaseModel"]
|
||||||
elif cls.cls.name == "DynamicTableRegion":
|
elif cls.cls.name == "DynamicTableRegion":
|
||||||
cls.cls.bases = ["DynamicTableRegionMixin", "VectorData"]
|
cls.cls.bases = ["DynamicTableRegionMixin", "VectorData", "ConfiguredBaseModel"]
|
||||||
elif cls.cls.name == "AlignedDynamicTable":
|
elif cls.cls.name == "AlignedDynamicTable":
|
||||||
cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"]
|
cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"]
|
||||||
|
elif cls.cls.name == "ElementIdentifiers":
|
||||||
|
cls.cls.bases = ["ElementIdentifiersMixin", "Data", "ConfiguredBaseModel"]
|
||||||
|
# make ``value`` generic on T
|
||||||
|
if "value" in cls.cls.attributes:
|
||||||
|
cls.cls.attributes["value"].range = "Optional[T]"
|
||||||
elif cls.cls.name == "TimeSeriesReferenceVectorData":
|
elif cls.cls.name == "TimeSeriesReferenceVectorData":
|
||||||
# in core.nwb.base, so need to inject and import again
|
# in core.nwb.base, so need to inject and import again
|
||||||
cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"]
|
cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"]
|
||||||
|
@ -305,14 +316,31 @@ class AfterGenerateClass:
|
||||||
):
|
):
|
||||||
for an_attr in cls.cls.attributes:
|
for an_attr in cls.cls.attributes:
|
||||||
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
|
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
|
||||||
|
if an_attr == "id":
|
||||||
|
cls.cls.attributes[an_attr].range = "ElementIdentifiers"
|
||||||
|
return cls
|
||||||
|
|
||||||
if an_attr.endswith("_index"):
|
if an_attr.endswith("_index"):
|
||||||
cls.cls.attributes[an_attr].range = "".join(
|
wrap_cls = "VectorIndex"
|
||||||
["VectorIndex[", slot_range, "]"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cls.cls.attributes[an_attr].range = "".join(
|
wrap_cls = "VectorData"
|
||||||
["VectorData[", slot_range, "]"]
|
|
||||||
)
|
cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"])
|
||||||
|
|
||||||
|
return cls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult:
|
||||||
|
"""
|
||||||
|
Inject ElementIdentifiers into module that define dynamictables -
|
||||||
|
needed to handle ID columns
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
cls.source.is_a == "DynamicTable"
|
||||||
|
or "DynamicTable" in sv.class_ancestors(cls.source.name)
|
||||||
|
) and sv.schema.name != "hdmf-common.table":
|
||||||
|
imp = import_method("ElementIdentifiers")
|
||||||
|
cls.imports += [imp]
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,3 +12,20 @@ BASEMODEL_GETITEM = """
|
||||||
else:
|
else:
|
||||||
raise KeyError("No value or data field to index from")
|
raise KeyError("No value or data field to index from")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
BASEMODEL_COERCE_VALUE = """
|
||||||
|
@field_validator("*", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def coerce_value(cls, v: Any, handler) -> Any:
|
||||||
|
\"\"\"Try to rescue instantiation by using the value field\"\"\"
|
||||||
|
try:
|
||||||
|
return handler(v)
|
||||||
|
except Exception as e1:
|
||||||
|
try:
|
||||||
|
if hasattr(v, "value"):
|
||||||
|
return handler(v.value)
|
||||||
|
else:
|
||||||
|
return handler(v["value"])
|
||||||
|
except Exception as e2:
|
||||||
|
raise e2 from e1
|
||||||
|
"""
|
||||||
|
|
|
@ -253,6 +253,8 @@ class DynamicTableMixin(BaseModel):
|
||||||
else:
|
else:
|
||||||
# add any columns not explicitly given an order at the end
|
# add any columns not explicitly given an order at the end
|
||||||
colnames = model["colnames"].copy()
|
colnames = model["colnames"].copy()
|
||||||
|
if isinstance(colnames, np.ndarray):
|
||||||
|
colnames = colnames.tolist()
|
||||||
colnames.extend(
|
colnames.extend(
|
||||||
[
|
[
|
||||||
k
|
k
|
||||||
|
@ -284,9 +286,13 @@ class DynamicTableMixin(BaseModel):
|
||||||
if not isinstance(val, (VectorData, VectorIndex)):
|
if not isinstance(val, (VectorData, VectorIndex)):
|
||||||
try:
|
try:
|
||||||
if key.endswith("_index"):
|
if key.endswith("_index"):
|
||||||
model[key] = VectorIndex(name=key, description="", value=val)
|
to_cast = VectorIndex
|
||||||
else:
|
else:
|
||||||
model[key] = VectorData(name=key, description="", value=val)
|
to_cast = VectorData
|
||||||
|
if isinstance(val, dict):
|
||||||
|
model[key] = to_cast(**val)
|
||||||
|
else:
|
||||||
|
model[key] = VectorIndex(name=key, description="", value=val)
|
||||||
except ValidationError as e: # pragma: no cover
|
except ValidationError as e: # pragma: no cover
|
||||||
raise ValidationError.from_exception_data(
|
raise ValidationError.from_exception_data(
|
||||||
title=f"field {key} cannot be cast to VectorData from {val}",
|
title=f"field {key} cannot be cast to VectorData from {val}",
|
||||||
|
@ -379,10 +385,10 @@ class VectorDataMixin(BaseModel, Generic[T]):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[T] = None
|
value: Optional[T] = None
|
||||||
|
|
||||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
# def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||||
if value is not None and "value" not in kwargs:
|
# if value is not None and "value" not in kwargs:
|
||||||
kwargs["value"] = value
|
# kwargs["value"] = value
|
||||||
super().__init__(**kwargs)
|
# super().__init__(**kwargs)
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
|
@ -703,14 +709,19 @@ class AlignedDynamicTableMixin(BaseModel):
|
||||||
model["categories"] = categories
|
model["categories"] = categories
|
||||||
else:
|
else:
|
||||||
# add any columns not explicitly given an order at the end
|
# add any columns not explicitly given an order at the end
|
||||||
categories = [
|
categories = model["categories"].copy()
|
||||||
|
if isinstance(categories, np.ndarray):
|
||||||
|
categories = categories.tolist()
|
||||||
|
categories.extend(
|
||||||
|
[
|
||||||
k
|
k
|
||||||
for k in model
|
for k in model
|
||||||
if k not in cls.NON_COLUMN_FIELDS
|
if k not in cls.NON_CATEGORY_FIELDS
|
||||||
and not k.endswith("_index")
|
and not k.endswith("_index")
|
||||||
and k not in model["categories"]
|
and k not in model["categories"]
|
||||||
]
|
]
|
||||||
model["categories"].extend(categories)
|
)
|
||||||
|
model["categories"] = categories
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
@ -839,6 +850,13 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ElementIdentifiersMixin(VectorDataMixin):
|
||||||
|
"""
|
||||||
|
Mixin class for ElementIdentifiers - allow treating
|
||||||
|
as generic, and give general indexing methods from VectorData
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
DYNAMIC_TABLE_IMPORTS = Imports(
|
DYNAMIC_TABLE_IMPORTS = Imports(
|
||||||
imports=[
|
imports=[
|
||||||
Import(module="pandas", alias="pd"),
|
Import(module="pandas", alias="pd"),
|
||||||
|
@ -882,6 +900,7 @@ DYNAMIC_TABLE_INJECTS = [
|
||||||
DynamicTableRegionMixin,
|
DynamicTableRegionMixin,
|
||||||
DynamicTableMixin,
|
DynamicTableMixin,
|
||||||
AlignedDynamicTableMixin,
|
AlignedDynamicTableMixin,
|
||||||
|
ElementIdentifiersMixin,
|
||||||
]
|
]
|
||||||
|
|
||||||
TSRVD_IMPORTS = Imports(
|
TSRVD_IMPORTS = Imports(
|
||||||
|
@ -923,3 +942,8 @@ if "pytest" in sys.modules:
|
||||||
"""TimeSeriesReferenceVectorData subclass for testing"""
|
"""TimeSeriesReferenceVectorData subclass for testing"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class ElementIdentifiers(ElementIdentifiersMixin):
|
||||||
|
"""ElementIdentifiers subclass for testing"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
|
@ -22,6 +22,7 @@ Other TODO:
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pdb
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -34,10 +35,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload
|
||||||
import h5py
|
import h5py
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from numpydantic.interface.hdf5 import H5ArrayPath
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from nwb_linkml.maps.hdf5 import ReadPhases, ReadQueue, flatten_hdf, get_references
|
from nwb_linkml.maps.hdf5 import get_references
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nwb_linkml.providers.schema import SchemaProvider
|
from nwb_linkml.providers.schema import SchemaProvider
|
||||||
|
@ -49,7 +51,11 @@ else:
|
||||||
from typing_extensions import Never
|
from typing_extensions import Never
|
||||||
|
|
||||||
|
|
||||||
def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
|
SKIP_PATTERN = re.compile("^/specifications.*")
|
||||||
|
"""Nodes to always skip in reading e.g. because they are handled elsewhere"""
|
||||||
|
|
||||||
|
|
||||||
|
def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
|
||||||
"""
|
"""
|
||||||
Directed dependency graph of dataset and group nodes in an NWBFile such that
|
Directed dependency graph of dataset and group nodes in an NWBFile such that
|
||||||
each node ``n_i`` is connected to node ``n_j`` if
|
each node ``n_i`` is connected to node ``n_j`` if
|
||||||
|
@ -63,14 +69,15 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
|
||||||
* Dataset columns
|
* Dataset columns
|
||||||
* Compound dtypes
|
* Compound dtypes
|
||||||
|
|
||||||
|
Edges are labeled with ``reference`` or ``child`` depending on the type of edge it is,
|
||||||
|
and attributes from the hdf5 file are added as node attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph
|
h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`networkx.DiGraph`
|
:class:`networkx.DiGraph`
|
||||||
"""
|
"""
|
||||||
# detect nodes to skip
|
|
||||||
skip_pattern = re.compile("^/specifications.*")
|
|
||||||
|
|
||||||
if isinstance(h5f, (Path, str)):
|
if isinstance(h5f, (Path, str)):
|
||||||
h5f = h5py.File(h5f, "r")
|
h5f = h5py.File(h5f, "r")
|
||||||
|
@ -78,17 +85,19 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
|
|
||||||
def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None:
|
def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None:
|
||||||
if skip_pattern.match(name):
|
if SKIP_PATTERN.match(node.name):
|
||||||
return
|
return
|
||||||
# find references in attributes
|
# find references in attributes
|
||||||
refs = get_references(node)
|
refs = get_references(node)
|
||||||
if isinstance(node, h5py.Group):
|
# add edges from references
|
||||||
refs.extend([child.name for child in node.values()])
|
edges = [(node.name, ref) for ref in refs if not SKIP_PATTERN.match(ref)]
|
||||||
refs = set(refs)
|
g.add_edges_from(edges, label="reference")
|
||||||
|
|
||||||
# add edges
|
# add children, if group
|
||||||
edges = [(node.name, ref) for ref in refs]
|
if isinstance(node, h5py.Group):
|
||||||
g.add_edges_from(edges)
|
children = [child.name for child in node.values() if not SKIP_PATTERN.match(child.name)]
|
||||||
|
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
|
||||||
|
g.add_edges_from(edges, label="child")
|
||||||
|
|
||||||
# ensure node added to graph
|
# ensure node added to graph
|
||||||
if len(edges) == 0:
|
if len(edges) == 0:
|
||||||
|
@ -119,13 +128,125 @@ def filter_dependency_graph(g: nx.DiGraph) -> nx.DiGraph:
|
||||||
node: str
|
node: str
|
||||||
for node in g.nodes:
|
for node in g.nodes:
|
||||||
ndtype = g.nodes[node].get("neurodata_type", None)
|
ndtype = g.nodes[node].get("neurodata_type", None)
|
||||||
if ndtype == "VectorData" or not ndtype and g.out_degree(node) == 0:
|
if (ndtype is None and g.out_degree(node) == 0) or SKIP_PATTERN.match(node):
|
||||||
remove_nodes.append(node)
|
remove_nodes.append(node)
|
||||||
|
|
||||||
g.remove_nodes_from(remove_nodes)
|
g.remove_nodes_from(remove_nodes)
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def _load_node(
|
||||||
|
path: str, h5f: h5py.File, provider: "SchemaProvider", context: dict
|
||||||
|
) -> dict | BaseModel:
|
||||||
|
"""
|
||||||
|
Load an individual node in the graph, then removes it from the graph
|
||||||
|
Args:
|
||||||
|
path:
|
||||||
|
g:
|
||||||
|
context:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
obj = h5f.get(path)
|
||||||
|
|
||||||
|
if isinstance(obj, h5py.Dataset):
|
||||||
|
args = _load_dataset(obj, h5f, context)
|
||||||
|
elif isinstance(obj, h5py.Group):
|
||||||
|
args = _load_group(obj, h5f, context)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
|
||||||
|
|
||||||
|
# if obj.name == "/general/intracellular_ephys/simultaneous_recordings/recordings":
|
||||||
|
# pdb.set_trace()
|
||||||
|
|
||||||
|
# resolve attr references
|
||||||
|
for k, v in args.items():
|
||||||
|
if isinstance(v, h5py.h5r.Reference):
|
||||||
|
ref_path = h5f[v].name
|
||||||
|
args[k] = context[ref_path]
|
||||||
|
|
||||||
|
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
|
||||||
|
|
||||||
|
# add additional needed params
|
||||||
|
args["hdf5_path"] = path
|
||||||
|
args["name"] = path.split("/")[-1]
|
||||||
|
return model(**args)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dataset(
|
||||||
|
dataset: h5py.Dataset, h5f: h5py.File, context: dict
|
||||||
|
) -> Union[dict, str, int, float]:
|
||||||
|
"""
|
||||||
|
Resolves datasets that do not have a ``neurodata_type`` as a dictionary or a scalar.
|
||||||
|
|
||||||
|
If the dataset is a single value without attrs, load it and return as a scalar value.
|
||||||
|
Otherwise return a :class:`.H5ArrayPath` as a reference to the dataset in the `value` key.
|
||||||
|
"""
|
||||||
|
res = {}
|
||||||
|
if dataset.shape == ():
|
||||||
|
val = dataset[()]
|
||||||
|
if isinstance(val, h5py.h5r.Reference):
|
||||||
|
val = context.get(h5f[val].name)
|
||||||
|
# if this is just a scalar value, return it
|
||||||
|
if not dataset.attrs:
|
||||||
|
return val
|
||||||
|
|
||||||
|
res["value"] = val
|
||||||
|
elif len(dataset) > 0 and isinstance(dataset[0], h5py.h5r.Reference):
|
||||||
|
# vector of references
|
||||||
|
res["value"] = [context.get(h5f[ref].name) for ref in dataset[:]]
|
||||||
|
elif len(dataset.dtype) > 1:
|
||||||
|
# compound dataset - check if any of the fields are references
|
||||||
|
for name in dataset.dtype.names:
|
||||||
|
if isinstance(dataset[name][0], h5py.h5r.Reference):
|
||||||
|
res[name] = [context.get(h5f[ref].name) for ref in dataset[name]]
|
||||||
|
else:
|
||||||
|
res[name] = H5ArrayPath(h5f.filename, dataset.name, name)
|
||||||
|
else:
|
||||||
|
res["value"] = H5ArrayPath(h5f.filename, dataset.name)
|
||||||
|
|
||||||
|
res.update(dataset.attrs)
|
||||||
|
if "namespace" in res:
|
||||||
|
del res["namespace"]
|
||||||
|
if "neurodata_type" in res:
|
||||||
|
del res["neurodata_type"]
|
||||||
|
res["name"] = dataset.name.split("/")[-1]
|
||||||
|
res["hdf5_path"] = dataset.name
|
||||||
|
|
||||||
|
if len(res) == 1:
|
||||||
|
return res["value"]
|
||||||
|
else:
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Load a group!
|
||||||
|
"""
|
||||||
|
res = {}
|
||||||
|
res.update(group.attrs)
|
||||||
|
for child_name, child in group.items():
|
||||||
|
if child.name in context:
|
||||||
|
res[child_name] = context[child.name]
|
||||||
|
elif isinstance(child, h5py.Dataset):
|
||||||
|
res[child_name] = _load_dataset(child, h5f, context)
|
||||||
|
elif isinstance(child, h5py.Group):
|
||||||
|
res[child_name] = _load_group(child, h5f, context)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Can only handle preinstantiated child objects in context, datasets, and group,"
|
||||||
|
f" got {child} for {child_name}"
|
||||||
|
)
|
||||||
|
if "namespace" in res:
|
||||||
|
del res["namespace"]
|
||||||
|
if "neurodata_type" in res:
|
||||||
|
del res["neurodata_type"]
|
||||||
|
res["name"] = group.name.split("/")[-1]
|
||||||
|
res["hdf5_path"] = group.name
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class HDF5IO:
|
class HDF5IO:
|
||||||
"""
|
"""
|
||||||
Read (and eventually write) from an NWB HDF5 file.
|
Read (and eventually write) from an NWB HDF5 file.
|
||||||
|
@ -185,28 +306,20 @@ class HDF5IO:
|
||||||
|
|
||||||
h5f = h5py.File(str(self.path))
|
h5f = h5py.File(str(self.path))
|
||||||
src = h5f.get(path) if path else h5f
|
src = h5f.get(path) if path else h5f
|
||||||
|
graph = hdf_dependency_graph(src)
|
||||||
|
graph = filter_dependency_graph(graph)
|
||||||
|
|
||||||
# get all children of selected item
|
# topo sort to get read order
|
||||||
if isinstance(src, (h5py.File, h5py.Group)):
|
# TODO: This could be parallelized using `topological_generations`,
|
||||||
children = flatten_hdf(src)
|
# but it's not clear what the perf bonus would be because there are many generations
|
||||||
else:
|
# with few items
|
||||||
raise NotImplementedError("directly read individual datasets")
|
topo_order = list(reversed(list(nx.topological_sort(graph))))
|
||||||
|
context = {}
|
||||||
|
for node in topo_order:
|
||||||
|
res = _load_node(node, h5f, provider, context)
|
||||||
|
context[node] = res
|
||||||
|
|
||||||
queue = ReadQueue(h5f=self.path, queue=children, provider=provider)
|
pdb.set_trace()
|
||||||
|
|
||||||
# Apply initial planning phase of reading
|
|
||||||
queue.apply_phase(ReadPhases.plan)
|
|
||||||
# Read operations gather the data before casting into models
|
|
||||||
queue.apply_phase(ReadPhases.read)
|
|
||||||
# Construction operations actually cast the models
|
|
||||||
# this often needs to run several times as models with dependencies wait for their
|
|
||||||
# dependents to be cast
|
|
||||||
queue.apply_phase(ReadPhases.construct)
|
|
||||||
|
|
||||||
if path is None:
|
|
||||||
return queue.completed["/"].result
|
|
||||||
else:
|
|
||||||
return queue.completed[path].result
|
|
||||||
|
|
||||||
def write(self, path: Path) -> Never:
|
def write(self, path: Path) -> Never:
|
||||||
"""
|
"""
|
||||||
|
@ -246,7 +359,7 @@ class HDF5IO:
|
||||||
"""
|
"""
|
||||||
from nwb_linkml.providers.schema import SchemaProvider
|
from nwb_linkml.providers.schema import SchemaProvider
|
||||||
|
|
||||||
h5f = h5py.File(str(self.path))
|
h5f = h5py.File(str(self.path), "r")
|
||||||
schema = read_specs_as_dicts(h5f.get("specifications"))
|
schema = read_specs_as_dicts(h5f.get("specifications"))
|
||||||
|
|
||||||
# get versions for each namespace
|
# get versions for each namespace
|
||||||
|
@ -260,7 +373,7 @@ class HDF5IO:
|
||||||
provider = SchemaProvider(versions=versions)
|
provider = SchemaProvider(versions=versions)
|
||||||
|
|
||||||
# build schema so we have them cached
|
# build schema so we have them cached
|
||||||
provider.build_from_dicts(schema)
|
# provider.build_from_dicts(schema)
|
||||||
h5f.close()
|
h5f.close()
|
||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
|
@ -233,66 +233,6 @@ class PruneEmpty(HDF5Map):
|
||||||
return H5ReadResult.model_construct(path=src.path, source=src, completed=True)
|
return H5ReadResult.model_construct(path=src.path, source=src, completed=True)
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# class ResolveDynamicTable(HDF5Map):
|
|
||||||
# """
|
|
||||||
# Handle loading a dynamic table!
|
|
||||||
#
|
|
||||||
# Dynamic tables are sort of odd in that their models don't include their fields
|
|
||||||
# (except as a list of strings in ``colnames`` ),
|
|
||||||
# so we need to create a new model that includes fields for each column,
|
|
||||||
# and then we include the datasets as :class:`~numpydantic.interface.hdf5.H5ArrayPath`
|
|
||||||
# objects which lazy load the arrays in a thread/process safe way.
|
|
||||||
#
|
|
||||||
# This map also resolves the child elements,
|
|
||||||
# indicating so by the ``completes`` field in the :class:`.ReadResult`
|
|
||||||
# """
|
|
||||||
#
|
|
||||||
# phase = ReadPhases.read
|
|
||||||
# priority = 1
|
|
||||||
#
|
|
||||||
# @classmethod
|
|
||||||
# def check(
|
|
||||||
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
|
|
||||||
# ) -> bool:
|
|
||||||
# if src.h5_type == "dataset":
|
|
||||||
# return False
|
|
||||||
# if "neurodata_type" in src.attrs:
|
|
||||||
# if src.attrs["neurodata_type"] == "DynamicTable":
|
|
||||||
# return True
|
|
||||||
# # otherwise, see if it's a subclass
|
|
||||||
# model = provider.get_class(src.attrs["namespace"], src.attrs["neurodata_type"])
|
|
||||||
# # just inspect the MRO as strings rather than trying to check subclasses because
|
|
||||||
# # we might replace DynamicTable in the future, and there isn't a stable DynamicTable
|
|
||||||
# # class to inherit from anyway because of the whole multiple versions thing
|
|
||||||
# parents = [parent.__name__ for parent in model.__mro__]
|
|
||||||
# return "DynamicTable" in parents
|
|
||||||
# else:
|
|
||||||
# return False
|
|
||||||
#
|
|
||||||
# @classmethod
|
|
||||||
# def apply(
|
|
||||||
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
|
|
||||||
# ) -> H5ReadResult:
|
|
||||||
# with h5py.File(src.h5f_path, "r") as h5f:
|
|
||||||
# obj = h5f.get(src.path)
|
|
||||||
#
|
|
||||||
# # make a populated model :)
|
|
||||||
# base_model = provider.get_class(src.namespace, src.neurodata_type)
|
|
||||||
# model = dynamictable_to_model(obj, base=base_model)
|
|
||||||
#
|
|
||||||
# completes = [HDF5_Path(child.name) for child in obj.values()]
|
|
||||||
#
|
|
||||||
# return H5ReadResult(
|
|
||||||
# path=src.path,
|
|
||||||
# source=src,
|
|
||||||
# result=model,
|
|
||||||
# completes=completes,
|
|
||||||
# completed=True,
|
|
||||||
# applied=["ResolveDynamicTable"],
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
class ResolveModelGroup(HDF5Map):
|
class ResolveModelGroup(HDF5Map):
|
||||||
"""
|
"""
|
||||||
HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs.
|
HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs.
|
||||||
|
|
|
@ -97,9 +97,9 @@ class Provider(ABC):
|
||||||
module_path = Path(importlib.util.find_spec("nwb_models").origin).parent
|
module_path = Path(importlib.util.find_spec("nwb_models").origin).parent
|
||||||
|
|
||||||
if self.PROVIDES == "linkml":
|
if self.PROVIDES == "linkml":
|
||||||
namespace_path = module_path / "schema" / "linkml" / namespace
|
namespace_path = module_path / "schema" / "linkml" / namespace_module
|
||||||
elif self.PROVIDES == "pydantic":
|
elif self.PROVIDES == "pydantic":
|
||||||
namespace_path = module_path / "models" / "pydantic" / namespace
|
namespace_path = module_path / "models" / "pydantic" / namespace_module
|
||||||
|
|
||||||
if version is not None:
|
if version is not None:
|
||||||
version_path = namespace_path / version_module_case(version)
|
version_path = namespace_path / version_module_case(version)
|
||||||
|
|
3
nwb_linkml/tests/fixtures/__init__.py
vendored
3
nwb_linkml/tests/fixtures/__init__.py
vendored
|
@ -1,4 +1,4 @@
|
||||||
from .nwb import nwb_file
|
from .nwb import nwb_file, nwb_file_base
|
||||||
from .paths import data_dir, tmp_output_dir, tmp_output_dir_func, tmp_output_dir_mod
|
from .paths import data_dir, tmp_output_dir, tmp_output_dir_func, tmp_output_dir_mod
|
||||||
from .schema import (
|
from .schema import (
|
||||||
NWBSchemaTest,
|
NWBSchemaTest,
|
||||||
|
@ -21,6 +21,7 @@ __all__ = [
|
||||||
"nwb_core_linkml",
|
"nwb_core_linkml",
|
||||||
"nwb_core_module",
|
"nwb_core_module",
|
||||||
"nwb_file",
|
"nwb_file",
|
||||||
|
"nwb_file_base",
|
||||||
"nwb_schema",
|
"nwb_schema",
|
||||||
"tmp_output_dir",
|
"tmp_output_dir",
|
||||||
"tmp_output_dir_func",
|
"tmp_output_dir_func",
|
||||||
|
|
2
nwb_linkml/tests/fixtures/paths.py
vendored
2
nwb_linkml/tests/fixtures/paths.py
vendored
|
@ -6,7 +6,7 @@ import pytest
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def tmp_output_dir(request: pytest.FixtureRequest) -> Path:
|
def tmp_output_dir(request: pytest.FixtureRequest) -> Path:
|
||||||
path = Path(__file__).parent.resolve() / "__tmp__"
|
path = Path(__file__).parents[1].resolve() / "__tmp__"
|
||||||
if path.exists():
|
if path.exists():
|
||||||
if request.config.getoption("--clean"):
|
if request.config.getoption("--clean"):
|
||||||
shutil.rmtree(path)
|
shutil.rmtree(path)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -100,10 +101,15 @@ def test_flatten_hdf():
|
||||||
raise NotImplementedError("Just a stub for local testing for now, finish me!")
|
raise NotImplementedError("Just a stub for local testing for now, finish me!")
|
||||||
|
|
||||||
|
|
||||||
def test_dependency_graph(nwb_file):
|
@pytest.mark.dev
|
||||||
|
def test_dependency_graph(nwb_file, tmp_output_dir):
|
||||||
"""
|
"""
|
||||||
dependency graph is correctly constructed from an HDF5 file
|
dependency graph is correctly constructed from an HDF5 file
|
||||||
"""
|
"""
|
||||||
graph = hdf_dependency_graph(nwb_file)
|
graph = hdf_dependency_graph(nwb_file)
|
||||||
|
A_unfiltered = nx.nx_agraph.to_agraph(graph)
|
||||||
|
A_unfiltered.draw(tmp_output_dir / "test_nwb_unfiltered.png", prog="dot")
|
||||||
graph = filter_dependency_graph(graph)
|
graph = filter_dependency_graph(graph)
|
||||||
|
A_filtered = nx.nx_agraph.to_agraph(graph)
|
||||||
|
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -2,12 +2,14 @@
|
||||||
Placeholder test module to test reading from pynwb-generated NWB file
|
Placeholder test module to test reading from pynwb-generated NWB file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from nwb_linkml.io.hdf5 import HDF5IO
|
||||||
|
|
||||||
|
|
||||||
def test_read_from_nwbfile(nwb_file):
|
def test_read_from_nwbfile(nwb_file):
|
||||||
"""
|
"""
|
||||||
Read data from a pynwb HDF5 NWB file
|
Read data from a pynwb HDF5 NWB file
|
||||||
"""
|
"""
|
||||||
pass
|
res = HDF5IO(nwb_file).read()
|
||||||
|
|
||||||
|
|
||||||
def test_read_from_yaml(nwb_file):
|
def test_read_from_yaml(nwb_file):
|
||||||
|
|
Loading…
Reference in a new issue