mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-09 05:34:28 +00:00
checkpointing working on model loading. it's a sloggggggggg
This commit is contained in:
parent
cd3d7ca78e
commit
d1498a3733
11 changed files with 258 additions and 127 deletions
|
@ -22,7 +22,7 @@ dependencies = [
|
|||
"pydantic-settings>=2.0.3",
|
||||
"tqdm>=4.66.1",
|
||||
'typing-extensions>=4.12.2;python_version<"3.11"',
|
||||
"numpydantic>=1.3.3",
|
||||
"numpydantic>=1.5.0",
|
||||
"black>=24.4.2",
|
||||
"pandas>=2.2.2",
|
||||
"networkx>=3.3",
|
||||
|
|
|
@ -10,7 +10,7 @@ import sys
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import ClassVar, Dict, List, Optional, Tuple
|
||||
from typing import ClassVar, Dict, List, Optional, Tuple, Literal
|
||||
|
||||
from linkml.generators import PydanticGenerator
|
||||
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
|
||||
|
@ -27,7 +27,7 @@ from linkml_runtime.utils.compile_python import file_text
|
|||
from linkml_runtime.utils.formatutils import remove_empty_items
|
||||
from linkml_runtime.utils.schemaview import SchemaView
|
||||
|
||||
from nwb_linkml.includes.base import BASEMODEL_GETITEM
|
||||
from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE
|
||||
from nwb_linkml.includes.hdmf import (
|
||||
DYNAMIC_TABLE_IMPORTS,
|
||||
DYNAMIC_TABLE_INJECTS,
|
||||
|
@ -52,6 +52,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
),
|
||||
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
|
||||
BASEMODEL_GETITEM,
|
||||
BASEMODEL_COERCE_VALUE,
|
||||
)
|
||||
split: bool = True
|
||||
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
|
||||
|
@ -66,6 +67,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
emit_metadata: bool = True
|
||||
gen_classvars: bool = True
|
||||
gen_slots: bool = True
|
||||
extra_fields: Literal["allow", "forbid", "ignore"] = "allow"
|
||||
|
||||
skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias")
|
||||
|
||||
|
@ -131,6 +133,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
"""Customize dynamictable behavior"""
|
||||
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
||||
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
|
||||
return cls
|
||||
|
||||
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
||||
|
@ -255,7 +258,7 @@ class AfterGenerateClass:
|
|||
|
||||
"""
|
||||
if cls.cls.name in "DynamicTable":
|
||||
cls.cls.bases = ["DynamicTableMixin"]
|
||||
cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"]
|
||||
|
||||
if cls.injected_classes is None:
|
||||
cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy()
|
||||
|
@ -269,13 +272,21 @@ class AfterGenerateClass:
|
|||
else:
|
||||
cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy()
|
||||
elif cls.cls.name == "VectorData":
|
||||
cls.cls.bases = ["VectorDataMixin"]
|
||||
cls.cls.bases = ["VectorDataMixin", "ConfiguredBaseModel"]
|
||||
# make ``value`` generic on T
|
||||
if "value" in cls.cls.attributes:
|
||||
cls.cls.attributes["value"].range = "Optional[T]"
|
||||
elif cls.cls.name == "VectorIndex":
|
||||
cls.cls.bases = ["VectorIndexMixin"]
|
||||
cls.cls.bases = ["VectorIndexMixin", "ConfiguredBaseModel"]
|
||||
elif cls.cls.name == "DynamicTableRegion":
|
||||
cls.cls.bases = ["DynamicTableRegionMixin", "VectorData"]
|
||||
cls.cls.bases = ["DynamicTableRegionMixin", "VectorData", "ConfiguredBaseModel"]
|
||||
elif cls.cls.name == "AlignedDynamicTable":
|
||||
cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"]
|
||||
elif cls.cls.name == "ElementIdentifiers":
|
||||
cls.cls.bases = ["ElementIdentifiersMixin", "Data", "ConfiguredBaseModel"]
|
||||
# make ``value`` generic on T
|
||||
if "value" in cls.cls.attributes:
|
||||
cls.cls.attributes["value"].range = "Optional[T]"
|
||||
elif cls.cls.name == "TimeSeriesReferenceVectorData":
|
||||
# in core.nwb.base, so need to inject and import again
|
||||
cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"]
|
||||
|
@ -305,14 +316,31 @@ class AfterGenerateClass:
|
|||
):
|
||||
for an_attr in cls.cls.attributes:
|
||||
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
|
||||
if an_attr == "id":
|
||||
cls.cls.attributes[an_attr].range = "ElementIdentifiers"
|
||||
return cls
|
||||
|
||||
if an_attr.endswith("_index"):
|
||||
cls.cls.attributes[an_attr].range = "".join(
|
||||
["VectorIndex[", slot_range, "]"]
|
||||
)
|
||||
wrap_cls = "VectorIndex"
|
||||
else:
|
||||
cls.cls.attributes[an_attr].range = "".join(
|
||||
["VectorData[", slot_range, "]"]
|
||||
)
|
||||
wrap_cls = "VectorData"
|
||||
|
||||
cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"])
|
||||
|
||||
return cls
|
||||
|
||||
@staticmethod
|
||||
def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult:
|
||||
"""
|
||||
Inject ElementIdentifiers into module that define dynamictables -
|
||||
needed to handle ID columns
|
||||
"""
|
||||
if (
|
||||
cls.source.is_a == "DynamicTable"
|
||||
or "DynamicTable" in sv.class_ancestors(cls.source.name)
|
||||
) and sv.schema.name != "hdmf-common.table":
|
||||
imp = import_method("ElementIdentifiers")
|
||||
cls.imports += [imp]
|
||||
return cls
|
||||
|
||||
|
||||
|
|
|
@ -12,3 +12,20 @@ BASEMODEL_GETITEM = """
|
|||
else:
|
||||
raise KeyError("No value or data field to index from")
|
||||
"""
|
||||
|
||||
BASEMODEL_COERCE_VALUE = """
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def coerce_value(cls, v: Any, handler) -> Any:
|
||||
\"\"\"Try to rescue instantiation by using the value field\"\"\"
|
||||
try:
|
||||
return handler(v)
|
||||
except Exception as e1:
|
||||
try:
|
||||
if hasattr(v, "value"):
|
||||
return handler(v.value)
|
||||
else:
|
||||
return handler(v["value"])
|
||||
except Exception as e2:
|
||||
raise e2 from e1
|
||||
"""
|
||||
|
|
|
@ -253,6 +253,8 @@ class DynamicTableMixin(BaseModel):
|
|||
else:
|
||||
# add any columns not explicitly given an order at the end
|
||||
colnames = model["colnames"].copy()
|
||||
if isinstance(colnames, np.ndarray):
|
||||
colnames = colnames.tolist()
|
||||
colnames.extend(
|
||||
[
|
||||
k
|
||||
|
@ -284,9 +286,13 @@ class DynamicTableMixin(BaseModel):
|
|||
if not isinstance(val, (VectorData, VectorIndex)):
|
||||
try:
|
||||
if key.endswith("_index"):
|
||||
model[key] = VectorIndex(name=key, description="", value=val)
|
||||
to_cast = VectorIndex
|
||||
else:
|
||||
model[key] = VectorData(name=key, description="", value=val)
|
||||
to_cast = VectorData
|
||||
if isinstance(val, dict):
|
||||
model[key] = to_cast(**val)
|
||||
else:
|
||||
model[key] = VectorIndex(name=key, description="", value=val)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise ValidationError.from_exception_data(
|
||||
title=f"field {key} cannot be cast to VectorData from {val}",
|
||||
|
@ -379,10 +385,10 @@ class VectorDataMixin(BaseModel, Generic[T]):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[T] = None
|
||||
|
||||
def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
# def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
||||
# if value is not None and "value" not in kwargs:
|
||||
# kwargs["value"] = value
|
||||
# super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
|
@ -703,14 +709,19 @@ class AlignedDynamicTableMixin(BaseModel):
|
|||
model["categories"] = categories
|
||||
else:
|
||||
# add any columns not explicitly given an order at the end
|
||||
categories = [
|
||||
k
|
||||
for k in model
|
||||
if k not in cls.NON_COLUMN_FIELDS
|
||||
and not k.endswith("_index")
|
||||
and k not in model["categories"]
|
||||
]
|
||||
model["categories"].extend(categories)
|
||||
categories = model["categories"].copy()
|
||||
if isinstance(categories, np.ndarray):
|
||||
categories = categories.tolist()
|
||||
categories.extend(
|
||||
[
|
||||
k
|
||||
for k in model
|
||||
if k not in cls.NON_CATEGORY_FIELDS
|
||||
and not k.endswith("_index")
|
||||
and k not in model["categories"]
|
||||
]
|
||||
)
|
||||
model["categories"] = categories
|
||||
return model
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -839,6 +850,13 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
|||
)
|
||||
|
||||
|
||||
class ElementIdentifiersMixin(VectorDataMixin):
|
||||
"""
|
||||
Mixin class for ElementIdentifiers - allow treating
|
||||
as generic, and give general indexing methods from VectorData
|
||||
"""
|
||||
|
||||
|
||||
DYNAMIC_TABLE_IMPORTS = Imports(
|
||||
imports=[
|
||||
Import(module="pandas", alias="pd"),
|
||||
|
@ -882,6 +900,7 @@ DYNAMIC_TABLE_INJECTS = [
|
|||
DynamicTableRegionMixin,
|
||||
DynamicTableMixin,
|
||||
AlignedDynamicTableMixin,
|
||||
ElementIdentifiersMixin,
|
||||
]
|
||||
|
||||
TSRVD_IMPORTS = Imports(
|
||||
|
@ -923,3 +942,8 @@ if "pytest" in sys.modules:
|
|||
"""TimeSeriesReferenceVectorData subclass for testing"""
|
||||
|
||||
pass
|
||||
|
||||
class ElementIdentifiers(ElementIdentifiersMixin):
|
||||
"""ElementIdentifiers subclass for testing"""
|
||||
|
||||
pass
|
||||
|
|
|
@ -22,6 +22,7 @@ Other TODO:
|
|||
|
||||
import json
|
||||
import os
|
||||
import pdb
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
|
@ -34,10 +35,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload
|
|||
import h5py
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from numpydantic.interface.hdf5 import H5ArrayPath
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from nwb_linkml.maps.hdf5 import ReadPhases, ReadQueue, flatten_hdf, get_references
|
||||
from nwb_linkml.maps.hdf5 import get_references
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nwb_linkml.providers.schema import SchemaProvider
|
||||
|
@ -49,7 +51,11 @@ else:
|
|||
from typing_extensions import Never
|
||||
|
||||
|
||||
def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
|
||||
SKIP_PATTERN = re.compile("^/specifications.*")
|
||||
"""Nodes to always skip in reading e.g. because they are handled elsewhere"""
|
||||
|
||||
|
||||
def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
|
||||
"""
|
||||
Directed dependency graph of dataset and group nodes in an NWBFile such that
|
||||
each node ``n_i`` is connected to node ``n_j`` if
|
||||
|
@ -63,14 +69,15 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
|
|||
* Dataset columns
|
||||
* Compound dtypes
|
||||
|
||||
Edges are labeled with ``reference`` or ``child`` depending on the type of edge it is,
|
||||
and attributes from the hdf5 file are added as node attributes.
|
||||
|
||||
Args:
|
||||
h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph
|
||||
|
||||
Returns:
|
||||
:class:`networkx.DiGraph`
|
||||
"""
|
||||
# detect nodes to skip
|
||||
skip_pattern = re.compile("^/specifications.*")
|
||||
|
||||
if isinstance(h5f, (Path, str)):
|
||||
h5f = h5py.File(h5f, "r")
|
||||
|
@ -78,17 +85,19 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
|
|||
g = nx.DiGraph()
|
||||
|
||||
def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None:
|
||||
if skip_pattern.match(name):
|
||||
if SKIP_PATTERN.match(node.name):
|
||||
return
|
||||
# find references in attributes
|
||||
refs = get_references(node)
|
||||
if isinstance(node, h5py.Group):
|
||||
refs.extend([child.name for child in node.values()])
|
||||
refs = set(refs)
|
||||
# add edges from references
|
||||
edges = [(node.name, ref) for ref in refs if not SKIP_PATTERN.match(ref)]
|
||||
g.add_edges_from(edges, label="reference")
|
||||
|
||||
# add edges
|
||||
edges = [(node.name, ref) for ref in refs]
|
||||
g.add_edges_from(edges)
|
||||
# add children, if group
|
||||
if isinstance(node, h5py.Group):
|
||||
children = [child.name for child in node.values() if not SKIP_PATTERN.match(child.name)]
|
||||
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
|
||||
g.add_edges_from(edges, label="child")
|
||||
|
||||
# ensure node added to graph
|
||||
if len(edges) == 0:
|
||||
|
@ -119,13 +128,125 @@ def filter_dependency_graph(g: nx.DiGraph) -> nx.DiGraph:
|
|||
node: str
|
||||
for node in g.nodes:
|
||||
ndtype = g.nodes[node].get("neurodata_type", None)
|
||||
if ndtype == "VectorData" or not ndtype and g.out_degree(node) == 0:
|
||||
if (ndtype is None and g.out_degree(node) == 0) or SKIP_PATTERN.match(node):
|
||||
remove_nodes.append(node)
|
||||
|
||||
g.remove_nodes_from(remove_nodes)
|
||||
return g
|
||||
|
||||
|
||||
def _load_node(
|
||||
path: str, h5f: h5py.File, provider: "SchemaProvider", context: dict
|
||||
) -> dict | BaseModel:
|
||||
"""
|
||||
Load an individual node in the graph, then removes it from the graph
|
||||
Args:
|
||||
path:
|
||||
g:
|
||||
context:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
obj = h5f.get(path)
|
||||
|
||||
if isinstance(obj, h5py.Dataset):
|
||||
args = _load_dataset(obj, h5f, context)
|
||||
elif isinstance(obj, h5py.Group):
|
||||
args = _load_group(obj, h5f, context)
|
||||
else:
|
||||
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
|
||||
|
||||
# if obj.name == "/general/intracellular_ephys/simultaneous_recordings/recordings":
|
||||
# pdb.set_trace()
|
||||
|
||||
# resolve attr references
|
||||
for k, v in args.items():
|
||||
if isinstance(v, h5py.h5r.Reference):
|
||||
ref_path = h5f[v].name
|
||||
args[k] = context[ref_path]
|
||||
|
||||
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
|
||||
|
||||
# add additional needed params
|
||||
args["hdf5_path"] = path
|
||||
args["name"] = path.split("/")[-1]
|
||||
return model(**args)
|
||||
|
||||
|
||||
def _load_dataset(
|
||||
dataset: h5py.Dataset, h5f: h5py.File, context: dict
|
||||
) -> Union[dict, str, int, float]:
|
||||
"""
|
||||
Resolves datasets that do not have a ``neurodata_type`` as a dictionary or a scalar.
|
||||
|
||||
If the dataset is a single value without attrs, load it and return as a scalar value.
|
||||
Otherwise return a :class:`.H5ArrayPath` as a reference to the dataset in the `value` key.
|
||||
"""
|
||||
res = {}
|
||||
if dataset.shape == ():
|
||||
val = dataset[()]
|
||||
if isinstance(val, h5py.h5r.Reference):
|
||||
val = context.get(h5f[val].name)
|
||||
# if this is just a scalar value, return it
|
||||
if not dataset.attrs:
|
||||
return val
|
||||
|
||||
res["value"] = val
|
||||
elif len(dataset) > 0 and isinstance(dataset[0], h5py.h5r.Reference):
|
||||
# vector of references
|
||||
res["value"] = [context.get(h5f[ref].name) for ref in dataset[:]]
|
||||
elif len(dataset.dtype) > 1:
|
||||
# compound dataset - check if any of the fields are references
|
||||
for name in dataset.dtype.names:
|
||||
if isinstance(dataset[name][0], h5py.h5r.Reference):
|
||||
res[name] = [context.get(h5f[ref].name) for ref in dataset[name]]
|
||||
else:
|
||||
res[name] = H5ArrayPath(h5f.filename, dataset.name, name)
|
||||
else:
|
||||
res["value"] = H5ArrayPath(h5f.filename, dataset.name)
|
||||
|
||||
res.update(dataset.attrs)
|
||||
if "namespace" in res:
|
||||
del res["namespace"]
|
||||
if "neurodata_type" in res:
|
||||
del res["neurodata_type"]
|
||||
res["name"] = dataset.name.split("/")[-1]
|
||||
res["hdf5_path"] = dataset.name
|
||||
|
||||
if len(res) == 1:
|
||||
return res["value"]
|
||||
else:
|
||||
return res
|
||||
|
||||
|
||||
def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
|
||||
"""
|
||||
Load a group!
|
||||
"""
|
||||
res = {}
|
||||
res.update(group.attrs)
|
||||
for child_name, child in group.items():
|
||||
if child.name in context:
|
||||
res[child_name] = context[child.name]
|
||||
elif isinstance(child, h5py.Dataset):
|
||||
res[child_name] = _load_dataset(child, h5f, context)
|
||||
elif isinstance(child, h5py.Group):
|
||||
res[child_name] = _load_group(child, h5f, context)
|
||||
else:
|
||||
raise TypeError(
|
||||
"Can only handle preinstantiated child objects in context, datasets, and group,"
|
||||
f" got {child} for {child_name}"
|
||||
)
|
||||
if "namespace" in res:
|
||||
del res["namespace"]
|
||||
if "neurodata_type" in res:
|
||||
del res["neurodata_type"]
|
||||
res["name"] = group.name.split("/")[-1]
|
||||
res["hdf5_path"] = group.name
|
||||
return res
|
||||
|
||||
|
||||
class HDF5IO:
|
||||
"""
|
||||
Read (and eventually write) from an NWB HDF5 file.
|
||||
|
@ -185,28 +306,20 @@ class HDF5IO:
|
|||
|
||||
h5f = h5py.File(str(self.path))
|
||||
src = h5f.get(path) if path else h5f
|
||||
graph = hdf_dependency_graph(src)
|
||||
graph = filter_dependency_graph(graph)
|
||||
|
||||
# get all children of selected item
|
||||
if isinstance(src, (h5py.File, h5py.Group)):
|
||||
children = flatten_hdf(src)
|
||||
else:
|
||||
raise NotImplementedError("directly read individual datasets")
|
||||
# topo sort to get read order
|
||||
# TODO: This could be parallelized using `topological_generations`,
|
||||
# but it's not clear what the perf bonus would be because there are many generations
|
||||
# with few items
|
||||
topo_order = list(reversed(list(nx.topological_sort(graph))))
|
||||
context = {}
|
||||
for node in topo_order:
|
||||
res = _load_node(node, h5f, provider, context)
|
||||
context[node] = res
|
||||
|
||||
queue = ReadQueue(h5f=self.path, queue=children, provider=provider)
|
||||
|
||||
# Apply initial planning phase of reading
|
||||
queue.apply_phase(ReadPhases.plan)
|
||||
# Read operations gather the data before casting into models
|
||||
queue.apply_phase(ReadPhases.read)
|
||||
# Construction operations actually cast the models
|
||||
# this often needs to run several times as models with dependencies wait for their
|
||||
# dependents to be cast
|
||||
queue.apply_phase(ReadPhases.construct)
|
||||
|
||||
if path is None:
|
||||
return queue.completed["/"].result
|
||||
else:
|
||||
return queue.completed[path].result
|
||||
pdb.set_trace()
|
||||
|
||||
def write(self, path: Path) -> Never:
|
||||
"""
|
||||
|
@ -246,7 +359,7 @@ class HDF5IO:
|
|||
"""
|
||||
from nwb_linkml.providers.schema import SchemaProvider
|
||||
|
||||
h5f = h5py.File(str(self.path))
|
||||
h5f = h5py.File(str(self.path), "r")
|
||||
schema = read_specs_as_dicts(h5f.get("specifications"))
|
||||
|
||||
# get versions for each namespace
|
||||
|
@ -260,7 +373,7 @@ class HDF5IO:
|
|||
provider = SchemaProvider(versions=versions)
|
||||
|
||||
# build schema so we have them cached
|
||||
provider.build_from_dicts(schema)
|
||||
# provider.build_from_dicts(schema)
|
||||
h5f.close()
|
||||
return provider
|
||||
|
||||
|
|
|
@ -233,66 +233,6 @@ class PruneEmpty(HDF5Map):
|
|||
return H5ReadResult.model_construct(path=src.path, source=src, completed=True)
|
||||
|
||||
|
||||
#
|
||||
# class ResolveDynamicTable(HDF5Map):
|
||||
# """
|
||||
# Handle loading a dynamic table!
|
||||
#
|
||||
# Dynamic tables are sort of odd in that their models don't include their fields
|
||||
# (except as a list of strings in ``colnames`` ),
|
||||
# so we need to create a new model that includes fields for each column,
|
||||
# and then we include the datasets as :class:`~numpydantic.interface.hdf5.H5ArrayPath`
|
||||
# objects which lazy load the arrays in a thread/process safe way.
|
||||
#
|
||||
# This map also resolves the child elements,
|
||||
# indicating so by the ``completes`` field in the :class:`.ReadResult`
|
||||
# """
|
||||
#
|
||||
# phase = ReadPhases.read
|
||||
# priority = 1
|
||||
#
|
||||
# @classmethod
|
||||
# def check(
|
||||
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
|
||||
# ) -> bool:
|
||||
# if src.h5_type == "dataset":
|
||||
# return False
|
||||
# if "neurodata_type" in src.attrs:
|
||||
# if src.attrs["neurodata_type"] == "DynamicTable":
|
||||
# return True
|
||||
# # otherwise, see if it's a subclass
|
||||
# model = provider.get_class(src.attrs["namespace"], src.attrs["neurodata_type"])
|
||||
# # just inspect the MRO as strings rather than trying to check subclasses because
|
||||
# # we might replace DynamicTable in the future, and there isn't a stable DynamicTable
|
||||
# # class to inherit from anyway because of the whole multiple versions thing
|
||||
# parents = [parent.__name__ for parent in model.__mro__]
|
||||
# return "DynamicTable" in parents
|
||||
# else:
|
||||
# return False
|
||||
#
|
||||
# @classmethod
|
||||
# def apply(
|
||||
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
|
||||
# ) -> H5ReadResult:
|
||||
# with h5py.File(src.h5f_path, "r") as h5f:
|
||||
# obj = h5f.get(src.path)
|
||||
#
|
||||
# # make a populated model :)
|
||||
# base_model = provider.get_class(src.namespace, src.neurodata_type)
|
||||
# model = dynamictable_to_model(obj, base=base_model)
|
||||
#
|
||||
# completes = [HDF5_Path(child.name) for child in obj.values()]
|
||||
#
|
||||
# return H5ReadResult(
|
||||
# path=src.path,
|
||||
# source=src,
|
||||
# result=model,
|
||||
# completes=completes,
|
||||
# completed=True,
|
||||
# applied=["ResolveDynamicTable"],
|
||||
# )
|
||||
|
||||
|
||||
class ResolveModelGroup(HDF5Map):
|
||||
"""
|
||||
HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs.
|
||||
|
|
|
@ -97,9 +97,9 @@ class Provider(ABC):
|
|||
module_path = Path(importlib.util.find_spec("nwb_models").origin).parent
|
||||
|
||||
if self.PROVIDES == "linkml":
|
||||
namespace_path = module_path / "schema" / "linkml" / namespace
|
||||
namespace_path = module_path / "schema" / "linkml" / namespace_module
|
||||
elif self.PROVIDES == "pydantic":
|
||||
namespace_path = module_path / "models" / "pydantic" / namespace
|
||||
namespace_path = module_path / "models" / "pydantic" / namespace_module
|
||||
|
||||
if version is not None:
|
||||
version_path = namespace_path / version_module_case(version)
|
||||
|
|
3
nwb_linkml/tests/fixtures/__init__.py
vendored
3
nwb_linkml/tests/fixtures/__init__.py
vendored
|
@ -1,4 +1,4 @@
|
|||
from .nwb import nwb_file
|
||||
from .nwb import nwb_file, nwb_file_base
|
||||
from .paths import data_dir, tmp_output_dir, tmp_output_dir_func, tmp_output_dir_mod
|
||||
from .schema import (
|
||||
NWBSchemaTest,
|
||||
|
@ -21,6 +21,7 @@ __all__ = [
|
|||
"nwb_core_linkml",
|
||||
"nwb_core_module",
|
||||
"nwb_file",
|
||||
"nwb_file_base",
|
||||
"nwb_schema",
|
||||
"tmp_output_dir",
|
||||
"tmp_output_dir_func",
|
||||
|
|
2
nwb_linkml/tests/fixtures/paths.py
vendored
2
nwb_linkml/tests/fixtures/paths.py
vendored
|
@ -6,7 +6,7 @@ import pytest
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def tmp_output_dir(request: pytest.FixtureRequest) -> Path:
|
||||
path = Path(__file__).parent.resolve() / "__tmp__"
|
||||
path = Path(__file__).parents[1].resolve() / "__tmp__"
|
||||
if path.exists():
|
||||
if request.config.getoption("--clean"):
|
||||
shutil.rmtree(path)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pdb
|
||||
|
||||
import h5py
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
@ -100,10 +101,15 @@ def test_flatten_hdf():
|
|||
raise NotImplementedError("Just a stub for local testing for now, finish me!")
|
||||
|
||||
|
||||
def test_dependency_graph(nwb_file):
|
||||
@pytest.mark.dev
|
||||
def test_dependency_graph(nwb_file, tmp_output_dir):
|
||||
"""
|
||||
dependency graph is correctly constructed from an HDF5 file
|
||||
"""
|
||||
graph = hdf_dependency_graph(nwb_file)
|
||||
A_unfiltered = nx.nx_agraph.to_agraph(graph)
|
||||
A_unfiltered.draw(tmp_output_dir / "test_nwb_unfiltered.png", prog="dot")
|
||||
graph = filter_dependency_graph(graph)
|
||||
A_filtered = nx.nx_agraph.to_agraph(graph)
|
||||
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
|
||||
pass
|
||||
|
|
|
@ -2,12 +2,14 @@
|
|||
Placeholder test module to test reading from pynwb-generated NWB file
|
||||
"""
|
||||
|
||||
from nwb_linkml.io.hdf5 import HDF5IO
|
||||
|
||||
|
||||
def test_read_from_nwbfile(nwb_file):
|
||||
"""
|
||||
Read data from a pynwb HDF5 NWB file
|
||||
"""
|
||||
pass
|
||||
res = HDF5IO(nwb_file).read()
|
||||
|
||||
|
||||
def test_read_from_yaml(nwb_file):
|
||||
|
|
Loading…
Reference in a new issue