checkpointing working on model loading. it's a sloggggggggg

This commit is contained in:
sneakers-the-rat 2024-09-03 00:54:56 -07:00
parent cd3d7ca78e
commit d1498a3733
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
11 changed files with 258 additions and 127 deletions

View file

@ -22,7 +22,7 @@ dependencies = [
"pydantic-settings>=2.0.3", "pydantic-settings>=2.0.3",
"tqdm>=4.66.1", "tqdm>=4.66.1",
'typing-extensions>=4.12.2;python_version<"3.11"', 'typing-extensions>=4.12.2;python_version<"3.11"',
"numpydantic>=1.3.3", "numpydantic>=1.5.0",
"black>=24.4.2", "black>=24.4.2",
"pandas>=2.2.2", "pandas>=2.2.2",
"networkx>=3.3", "networkx>=3.3",

View file

@ -10,7 +10,7 @@ import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import ClassVar, Dict, List, Optional, Tuple from typing import ClassVar, Dict, List, Optional, Tuple, Literal
from linkml.generators import PydanticGenerator from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
@ -27,7 +27,7 @@ from linkml_runtime.utils.compile_python import file_text
from linkml_runtime.utils.formatutils import remove_empty_items from linkml_runtime.utils.formatutils import remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView from linkml_runtime.utils.schemaview import SchemaView
from nwb_linkml.includes.base import BASEMODEL_GETITEM from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE
from nwb_linkml.includes.hdmf import ( from nwb_linkml.includes.hdmf import (
DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_IMPORTS,
DYNAMIC_TABLE_INJECTS, DYNAMIC_TABLE_INJECTS,
@ -52,6 +52,7 @@ class NWBPydanticGenerator(PydanticGenerator):
), ),
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")', 'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
BASEMODEL_GETITEM, BASEMODEL_GETITEM,
BASEMODEL_COERCE_VALUE,
) )
split: bool = True split: bool = True
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")]) imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
@ -66,6 +67,7 @@ class NWBPydanticGenerator(PydanticGenerator):
emit_metadata: bool = True emit_metadata: bool = True
gen_classvars: bool = True gen_classvars: bool = True
gen_slots: bool = True gen_slots: bool = True
extra_fields: Literal["allow", "forbid", "ignore"] = "allow"
skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias") skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias")
@ -131,6 +133,7 @@ class NWBPydanticGenerator(PydanticGenerator):
"""Customize dynamictable behavior""" """Customize dynamictable behavior"""
cls = AfterGenerateClass.inject_dynamictable(cls) cls = AfterGenerateClass.inject_dynamictable(cls)
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv) cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
return cls return cls
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
@ -255,7 +258,7 @@ class AfterGenerateClass:
""" """
if cls.cls.name in "DynamicTable": if cls.cls.name in "DynamicTable":
cls.cls.bases = ["DynamicTableMixin"] cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"]
if cls.injected_classes is None: if cls.injected_classes is None:
cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy() cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy()
@ -269,13 +272,21 @@ class AfterGenerateClass:
else: else:
cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy() cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy()
elif cls.cls.name == "VectorData": elif cls.cls.name == "VectorData":
cls.cls.bases = ["VectorDataMixin"] cls.cls.bases = ["VectorDataMixin", "ConfiguredBaseModel"]
# make ``value`` generic on T
if "value" in cls.cls.attributes:
cls.cls.attributes["value"].range = "Optional[T]"
elif cls.cls.name == "VectorIndex": elif cls.cls.name == "VectorIndex":
cls.cls.bases = ["VectorIndexMixin"] cls.cls.bases = ["VectorIndexMixin", "ConfiguredBaseModel"]
elif cls.cls.name == "DynamicTableRegion": elif cls.cls.name == "DynamicTableRegion":
cls.cls.bases = ["DynamicTableRegionMixin", "VectorData"] cls.cls.bases = ["DynamicTableRegionMixin", "VectorData", "ConfiguredBaseModel"]
elif cls.cls.name == "AlignedDynamicTable": elif cls.cls.name == "AlignedDynamicTable":
cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"] cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"]
elif cls.cls.name == "ElementIdentifiers":
cls.cls.bases = ["ElementIdentifiersMixin", "Data", "ConfiguredBaseModel"]
# make ``value`` generic on T
if "value" in cls.cls.attributes:
cls.cls.attributes["value"].range = "Optional[T]"
elif cls.cls.name == "TimeSeriesReferenceVectorData": elif cls.cls.name == "TimeSeriesReferenceVectorData":
# in core.nwb.base, so need to inject and import again # in core.nwb.base, so need to inject and import again
cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"] cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"]
@ -305,14 +316,31 @@ class AfterGenerateClass:
): ):
for an_attr in cls.cls.attributes: for an_attr in cls.cls.attributes:
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range): if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
if an_attr == "id":
cls.cls.attributes[an_attr].range = "ElementIdentifiers"
return cls
if an_attr.endswith("_index"): if an_attr.endswith("_index"):
cls.cls.attributes[an_attr].range = "".join( wrap_cls = "VectorIndex"
["VectorIndex[", slot_range, "]"]
)
else: else:
cls.cls.attributes[an_attr].range = "".join( wrap_cls = "VectorData"
["VectorData[", slot_range, "]"]
) cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"])
return cls
@staticmethod
def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult:
"""
Inject ElementIdentifiers into module that define dynamictables -
needed to handle ID columns
"""
if (
cls.source.is_a == "DynamicTable"
or "DynamicTable" in sv.class_ancestors(cls.source.name)
) and sv.schema.name != "hdmf-common.table":
imp = import_method("ElementIdentifiers")
cls.imports += [imp]
return cls return cls

