From d1498a37332d57279cc0772071db71937e5db6cd Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Tue, 3 Sep 2024 00:54:56 -0700 Subject: [PATCH] checkpointing working on model loading. it's a sloggggggggg --- nwb_linkml/pyproject.toml | 2 +- .../src/nwb_linkml/generators/pydantic.py | 52 +++-- nwb_linkml/src/nwb_linkml/includes/base.py | 17 ++ nwb_linkml/src/nwb_linkml/includes/hdmf.py | 52 +++-- nwb_linkml/src/nwb_linkml/io/hdf5.py | 181 ++++++++++++++---- nwb_linkml/src/nwb_linkml/maps/hdf5.py | 60 ------ .../src/nwb_linkml/providers/provider.py | 4 +- nwb_linkml/tests/fixtures/__init__.py | 3 +- nwb_linkml/tests/fixtures/paths.py | 2 +- nwb_linkml/tests/test_io/test_io_hdf5.py | 8 +- nwb_linkml/tests/test_io/test_io_nwb.py | 4 +- 11 files changed, 258 insertions(+), 127 deletions(-) diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 6e86158..5e00933 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "pydantic-settings>=2.0.3", "tqdm>=4.66.1", 'typing-extensions>=4.12.2;python_version<"3.11"', - "numpydantic>=1.3.3", + "numpydantic>=1.5.0", "black>=24.4.2", "pandas>=2.2.2", "networkx>=3.3", diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 0f824af..20619d2 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -10,7 +10,7 @@ import sys from dataclasses import dataclass, field from pathlib import Path from types import ModuleType -from typing import ClassVar, Dict, List, Optional, Tuple +from typing import ClassVar, Dict, List, Optional, Tuple, Literal from linkml.generators import PydanticGenerator from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray @@ -27,7 +27,7 @@ from linkml_runtime.utils.compile_python import file_text from linkml_runtime.utils.formatutils import remove_empty_items from linkml_runtime.utils.schemaview import SchemaView -from nwb_linkml.includes.base import BASEMODEL_GETITEM +from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE from nwb_linkml.includes.hdmf import ( DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_INJECTS, @@ -52,6 +52,7 @@ class NWBPydanticGenerator(PydanticGenerator): ), 'object_id: Optional[str] = Field(None, description="Unique UUID for each object")', BASEMODEL_GETITEM, + BASEMODEL_COERCE_VALUE, ) split: bool = True imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")]) @@ -66,6 +67,7 @@ class NWBPydanticGenerator(PydanticGenerator): emit_metadata: bool = True gen_classvars: bool = True gen_slots: bool = True + extra_fields: Literal["allow", "forbid", "ignore"] = "allow" skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias") @@ -131,6 +133,7 @@ class NWBPydanticGenerator(PydanticGenerator): """Customize dynamictable behavior""" cls = AfterGenerateClass.inject_dynamictable(cls) cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv) + cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import) return cls def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: @@ -255,7 +258,7 @@ class AfterGenerateClass: """ if cls.cls.name in "DynamicTable": - cls.cls.bases = ["DynamicTableMixin"] + cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"] if cls.injected_classes is None: cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy() @@ -269,13 +272,21 @@ class AfterGenerateClass: else: cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy() elif cls.cls.name == "VectorData": - cls.cls.bases = ["VectorDataMixin"] + cls.cls.bases = ["VectorDataMixin", "ConfiguredBaseModel"] + # make ``value`` generic on T + if "value" in cls.cls.attributes: + cls.cls.attributes["value"].range = "Optional[T]" elif cls.cls.name == "VectorIndex": - cls.cls.bases = ["VectorIndexMixin"] + cls.cls.bases = ["VectorIndexMixin", "ConfiguredBaseModel"] elif cls.cls.name == "DynamicTableRegion": - cls.cls.bases = ["DynamicTableRegionMixin", "VectorData"] + cls.cls.bases = ["DynamicTableRegionMixin", "VectorData", "ConfiguredBaseModel"] elif cls.cls.name == "AlignedDynamicTable": cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"] + elif cls.cls.name == "ElementIdentifiers": + cls.cls.bases = ["ElementIdentifiersMixin", "Data", "ConfiguredBaseModel"] + # make ``value`` generic on T + if "value" in cls.cls.attributes: + cls.cls.attributes["value"].range = "Optional[T]" elif cls.cls.name == "TimeSeriesReferenceVectorData": # in core.nwb.base, so need to inject and import again cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"] @@ -305,14 +316,31 @@ class AfterGenerateClass: ): for an_attr in cls.cls.attributes: if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range): + if an_attr == "id": + cls.cls.attributes[an_attr].range = "ElementIdentifiers" + return cls + if an_attr.endswith("_index"): - cls.cls.attributes[an_attr].range = "".join( - ["VectorIndex[", slot_range, "]"] - ) + wrap_cls = "VectorIndex" else: - cls.cls.attributes[an_attr].range = "".join( - ["VectorData[", slot_range, "]"] - ) + wrap_cls = "VectorData" + + cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"]) + + return cls + + @staticmethod + def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult: + """ + Inject ElementIdentifiers into module that define dynamictables - + needed to handle ID columns + """ + if ( + cls.source.is_a == "DynamicTable" + or "DynamicTable" in sv.class_ancestors(cls.source.name) + ) and sv.schema.name != "hdmf-common.table": + imp = import_method("ElementIdentifiers") + cls.imports += [imp] return cls diff --git a/nwb_linkml/src/nwb_linkml/includes/base.py b/nwb_linkml/src/nwb_linkml/includes/base.py index ed69bf3..d3ad3f7 100644 --- a/nwb_linkml/src/nwb_linkml/includes/base.py +++ b/nwb_linkml/src/nwb_linkml/includes/base.py @@ -12,3 +12,20 @@ BASEMODEL_GETITEM = """ else: raise KeyError("No value or data field to index from") """ + +BASEMODEL_COERCE_VALUE = """ + @field_validator("*", mode="wrap") + @classmethod + def coerce_value(cls, v: Any, handler) -> Any: + \"\"\"Try to rescue instantiation by using the value field\"\"\" + try: + return handler(v) + except Exception as e1: + try: + if hasattr(v, "value"): + return handler(v.value) + else: + return handler(v["value"]) + except Exception as e2: + raise e2 from e1 +""" diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 631ad67..573a9c1 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -253,6 +253,8 @@ class DynamicTableMixin(BaseModel): else: # add any columns not explicitly given an order at the end colnames = model["colnames"].copy() + if isinstance(colnames, np.ndarray): + colnames = colnames.tolist() colnames.extend( [ k @@ -284,9 +286,13 @@ class DynamicTableMixin(BaseModel): if not isinstance(val, (VectorData, VectorIndex)): try: if key.endswith("_index"): - model[key] = VectorIndex(name=key, description="", value=val) + to_cast = VectorIndex else: - model[key] = VectorData(name=key, description="", value=val) + to_cast = VectorData + if isinstance(val, dict): + model[key] = to_cast(**val) + else: + model[key] = VectorIndex(name=key, description="", value=val) except ValidationError as e: # pragma: no cover raise ValidationError.from_exception_data( title=f"field {key} cannot be cast to VectorData from {val}", @@ -379,10 +385,10 @@ class VectorDataMixin(BaseModel, Generic[T]): # redefined in `VectorData`, but included here for testing and type checking value: Optional[T] = None - def __init__(self, value: Optional[NDArray] = None, **kwargs): - if value is not None and "value" not in kwargs: - kwargs["value"] = value - super().__init__(**kwargs) + # def __init__(self, value: Optional[NDArray] = None, **kwargs): + # if value is not None and "value" not in kwargs: + # kwargs["value"] = value + # super().__init__(**kwargs) def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: if self._index: @@ -703,14 +709,19 @@ class AlignedDynamicTableMixin(BaseModel): model["categories"] = categories else: # add any columns not explicitly given an order at the end - categories = [ - k - for k in model - if k not in cls.NON_COLUMN_FIELDS - and not k.endswith("_index") - and k not in model["categories"] - ] - model["categories"].extend(categories) + categories = model["categories"].copy() + if isinstance(categories, np.ndarray): + categories = categories.tolist() + categories.extend( + [ + k + for k in model + if k not in cls.NON_CATEGORY_FIELDS + and not k.endswith("_index") + and k not in model["categories"] + ] + ) + model["categories"] = categories return model @model_validator(mode="after") @@ -839,6 +850,13 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin): ) +class ElementIdentifiersMixin(VectorDataMixin): + """ + Mixin class for ElementIdentifiers - allow treating + as generic, and give general indexing methods from VectorData + """ + + DYNAMIC_TABLE_IMPORTS = Imports( imports=[ Import(module="pandas", alias="pd"), @@ -882,6 +900,7 @@ DYNAMIC_TABLE_INJECTS = [ DynamicTableRegionMixin, DynamicTableMixin, AlignedDynamicTableMixin, + ElementIdentifiersMixin, ] TSRVD_IMPORTS = Imports( @@ -923,3 +942,8 @@ if "pytest" in sys.modules: """TimeSeriesReferenceVectorData subclass for testing""" pass + + class ElementIdentifiers(ElementIdentifiersMixin): + """ElementIdentifiers subclass for testing""" + + pass diff --git a/nwb_linkml/src/nwb_linkml/io/hdf5.py b/nwb_linkml/src/nwb_linkml/io/hdf5.py index ba1d017..d7692d7 100644 --- a/nwb_linkml/src/nwb_linkml/io/hdf5.py +++ b/nwb_linkml/src/nwb_linkml/io/hdf5.py @@ -22,6 +22,7 @@ Other TODO: import json import os +import pdb import re import shutil import subprocess @@ -34,10 +35,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload import h5py import networkx as nx import numpy as np +from numpydantic.interface.hdf5 import H5ArrayPath from pydantic import BaseModel from tqdm import tqdm -from nwb_linkml.maps.hdf5 import ReadPhases, ReadQueue, flatten_hdf, get_references +from nwb_linkml.maps.hdf5 import get_references if TYPE_CHECKING: from nwb_linkml.providers.schema import SchemaProvider @@ -49,7 +51,11 @@ else: from typing_extensions import Never -def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph: +SKIP_PATTERN = re.compile("^/specifications.*") +"""Nodes to always skip in reading e.g. because they are handled elsewhere""" + + +def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph: """ Directed dependency graph of dataset and group nodes in an NWBFile such that each node ``n_i`` is connected to node ``n_j`` if @@ -63,14 +69,15 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph: * Dataset columns * Compound dtypes + Edges are labeled with ``reference`` or ``child`` depending on the type of edge it is, + and attributes from the hdf5 file are added as node attributes. + Args: h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph Returns: :class:`networkx.DiGraph` """ - # detect nodes to skip - skip_pattern = re.compile("^/specifications.*") if isinstance(h5f, (Path, str)): h5f = h5py.File(h5f, "r") @@ -78,17 +85,19 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph: g = nx.DiGraph() def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None: - if skip_pattern.match(name): + if SKIP_PATTERN.match(node.name): return # find references in attributes refs = get_references(node) - if isinstance(node, h5py.Group): - refs.extend([child.name for child in node.values()]) - refs = set(refs) + # add edges from references + edges = [(node.name, ref) for ref in refs if not SKIP_PATTERN.match(ref)] + g.add_edges_from(edges, label="reference") - # add edges - edges = [(node.name, ref) for ref in refs] - g.add_edges_from(edges) + # add children, if group + if isinstance(node, h5py.Group): + children = [child.name for child in node.values() if not SKIP_PATTERN.match(child.name)] + edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)] + g.add_edges_from(edges, label="child") # ensure node added to graph if len(edges) == 0: @@ -119,13 +128,125 @@ def filter_dependency_graph(g: nx.DiGraph) -> nx.DiGraph: node: str for node in g.nodes: ndtype = g.nodes[node].get("neurodata_type", None) - if ndtype == "VectorData" or not ndtype and g.out_degree(node) == 0: + if (ndtype is None and g.out_degree(node) == 0) or SKIP_PATTERN.match(node): remove_nodes.append(node) g.remove_nodes_from(remove_nodes) return g +def _load_node( + path: str, h5f: h5py.File, provider: "SchemaProvider", context: dict +) -> dict | BaseModel: + """ + Load an individual node in the graph, then removes it from the graph + Args: + path: + g: + context: + + Returns: + + """ + obj = h5f.get(path) + + if isinstance(obj, h5py.Dataset): + args = _load_dataset(obj, h5f, context) + elif isinstance(obj, h5py.Group): + args = _load_group(obj, h5f, context) + else: + raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}") + + # if obj.name == "/general/intracellular_ephys/simultaneous_recordings/recordings": + # pdb.set_trace() + + # resolve attr references + for k, v in args.items(): + if isinstance(v, h5py.h5r.Reference): + ref_path = h5f[v].name + args[k] = context[ref_path] + + model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"]) + + # add additional needed params + args["hdf5_path"] = path + args["name"] = path.split("/")[-1] + return model(**args) + + +def _load_dataset( + dataset: h5py.Dataset, h5f: h5py.File, context: dict +) -> Union[dict, str, int, float]: + """ + Resolves datasets that do not have a ``neurodata_type`` as a dictionary or a scalar. + + If the dataset is a single value without attrs, load it and return as a scalar value. + Otherwise return a :class:`.H5ArrayPath` as a reference to the dataset in the `value` key. + """ + res = {} + if dataset.shape == (): + val = dataset[()] + if isinstance(val, h5py.h5r.Reference): + val = context.get(h5f[val].name) + # if this is just a scalar value, return it + if not dataset.attrs: + return val + + res["value"] = val + elif len(dataset) > 0 and isinstance(dataset[0], h5py.h5r.Reference): + # vector of references + res["value"] = [context.get(h5f[ref].name) for ref in dataset[:]] + elif len(dataset.dtype) > 1: + # compound dataset - check if any of the fields are references + for name in dataset.dtype.names: + if isinstance(dataset[name][0], h5py.h5r.Reference): + res[name] = [context.get(h5f[ref].name) for ref in dataset[name]] + else: + res[name] = H5ArrayPath(h5f.filename, dataset.name, name) + else: + res["value"] = H5ArrayPath(h5f.filename, dataset.name) + + res.update(dataset.attrs) + if "namespace" in res: + del res["namespace"] + if "neurodata_type" in res: + del res["neurodata_type"] + res["name"] = dataset.name.split("/")[-1] + res["hdf5_path"] = dataset.name + + if len(res) == 1: + return res["value"] + else: + return res + + +def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict: + """ + Load a group! + """ + res = {} + res.update(group.attrs) + for child_name, child in group.items(): + if child.name in context: + res[child_name] = context[child.name] + elif isinstance(child, h5py.Dataset): + res[child_name] = _load_dataset(child, h5f, context) + elif isinstance(child, h5py.Group): + res[child_name] = _load_group(child, h5f, context) + else: + raise TypeError( + "Can only handle preinstantiated child objects in context, datasets, and group," + f" got {child} for {child_name}" + ) + if "namespace" in res: + del res["namespace"] + if "neurodata_type" in res: + del res["neurodata_type"] + res["name"] = group.name.split("/")[-1] + res["hdf5_path"] = group.name + return res + + class HDF5IO: """ Read (and eventually write) from an NWB HDF5 file. @@ -185,28 +306,20 @@ class HDF5IO: h5f = h5py.File(str(self.path)) src = h5f.get(path) if path else h5f + graph = hdf_dependency_graph(src) + graph = filter_dependency_graph(graph) - # get all children of selected item - if isinstance(src, (h5py.File, h5py.Group)): - children = flatten_hdf(src) - else: - raise NotImplementedError("directly read individual datasets") + # topo sort to get read order + # TODO: This could be parallelized using `topological_generations`, + # but it's not clear what the perf bonus would be because there are many generations + # with few items + topo_order = list(reversed(list(nx.topological_sort(graph)))) + context = {} + for node in topo_order: + res = _load_node(node, h5f, provider, context) + context[node] = res - queue = ReadQueue(h5f=self.path, queue=children, provider=provider) - - # Apply initial planning phase of reading - queue.apply_phase(ReadPhases.plan) - # Read operations gather the data before casting into models - queue.apply_phase(ReadPhases.read) - # Construction operations actually cast the models - # this often needs to run several times as models with dependencies wait for their - # dependents to be cast - queue.apply_phase(ReadPhases.construct) - - if path is None: - return queue.completed["/"].result - else: - return queue.completed[path].result + pdb.set_trace() def write(self, path: Path) -> Never: """ @@ -246,7 +359,7 @@ class HDF5IO: """ from nwb_linkml.providers.schema import SchemaProvider - h5f = h5py.File(str(self.path)) + h5f = h5py.File(str(self.path), "r") schema = read_specs_as_dicts(h5f.get("specifications")) # get versions for each namespace @@ -260,7 +373,7 @@ class HDF5IO: provider = SchemaProvider(versions=versions) # build schema so we have them cached - provider.build_from_dicts(schema) + # provider.build_from_dicts(schema) h5f.close() return provider diff --git a/nwb_linkml/src/nwb_linkml/maps/hdf5.py b/nwb_linkml/src/nwb_linkml/maps/hdf5.py index e554dc3..0299136 100644 --- a/nwb_linkml/src/nwb_linkml/maps/hdf5.py +++ b/nwb_linkml/src/nwb_linkml/maps/hdf5.py @@ -233,66 +233,6 @@ class PruneEmpty(HDF5Map): return H5ReadResult.model_construct(path=src.path, source=src, completed=True) -# -# class ResolveDynamicTable(HDF5Map): -# """ -# Handle loading a dynamic table! -# -# Dynamic tables are sort of odd in that their models don't include their fields -# (except as a list of strings in ``colnames`` ), -# so we need to create a new model that includes fields for each column, -# and then we include the datasets as :class:`~numpydantic.interface.hdf5.H5ArrayPath` -# objects which lazy load the arrays in a thread/process safe way. -# -# This map also resolves the child elements, -# indicating so by the ``completes`` field in the :class:`.ReadResult` -# """ -# -# phase = ReadPhases.read -# priority = 1 -# -# @classmethod -# def check( -# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult] -# ) -> bool: -# if src.h5_type == "dataset": -# return False -# if "neurodata_type" in src.attrs: -# if src.attrs["neurodata_type"] == "DynamicTable": -# return True -# # otherwise, see if it's a subclass -# model = provider.get_class(src.attrs["namespace"], src.attrs["neurodata_type"]) -# # just inspect the MRO as strings rather than trying to check subclasses because -# # we might replace DynamicTable in the future, and there isn't a stable DynamicTable -# # class to inherit from anyway because of the whole multiple versions thing -# parents = [parent.__name__ for parent in model.__mro__] -# return "DynamicTable" in parents -# else: -# return False -# -# @classmethod -# def apply( -# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult] -# ) -> H5ReadResult: -# with h5py.File(src.h5f_path, "r") as h5f: -# obj = h5f.get(src.path) -# -# # make a populated model :) -# base_model = provider.get_class(src.namespace, src.neurodata_type) -# model = dynamictable_to_model(obj, base=base_model) -# -# completes = [HDF5_Path(child.name) for child in obj.values()] -# -# return H5ReadResult( -# path=src.path, -# source=src, -# result=model, -# completes=completes, -# completed=True, -# applied=["ResolveDynamicTable"], -# ) - - class ResolveModelGroup(HDF5Map): """ HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs. diff --git a/nwb_linkml/src/nwb_linkml/providers/provider.py b/nwb_linkml/src/nwb_linkml/providers/provider.py index 87f6567..ff349af 100644 --- a/nwb_linkml/src/nwb_linkml/providers/provider.py +++ b/nwb_linkml/src/nwb_linkml/providers/provider.py @@ -97,9 +97,9 @@ class Provider(ABC): module_path = Path(importlib.util.find_spec("nwb_models").origin).parent if self.PROVIDES == "linkml": - namespace_path = module_path / "schema" / "linkml" / namespace + namespace_path = module_path / "schema" / "linkml" / namespace_module elif self.PROVIDES == "pydantic": - namespace_path = module_path / "models" / "pydantic" / namespace + namespace_path = module_path / "models" / "pydantic" / namespace_module if version is not None: version_path = namespace_path / version_module_case(version) diff --git a/nwb_linkml/tests/fixtures/__init__.py b/nwb_linkml/tests/fixtures/__init__.py index e0ff5bd..f135929 100644 --- a/nwb_linkml/tests/fixtures/__init__.py +++ b/nwb_linkml/tests/fixtures/__init__.py @@ -1,4 +1,4 @@ -from .nwb import nwb_file +from .nwb import nwb_file, nwb_file_base from .paths import data_dir, tmp_output_dir, tmp_output_dir_func, tmp_output_dir_mod from .schema import ( NWBSchemaTest, @@ -21,6 +21,7 @@ __all__ = [ "nwb_core_linkml", "nwb_core_module", "nwb_file", + "nwb_file_base", "nwb_schema", "tmp_output_dir", "tmp_output_dir_func", diff --git a/nwb_linkml/tests/fixtures/paths.py b/nwb_linkml/tests/fixtures/paths.py index d81304a..d7f5f0c 100644 --- a/nwb_linkml/tests/fixtures/paths.py +++ b/nwb_linkml/tests/fixtures/paths.py @@ -6,7 +6,7 @@ import pytest @pytest.fixture(scope="session") def tmp_output_dir(request: pytest.FixtureRequest) -> Path: - path = Path(__file__).parent.resolve() / "__tmp__" + path = Path(__file__).parents[1].resolve() / "__tmp__" if path.exists(): if request.config.getoption("--clean"): shutil.rmtree(path) diff --git a/nwb_linkml/tests/test_io/test_io_hdf5.py b/nwb_linkml/tests/test_io/test_io_hdf5.py index 2d587e5..0f71e04 100644 --- a/nwb_linkml/tests/test_io/test_io_hdf5.py +++ b/nwb_linkml/tests/test_io/test_io_hdf5.py @@ -1,6 +1,7 @@ import pdb import h5py +import networkx as nx import numpy as np import pytest @@ -100,10 +101,15 @@ def test_flatten_hdf(): raise NotImplementedError("Just a stub for local testing for now, finish me!") -def test_dependency_graph(nwb_file): +@pytest.mark.dev +def test_dependency_graph(nwb_file, tmp_output_dir): """ dependency graph is correctly constructed from an HDF5 file """ graph = hdf_dependency_graph(nwb_file) + A_unfiltered = nx.nx_agraph.to_agraph(graph) + A_unfiltered.draw(tmp_output_dir / "test_nwb_unfiltered.png", prog="dot") graph = filter_dependency_graph(graph) + A_filtered = nx.nx_agraph.to_agraph(graph) + A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot") pass diff --git a/nwb_linkml/tests/test_io/test_io_nwb.py b/nwb_linkml/tests/test_io/test_io_nwb.py index a6eb230..54f4d0f 100644 --- a/nwb_linkml/tests/test_io/test_io_nwb.py +++ b/nwb_linkml/tests/test_io/test_io_nwb.py @@ -2,12 +2,14 @@ Placeholder test module to test reading from pynwb-generated NWB file """ +from nwb_linkml.io.hdf5 import HDF5IO + def test_read_from_nwbfile(nwb_file): """ Read data from a pynwb HDF5 NWB file """ - pass + res = HDF5IO(nwb_file).read() def test_read_from_yaml(nwb_file):