View file

@ -12,3 +12,20 @@ BASEMODEL_GETITEM = """
else: else:
raise KeyError("No value or data field to index from") raise KeyError("No value or data field to index from")
""" """
BASEMODEL_COERCE_VALUE = """
@field_validator("*", mode="wrap")
@classmethod
def coerce_value(cls, v: Any, handler) -> Any:
\"\"\"Try to rescue instantiation by using the value field\"\"\"
try:
return handler(v)
except Exception as e1:
try:
if hasattr(v, "value"):
return handler(v.value)
else:
return handler(v["value"])
except Exception as e2:
raise e2 from e1
"""

View file

@ -253,6 +253,8 @@ class DynamicTableMixin(BaseModel):
else: else:
# add any columns not explicitly given an order at the end # add any columns not explicitly given an order at the end
colnames = model["colnames"].copy() colnames = model["colnames"].copy()
if isinstance(colnames, np.ndarray):
colnames = colnames.tolist()
colnames.extend( colnames.extend(
[ [
k k
@ -284,9 +286,13 @@ class DynamicTableMixin(BaseModel):
if not isinstance(val, (VectorData, VectorIndex)): if not isinstance(val, (VectorData, VectorIndex)):
try: try:
if key.endswith("_index"): if key.endswith("_index"):
model[key] = VectorIndex(name=key, description="", value=val) to_cast = VectorIndex
else: else:
model[key] = VectorData(name=key, description="", value=val) to_cast = VectorData
if isinstance(val, dict):
model[key] = to_cast(**val)
else:
model[key] = VectorIndex(name=key, description="", value=val)
except ValidationError as e: # pragma: no cover except ValidationError as e: # pragma: no cover
raise ValidationError.from_exception_data( raise ValidationError.from_exception_data(
title=f"field {key} cannot be cast to VectorData from {val}", title=f"field {key} cannot be cast to VectorData from {val}",
@ -379,10 +385,10 @@ class VectorDataMixin(BaseModel, Generic[T]):
# redefined in `VectorData`, but included here for testing and type checking # redefined in `VectorData`, but included here for testing and type checking
value: Optional[T] = None value: Optional[T] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs): # def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs: # if value is not None and "value" not in kwargs:
kwargs["value"] = value # kwargs["value"] = value
super().__init__(**kwargs) # super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index: if self._index:
@ -703,14 +709,19 @@ class AlignedDynamicTableMixin(BaseModel):
model["categories"] = categories model["categories"] = categories
else: else:
# add any columns not explicitly given an order at the end # add any columns not explicitly given an order at the end
categories = [ categories = model["categories"].copy()
if isinstance(categories, np.ndarray):
categories = categories.tolist()
categories.extend(
[
k k
for k in model for k in model
if k not in cls.NON_COLUMN_FIELDS if k not in cls.NON_CATEGORY_FIELDS
and not k.endswith("_index") and not k.endswith("_index")
and k not in model["categories"] and k not in model["categories"]
] ]
model["categories"].extend(categories) )
model["categories"] = categories
return model return model
@model_validator(mode="after") @model_validator(mode="after")
@ -839,6 +850,13 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
) )
class ElementIdentifiersMixin(VectorDataMixin):
"""
Mixin class for ElementIdentifiers - allow treating
as generic, and give general indexing methods from VectorData
"""
DYNAMIC_TABLE_IMPORTS = Imports( DYNAMIC_TABLE_IMPORTS = Imports(
imports=[ imports=[
Import(module="pandas", alias="pd"), Import(module="pandas", alias="pd"),
@ -882,6 +900,7 @@ DYNAMIC_TABLE_INJECTS = [
DynamicTableRegionMixin, DynamicTableRegionMixin,
DynamicTableMixin, DynamicTableMixin,
AlignedDynamicTableMixin, AlignedDynamicTableMixin,
ElementIdentifiersMixin,
] ]
TSRVD_IMPORTS = Imports( TSRVD_IMPORTS = Imports(
@ -923,3 +942,8 @@ if "pytest" in sys.modules:
"""TimeSeriesReferenceVectorData subclass for testing""" """TimeSeriesReferenceVectorData subclass for testing"""
pass pass
class ElementIdentifiers(ElementIdentifiersMixin):
"""ElementIdentifiers subclass for testing"""
pass

View file

@ -22,6 +22,7 @@ Other TODO:
import json import json
import os import os
import pdb
import re import re
import shutil import shutil
import subprocess import subprocess
@ -34,10 +35,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload
import h5py import h5py
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
from nwb_linkml.maps.hdf5 import ReadPhases, ReadQueue, flatten_hdf, get_references from nwb_linkml.maps.hdf5 import get_references
if TYPE_CHECKING: if TYPE_CHECKING:
from nwb_linkml.providers.schema import SchemaProvider from nwb_linkml.providers.schema import SchemaProvider
@ -49,7 +51,11 @@ else:
from typing_extensions import Never from typing_extensions import Never
def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph: SKIP_PATTERN = re.compile("^/specifications.*")
"""Nodes to always skip in reading e.g. because they are handled elsewhere"""
def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
""" """
Directed dependency graph of dataset and group nodes in an NWBFile such that Directed dependency graph of dataset and group nodes in an NWBFile such that
each node ``n_i`` is connected to node ``n_j`` if each node ``n_i`` is connected to node ``n_j`` if
@ -63,14 +69,15 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
* Dataset columns * Dataset columns
* Compound dtypes * Compound dtypes
Edges are labeled with ``reference`` or ``child`` depending on the type of edge it is,
and attributes from the hdf5 file are added as node attributes.
Args: Args:
h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph
Returns: Returns:
:class:`networkx.DiGraph` :class:`networkx.DiGraph`
""" """
# detect nodes to skip
skip_pattern = re.compile("^/specifications.*")
if isinstance(h5f, (Path, str)): if isinstance(h5f, (Path, str)):
h5f = h5py.File(h5f, "r") h5f = h5py.File(h5f, "r")
@ -78,17 +85,19 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
g = nx.DiGraph() g = nx.DiGraph()
def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None: def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None:
if skip_pattern.match(name): if SKIP_PATTERN.match(node.name):
return return
# find references in attributes # find references in attributes
refs = get_references(node) refs = get_references(node)
if isinstance(node, h5py.Group): # add edges from references
refs.extend([child.name for child in node.values()]) edges = [(node.name, ref) for ref in refs if not SKIP_PATTERN.match(ref)]
refs = set(refs) g.add_edges_from(edges, label="reference")
# add edges # add children, if group
edges = [(node.name, ref) for ref in refs] if isinstance(node, h5py.Group):
g.add_edges_from(edges) children = [child.name for child in node.values() if not SKIP_PATTERN.match(child.name)]
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="child")
# ensure node added to graph # ensure node added to graph
if len(edges) == 0: if len(edges) == 0:
@ -119,13 +128,125 @@ def filter_dependency_graph(g: nx.DiGraph) -> nx.DiGraph:
node: str node: str
for node in g.nodes: for node in g.nodes:
ndtype = g.nodes[node].get("neurodata_type", None) ndtype = g.nodes[node].get("neurodata_type", None)
if ndtype == "VectorData" or not ndtype and g.out_degree(node) == 0: if (ndtype is None and g.out_degree(node) == 0) or SKIP_PATTERN.match(node):
remove_nodes.append(node) remove_nodes.append(node)
g.remove_nodes_from(remove_nodes) g.remove_nodes_from(remove_nodes)
return g return g
def _load_node(
path: str, h5f: h5py.File, provider: "SchemaProvider", context: dict
) -> dict | BaseModel:
"""
Load an individual node in the graph, then removes it from the graph
Args:
path:
g:
context:
Returns:
"""
obj = h5f.get(path)
if isinstance(obj, h5py.Dataset):
args = _load_dataset(obj, h5f, context)
elif isinstance(obj, h5py.Group):
args = _load_group(obj, h5f, context)
else:
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
# if obj.name == "/general/intracellular_ephys/simultaneous_recordings/recordings":
# pdb.set_trace()
# resolve attr references
for k, v in args.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
args[k] = context[ref_path]
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
# add additional needed params
args["hdf5_path"] = path
args["name"] = path.split("/")[-1]
return model(**args)
def _load_dataset(
dataset: h5py.Dataset, h5f: h5py.File, context: dict
) -> Union[dict, str, int, float]:
"""
Resolves datasets that do not have a ``neurodata_type`` as a dictionary or a scalar.
If the dataset is a single value without attrs, load it and return as a scalar value.
Otherwise return a :class:`.H5ArrayPath` as a reference to the dataset in the `value` key.
"""
res = {}
if dataset.shape == ():
val = dataset[()]
if isinstance(val, h5py.h5r.Reference):
val = context.get(h5f[val].name)
# if this is just a scalar value, return it
if not dataset.attrs:
return val
res["value"] = val
elif len(dataset) > 0 and isinstance(dataset[0], h5py.h5r.Reference):
# vector of references
res["value"] = [context.get(h5f[ref].name) for ref in dataset[:]]
elif len(dataset.dtype) > 1:
# compound dataset - check if any of the fields are references
for name in dataset.dtype.names:
if isinstance(dataset[name][0], h5py.h5r.Reference):
res[name] = [context.get(h5f[ref].name) for ref in dataset[name]]
else:
res[name] = H5ArrayPath(h5f.filename, dataset.name, name)
else:
res["value"] = H5ArrayPath(h5f.filename, dataset.name)
res.update(dataset.attrs)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
res["name"] = dataset.name.split("/")[-1]
res["hdf5_path"] = dataset.name
if len(res) == 1:
return res["value"]
else:
return res
def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
"""
Load a group!
"""
res = {}
res.update(group.attrs)
for child_name, child in group.items():
if child.name in context:
res[child_name] = context[child.name]
elif isinstance(child, h5py.Dataset):
res[child_name] = _load_dataset(child, h5f, context)
elif isinstance(child, h5py.Group):
res[child_name] = _load_group(child, h5f, context)
else:
raise TypeError(
"Can only handle preinstantiated child objects in context, datasets, and group,"
f" got {child} for {child_name}"
)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
res["name"] = group.name.split("/")[-1]
res["hdf5_path"] = group.name
return res
class HDF5IO: class HDF5IO:
""" """
Read (and eventually write) from an NWB HDF5 file. Read (and eventually write) from an NWB HDF5 file.
@ -185,28 +306,20 @@ class HDF5IO:
h5f = h5py.File(str(self.path)) h5f = h5py.File(str(self.path))
src = h5f.get(path) if path else h5f src = h5f.get(path) if path else h5f
graph = hdf_dependency_graph(src)
graph = filter_dependency_graph(graph)
# get all children of selected item # topo sort to get read order
if isinstance(src, (h5py.File, h5py.Group)): # TODO: This could be parallelized using `topological_generations`,
children = flatten_hdf(src) # but it's not clear what the perf bonus would be because there are many generations
else: # with few items
raise NotImplementedError("directly read individual datasets") topo_order = list(reversed(list(nx.topological_sort(graph))))
context = {}
for node in topo_order:
res = _load_node(node, h5f, provider, context)
context[node] = res
queue = ReadQueue(h5f=self.path, queue=children, provider=provider) pdb.set_trace()
# Apply initial planning phase of reading
queue.apply_phase(ReadPhases.plan)
# Read operations gather the data before casting into models
queue.apply_phase(ReadPhases.read)
# Construction operations actually cast the models
# this often needs to run several times as models with dependencies wait for their
# dependents to be cast
queue.apply_phase(ReadPhases.construct)
if path is None:
return queue.completed["/"].result
else:
return queue.completed[path].result
def write(self, path: Path) -> Never: def write(self, path: Path) -> Never:
""" """
@ -246,7 +359,7 @@ class HDF5IO:
""" """
from nwb_linkml.providers.schema import SchemaProvider from nwb_linkml.providers.schema import SchemaProvider
h5f = h5py.File(str(self.path)) h5f = h5py.File(str(self.path), "r")
schema = read_specs_as_dicts(h5f.get("specifications")) schema = read_specs_as_dicts(h5f.get("specifications"))
# get versions for each namespace # get versions for each namespace
@ -260,7 +373,7 @@ class HDF5IO:
provider = SchemaProvider(versions=versions) provider = SchemaProvider(versions=versions)
# build schema so we have them cached # build schema so we have them cached
provider.build_from_dicts(schema) # provider.build_from_dicts(schema)
h5f.close() h5f.close()
return provider return provider

View file

@ -233,66 +233,6 @@ class PruneEmpty(HDF5Map):
return H5ReadResult.model_construct(path=src.path, source=src, completed=True) return H5ReadResult.model_construct(path=src.path, source=src, completed=True)
#
# class ResolveDynamicTable(HDF5Map):
# """
# Handle loading a dynamic table!
#
# Dynamic tables are sort of odd in that their models don't include their fields
# (except as a list of strings in ``colnames`` ),
# so we need to create a new model that includes fields for each column,
# and then we include the datasets as :class:`~numpydantic.interface.hdf5.H5ArrayPath`
# objects which lazy load the arrays in a thread/process safe way.
#
# This map also resolves the child elements,
# indicating so by the ``completes`` field in the :class:`.ReadResult`
# """
#
# phase = ReadPhases.read
# priority = 1
#
# @classmethod
# def check(
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
# ) -> bool:
# if src.h5_type == "dataset":
# return False
# if "neurodata_type" in src.attrs:
# if src.attrs["neurodata_type"] == "DynamicTable":
# return True
# # otherwise, see if it's a subclass
# model = provider.get_class(src.attrs["namespace"], src.attrs["neurodata_type"])
# # just inspect the MRO as strings rather than trying to check subclasses because
# # we might replace DynamicTable in the future, and there isn't a stable DynamicTable
# # class to inherit from anyway because of the whole multiple versions thing
# parents = [parent.__name__ for parent in model.__mro__]
# return "DynamicTable" in parents
# else:
# return False
#
# @classmethod
# def apply(
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
# ) -> H5ReadResult:
# with h5py.File(src.h5f_path, "r") as h5f:
# obj = h5f.get(src.path)
#
# # make a populated model :)
# base_model = provider.get_class(src.namespace, src.neurodata_type)
# model = dynamictable_to_model(obj, base=base_model)
#
# completes = [HDF5_Path(child.name) for child in obj.values()]
#
# return H5ReadResult(
# path=src.path,
# source=src,
# result=model,
# completes=completes,
# completed=True,
# applied=["ResolveDynamicTable"],
# )
class ResolveModelGroup(HDF5Map): class ResolveModelGroup(HDF5Map):
""" """
HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs. HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs.

View file

@ -97,9 +97,9 @@ class Provider(ABC):
module_path = Path(importlib.util.find_spec("nwb_models").origin).parent module_path = Path(importlib.util.find_spec("nwb_models").origin).parent
if self.PROVIDES == "linkml": if self.PROVIDES == "linkml":
namespace_path = module_path / "schema" / "linkml" / namespace namespace_path = module_path / "schema" / "linkml" / namespace_module
elif self.PROVIDES == "pydantic": elif self.PROVIDES == "pydantic":
namespace_path = module_path / "models" / "pydantic" / namespace namespace_path = module_path / "models" / "pydantic" / namespace_module
if version is not None: if version is not None:
version_path = namespace_path / version_module_case(version) version_path = namespace_path / version_module_case(version)

View file

@ -1,4 +1,4 @@
from .nwb import nwb_file from .nwb import nwb_file, nwb_file_base
from .paths import data_dir, tmp_output_dir, tmp_output_dir_func, tmp_output_dir_mod from .paths import data_dir, tmp_output_dir, tmp_output_dir_func, tmp_output_dir_mod
from .schema import ( from .schema import (
NWBSchemaTest, NWBSchemaTest,
@ -21,6 +21,7 @@ __all__ = [
"nwb_core_linkml", "nwb_core_linkml",
"nwb_core_module", "nwb_core_module",
"nwb_file", "nwb_file",
"nwb_file_base",
"nwb_schema", "nwb_schema",
"tmp_output_dir", "tmp_output_dir",
"tmp_output_dir_func", "tmp_output_dir_func",

View file

@ -6,7 +6,7 @@ import pytest
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tmp_output_dir(request: pytest.FixtureRequest) -> Path: def tmp_output_dir(request: pytest.FixtureRequest) -> Path:
path = Path(__file__).parent.resolve() / "__tmp__" path = Path(__file__).parents[1].resolve() / "__tmp__"
if path.exists(): if path.exists():
if request.config.getoption("--clean"): if request.config.getoption("--clean"):
shutil.rmtree(path) shutil.rmtree(path)

View file

@ -1,6 +1,7 @@
import pdb import pdb
import h5py import h5py
import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
@ -100,10 +101,15 @@ def test_flatten_hdf():
raise NotImplementedError("Just a stub for local testing for now, finish me!") raise NotImplementedError("Just a stub for local testing for now, finish me!")
def test_dependency_graph(nwb_file): @pytest.mark.dev
def test_dependency_graph(nwb_file, tmp_output_dir):
""" """
dependency graph is correctly constructed from an HDF5 file dependency graph is correctly constructed from an HDF5 file
""" """
graph = hdf_dependency_graph(nwb_file) graph = hdf_dependency_graph(nwb_file)
A_unfiltered = nx.nx_agraph.to_agraph(graph)
A_unfiltered.draw(tmp_output_dir / "test_nwb_unfiltered.png", prog="dot")
graph = filter_dependency_graph(graph) graph = filter_dependency_graph(graph)
A_filtered = nx.nx_agraph.to_agraph(graph)
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
pass pass

View file

@ -2,12 +2,14 @@
Placeholder test module to test reading from pynwb-generated NWB file Placeholder test module to test reading from pynwb-generated NWB file
""" """
from nwb_linkml.io.hdf5 import HDF5IO
def test_read_from_nwbfile(nwb_file): def test_read_from_nwbfile(nwb_file):
""" """
Read data from a pynwb HDF5 NWB file Read data from a pynwb HDF5 NWB file
""" """
pass res = HDF5IO(nwb_file).read()
def test_read_from_yaml(nwb_file): def test_read_from_yaml(nwb_file):