checkpointing working on model loading. it's a sloggggggggg

This commit is contained in:
sneakers-the-rat 2024-09-03 00:54:56 -07:00
parent cd3d7ca78e
commit d1498a3733
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
11 changed files with 258 additions and 127 deletions

View file

@ -22,7 +22,7 @@ dependencies = [
"pydantic-settings>=2.0.3",
"tqdm>=4.66.1",
'typing-extensions>=4.12.2;python_version<"3.11"',
"numpydantic>=1.3.3",
"numpydantic>=1.5.0",
"black>=24.4.2",
"pandas>=2.2.2",
"networkx>=3.3",

View file

@ -10,7 +10,7 @@ import sys
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from typing import ClassVar, Dict, List, Optional, Tuple
from typing import ClassVar, Dict, List, Optional, Tuple, Literal
from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
@ -27,7 +27,7 @@ from linkml_runtime.utils.compile_python import file_text
from linkml_runtime.utils.formatutils import remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView
from nwb_linkml.includes.base import BASEMODEL_GETITEM
from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE
from nwb_linkml.includes.hdmf import (
DYNAMIC_TABLE_IMPORTS,
DYNAMIC_TABLE_INJECTS,
@ -52,6 +52,7 @@ class NWBPydanticGenerator(PydanticGenerator):
),
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
BASEMODEL_GETITEM,
BASEMODEL_COERCE_VALUE,
)
split: bool = True
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
@ -66,6 +67,7 @@ class NWBPydanticGenerator(PydanticGenerator):
emit_metadata: bool = True
gen_classvars: bool = True
gen_slots: bool = True
extra_fields: Literal["allow", "forbid", "ignore"] = "allow"
skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias")
@ -131,6 +133,7 @@ class NWBPydanticGenerator(PydanticGenerator):
"""Customize dynamictable behavior"""
cls = AfterGenerateClass.inject_dynamictable(cls)
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
return cls
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
@ -255,7 +258,7 @@ class AfterGenerateClass:
"""
if cls.cls.name in "DynamicTable":
cls.cls.bases = ["DynamicTableMixin"]
cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"]
if cls.injected_classes is None:
cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy()
@ -269,13 +272,21 @@ class AfterGenerateClass:
else:
cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy()
elif cls.cls.name == "VectorData":
cls.cls.bases = ["VectorDataMixin"]
cls.cls.bases = ["VectorDataMixin", "ConfiguredBaseModel"]
# make ``value`` generic on T
if "value" in cls.cls.attributes:
cls.cls.attributes["value"].range = "Optional[T]"
elif cls.cls.name == "VectorIndex":
cls.cls.bases = ["VectorIndexMixin"]
cls.cls.bases = ["VectorIndexMixin", "ConfiguredBaseModel"]
elif cls.cls.name == "DynamicTableRegion":
cls.cls.bases = ["DynamicTableRegionMixin", "VectorData"]
cls.cls.bases = ["DynamicTableRegionMixin", "VectorData", "ConfiguredBaseModel"]
elif cls.cls.name == "AlignedDynamicTable":
cls.cls.bases = ["AlignedDynamicTableMixin", "DynamicTable"]
elif cls.cls.name == "ElementIdentifiers":
cls.cls.bases = ["ElementIdentifiersMixin", "Data", "ConfiguredBaseModel"]
# make ``value`` generic on T
if "value" in cls.cls.attributes:
cls.cls.attributes["value"].range = "Optional[T]"
elif cls.cls.name == "TimeSeriesReferenceVectorData":
# in core.nwb.base, so need to inject and import again
cls.cls.bases = ["TimeSeriesReferenceVectorDataMixin", "VectorData"]
@ -305,14 +316,31 @@ class AfterGenerateClass:
):
for an_attr in cls.cls.attributes:
if "NDArray" in (slot_range := cls.cls.attributes[an_attr].range):
if an_attr == "id":
cls.cls.attributes[an_attr].range = "ElementIdentifiers"
return cls
if an_attr.endswith("_index"):
cls.cls.attributes[an_attr].range = "".join(
["VectorIndex[", slot_range, "]"]
)
wrap_cls = "VectorIndex"
else:
cls.cls.attributes[an_attr].range = "".join(
["VectorData[", slot_range, "]"]
)
wrap_cls = "VectorData"
cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"])
return cls
@staticmethod
def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult:
"""
Inject ElementIdentifiers into module that define dynamictables -
needed to handle ID columns
"""
if (
cls.source.is_a == "DynamicTable"
or "DynamicTable" in sv.class_ancestors(cls.source.name)
) and sv.schema.name != "hdmf-common.table":
imp = import_method("ElementIdentifiers")
cls.imports += [imp]
return cls

View file

@ -12,3 +12,20 @@ BASEMODEL_GETITEM = """
else:
raise KeyError("No value or data field to index from")
"""
BASEMODEL_COERCE_VALUE = """
@field_validator("*", mode="wrap")
@classmethod
def coerce_value(cls, v: Any, handler) -> Any:
\"\"\"Try to rescue instantiation by using the value field\"\"\"
try:
return handler(v)
except Exception as e1:
try:
if hasattr(v, "value"):
return handler(v.value)
else:
return handler(v["value"])
except Exception as e2:
raise e2 from e1
"""

View file

@ -253,6 +253,8 @@ class DynamicTableMixin(BaseModel):
else:
# add any columns not explicitly given an order at the end
colnames = model["colnames"].copy()
if isinstance(colnames, np.ndarray):
colnames = colnames.tolist()
colnames.extend(
[
k
@ -284,9 +286,13 @@ class DynamicTableMixin(BaseModel):
if not isinstance(val, (VectorData, VectorIndex)):
try:
if key.endswith("_index"):
model[key] = VectorIndex(name=key, description="", value=val)
to_cast = VectorIndex
else:
model[key] = VectorData(name=key, description="", value=val)
to_cast = VectorData
if isinstance(val, dict):
model[key] = to_cast(**val)
else:
model[key] = VectorIndex(name=key, description="", value=val)
except ValidationError as e: # pragma: no cover
raise ValidationError.from_exception_data(
title=f"field {key} cannot be cast to VectorData from {val}",
@ -379,10 +385,10 @@ class VectorDataMixin(BaseModel, Generic[T]):
# redefined in `VectorData`, but included here for testing and type checking
value: Optional[T] = None
def __init__(self, value: Optional[NDArray] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
# def __init__(self, value: Optional[NDArray] = None, **kwargs):
# if value is not None and "value" not in kwargs:
# kwargs["value"] = value
# super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index:
@ -703,14 +709,19 @@ class AlignedDynamicTableMixin(BaseModel):
model["categories"] = categories
else:
# add any columns not explicitly given an order at the end
categories = [
categories = model["categories"].copy()
if isinstance(categories, np.ndarray):
categories = categories.tolist()
categories.extend(
[
k
for k in model
if k not in cls.NON_COLUMN_FIELDS
if k not in cls.NON_CATEGORY_FIELDS
and not k.endswith("_index")
and k not in model["categories"]
]
model["categories"].extend(categories)
)
model["categories"] = categories
return model
@model_validator(mode="after")
@ -839,6 +850,13 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
)
class ElementIdentifiersMixin(VectorDataMixin):
"""
Mixin class for ElementIdentifiers - allow treating
as generic, and give general indexing methods from VectorData
"""
DYNAMIC_TABLE_IMPORTS = Imports(
imports=[
Import(module="pandas", alias="pd"),
@ -882,6 +900,7 @@ DYNAMIC_TABLE_INJECTS = [
DynamicTableRegionMixin,
DynamicTableMixin,
AlignedDynamicTableMixin,
ElementIdentifiersMixin,
]
TSRVD_IMPORTS = Imports(
@ -923,3 +942,8 @@ if "pytest" in sys.modules:
"""TimeSeriesReferenceVectorData subclass for testing"""
pass
class ElementIdentifiers(ElementIdentifiersMixin):
"""ElementIdentifiers subclass for testing"""
pass

View file

@ -22,6 +22,7 @@ Other TODO:
import json
import os
import pdb
import re
import shutil
import subprocess
@ -34,10 +35,11 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload
import h5py
import networkx as nx
import numpy as np
from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel
from tqdm import tqdm
from nwb_linkml.maps.hdf5 import ReadPhases, ReadQueue, flatten_hdf, get_references
from nwb_linkml.maps.hdf5 import get_references
if TYPE_CHECKING:
from nwb_linkml.providers.schema import SchemaProvider
@ -49,7 +51,11 @@ else:
from typing_extensions import Never
def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
SKIP_PATTERN = re.compile("^/specifications.*")
"""Nodes to always skip in reading e.g. because they are handled elsewhere"""
def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
"""
Directed dependency graph of dataset and group nodes in an NWBFile such that
each node ``n_i`` is connected to node ``n_j`` if
@ -63,14 +69,15 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
* Dataset columns
* Compound dtypes
Edges are labeled with ``reference`` or ``child`` depending on the type of edge it is,
and attributes from the hdf5 file are added as node attributes.
Args:
h5f (:class:`pathlib.Path` | :class:`h5py.File`): NWB file to graph
Returns:
:class:`networkx.DiGraph`
"""
# detect nodes to skip
skip_pattern = re.compile("^/specifications.*")
if isinstance(h5f, (Path, str)):
h5f = h5py.File(h5f, "r")
@ -78,17 +85,19 @@ def hdf_dependency_graph(h5f: Path | h5py.File) -> nx.DiGraph:
g = nx.DiGraph()
def _visit_item(name: str, node: h5py.Dataset | h5py.Group) -> None:
if skip_pattern.match(name):
if SKIP_PATTERN.match(node.name):
return
# find references in attributes
refs = get_references(node)
if isinstance(node, h5py.Group):
refs.extend([child.name for child in node.values()])
refs = set(refs)
# add edges from references
edges = [(node.name, ref) for ref in refs if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="reference")
# add edges
edges = [(node.name, ref) for ref in refs]
g.add_edges_from(edges)
# add children, if group
if isinstance(node, h5py.Group):
children = [child.name for child in node.values() if not SKIP_PATTERN.match(child.name)]
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="child")
# ensure node added to graph
if len(edges) == 0:
@ -119,13 +128,125 @@ def filter_dependency_graph(g: nx.DiGraph) -> nx.DiGraph:
node: str
for node in g.nodes:
ndtype = g.nodes[node].get("neurodata_type", None)
if ndtype == "VectorData" or not ndtype and g.out_degree(node) == 0:
if (ndtype is None and g.out_degree(node) == 0) or SKIP_PATTERN.match(node):
remove_nodes.append(node)
g.remove_nodes_from(remove_nodes)
return g
def _load_node(
path: str, h5f: h5py.File, provider: "SchemaProvider", context: dict
) -> dict | BaseModel:
"""
Load an individual node in the graph, then removes it from the graph
Args:
path:
g:
context:
Returns:
"""
obj = h5f.get(path)
if isinstance(obj, h5py.Dataset):
args = _load_dataset(obj, h5f, context)
elif isinstance(obj, h5py.Group):
args = _load_group(obj, h5f, context)
else:
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
# if obj.name == "/general/intracellular_ephys/simultaneous_recordings/recordings":
# pdb.set_trace()
# resolve attr references
for k, v in args.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
args[k] = context[ref_path]
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
# add additional needed params
args["hdf5_path"] = path
args["name"] = path.split("/")[-1]
return model(**args)
def _load_dataset(
dataset: h5py.Dataset, h5f: h5py.File, context: dict
) -> Union[dict, str, int, float]:
"""
Resolves datasets that do not have a ``neurodata_type`` as a dictionary or a scalar.
If the dataset is a single value without attrs, load it and return as a scalar value.
Otherwise return a :class:`.H5ArrayPath` as a reference to the dataset in the `value` key.
"""
res = {}
if dataset.shape == ():
val = dataset[()]
if isinstance(val, h5py.h5r.Reference):
val = context.get(h5f[val].name)
# if this is just a scalar value, return it
if not dataset.attrs:
return val
res["value"] = val
elif len(dataset) > 0 and isinstance(dataset[0], h5py.h5r.Reference):
# vector of references
res["value"] = [context.get(h5f[ref].name) for ref in dataset[:]]
elif len(dataset.dtype) > 1:
# compound dataset - check if any of the fields are references
for name in dataset.dtype.names:
if isinstance(dataset[name][0], h5py.h5r.Reference):
res[name] = [context.get(h5f[ref].name) for ref in dataset[name]]
else:
res[name] = H5ArrayPath(h5f.filename, dataset.name, name)
else:
res["value"] = H5ArrayPath(h5f.filename, dataset.name)
res.update(dataset.attrs)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
res["name"] = dataset.name.split("/")[-1]
res["hdf5_path"] = dataset.name
if len(res) == 1:
return res["value"]
else:
return res
def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
"""
Load a group!
"""
res = {}
res.update(group.attrs)
for child_name, child in group.items():
if child.name in context:
res[child_name] = context[child.name]
elif isinstance(child, h5py.Dataset):
res[child_name] = _load_dataset(child, h5f, context)
elif isinstance(child, h5py.Group):
res[child_name] = _load_group(child, h5f, context)
else:
raise TypeError(
"Can only handle preinstantiated child objects in context, datasets, and group,"
f" got {child} for {child_name}"
)
if "namespace" in res:
del res["namespace"]
if "neurodata_type" in res:
del res["neurodata_type"]
res["name"] = group.name.split("/")[-1]
res["hdf5_path"] = group.name
return res
class HDF5IO:
"""
Read (and eventually write) from an NWB HDF5 file.
@ -185,28 +306,20 @@ class HDF5IO:
h5f = h5py.File(str(self.path))
src = h5f.get(path) if path else h5f
graph = hdf_dependency_graph(src)
graph = filter_dependency_graph(graph)
# get all children of selected item
if isinstance(src, (h5py.File, h5py.Group)):
children = flatten_hdf(src)
else:
raise NotImplementedError("directly read individual datasets")
# topo sort to get read order
# TODO: This could be parallelized using `topological_generations`,
# but it's not clear what the perf bonus would be because there are many generations
# with few items
topo_order = list(reversed(list(nx.topological_sort(graph))))
context = {}
for node in topo_order:
res = _load_node(node, h5f, provider, context)
context[node] = res
queue = ReadQueue(h5f=self.path, queue=children, provider=provider)
# Apply initial planning phase of reading
queue.apply_phase(ReadPhases.plan)
# Read operations gather the data before casting into models
queue.apply_phase(ReadPhases.read)
# Construction operations actually cast the models
# this often needs to run several times as models with dependencies wait for their
# dependents to be cast
queue.apply_phase(ReadPhases.construct)
if path is None:
return queue.completed["/"].result
else:
return queue.completed[path].result
pdb.set_trace()
def write(self, path: Path) -> Never:
"""
@ -246,7 +359,7 @@ class HDF5IO:
"""
from nwb_linkml.providers.schema import SchemaProvider
h5f = h5py.File(str(self.path))
h5f = h5py.File(str(self.path), "r")
schema = read_specs_as_dicts(h5f.get("specifications"))
# get versions for each namespace
@ -260,7 +373,7 @@ class HDF5IO:
provider = SchemaProvider(versions=versions)
# build schema so we have them cached
provider.build_from_dicts(schema)
# provider.build_from_dicts(schema)
h5f.close()
return provider

View file

@ -233,66 +233,6 @@ class PruneEmpty(HDF5Map):
return H5ReadResult.model_construct(path=src.path, source=src, completed=True)
#
# class ResolveDynamicTable(HDF5Map):
# """
# Handle loading a dynamic table!
#
# Dynamic tables are sort of odd in that their models don't include their fields
# (except as a list of strings in ``colnames`` ),
# so we need to create a new model that includes fields for each column,
# and then we include the datasets as :class:`~numpydantic.interface.hdf5.H5ArrayPath`
# objects which lazy load the arrays in a thread/process safe way.
#
# This map also resolves the child elements,
# indicating so by the ``completes`` field in the :class:`.ReadResult`
# """
#
# phase = ReadPhases.read
# priority = 1
#
# @classmethod
# def check(
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
# ) -> bool:
# if src.h5_type == "dataset":
# return False
# if "neurodata_type" in src.attrs:
# if src.attrs["neurodata_type"] == "DynamicTable":
# return True
# # otherwise, see if it's a subclass
# model = provider.get_class(src.attrs["namespace"], src.attrs["neurodata_type"])
# # just inspect the MRO as strings rather than trying to check subclasses because
# # we might replace DynamicTable in the future, and there isn't a stable DynamicTable
# # class to inherit from anyway because of the whole multiple versions thing
# parents = [parent.__name__ for parent in model.__mro__]
# return "DynamicTable" in parents
# else:
# return False
#
# @classmethod
# def apply(
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
# ) -> H5ReadResult:
# with h5py.File(src.h5f_path, "r") as h5f:
# obj = h5f.get(src.path)
#
# # make a populated model :)
# base_model = provider.get_class(src.namespace, src.neurodata_type)
# model = dynamictable_to_model(obj, base=base_model)
#
# completes = [HDF5_Path(child.name) for child in obj.values()]
#
# return H5ReadResult(
# path=src.path,
# source=src,
# result=model,
# completes=completes,
# completed=True,
# applied=["ResolveDynamicTable"],
# )
class ResolveModelGroup(HDF5Map):
"""
HDF5 Groups that have a model, as indicated by ``neurodata_type`` in their attrs.

View file

@ -97,9 +97,9 @@ class Provider(ABC):
module_path = Path(importlib.util.find_spec("nwb_models").origin).parent
if self.PROVIDES == "linkml":
namespace_path = module_path / "schema" / "linkml" / namespace
namespace_path = module_path / "schema" / "linkml" / namespace_module
elif self.PROVIDES == "pydantic":
namespace_path = module_path / "models" / "pydantic" / namespace
namespace_path = module_path / "models" / "pydantic" / namespace_module
if version is not None:
version_path = namespace_path / version_module_case(version)

View file

@ -1,4 +1,4 @@
from .nwb import nwb_file
from .nwb import nwb_file, nwb_file_base
from .paths import data_dir, tmp_output_dir, tmp_output_dir_func, tmp_output_dir_mod
from .schema import (
NWBSchemaTest,
@ -21,6 +21,7 @@ __all__ = [
"nwb_core_linkml",
"nwb_core_module",
"nwb_file",
"nwb_file_base",
"nwb_schema",
"tmp_output_dir",
"tmp_output_dir_func",

View file

@ -6,7 +6,7 @@ import pytest
@pytest.fixture(scope="session")
def tmp_output_dir(request: pytest.FixtureRequest) -> Path:
path = Path(__file__).parent.resolve() / "__tmp__"
path = Path(__file__).parents[1].resolve() / "__tmp__"
if path.exists():
if request.config.getoption("--clean"):
shutil.rmtree(path)

View file

@ -1,6 +1,7 @@
import pdb
import h5py
import networkx as nx
import numpy as np
import pytest
@ -100,10 +101,15 @@ def test_flatten_hdf():
raise NotImplementedError("Just a stub for local testing for now, finish me!")
def test_dependency_graph(nwb_file):
@pytest.mark.dev
def test_dependency_graph(nwb_file, tmp_output_dir):
"""
dependency graph is correctly constructed from an HDF5 file
"""
graph = hdf_dependency_graph(nwb_file)
A_unfiltered = nx.nx_agraph.to_agraph(graph)
A_unfiltered.draw(tmp_output_dir / "test_nwb_unfiltered.png", prog="dot")
graph = filter_dependency_graph(graph)
A_filtered = nx.nx_agraph.to_agraph(graph)
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
pass

View file

@ -2,12 +2,14 @@
Placeholder test module to test reading from pynwb-generated NWB file
"""
from nwb_linkml.io.hdf5 import HDF5IO
def test_read_from_nwbfile(nwb_file):
"""
Read data from a pynwb HDF5 NWB file
"""
pass
res = HDF5IO(nwb_file).read()
def test_read_from_yaml(nwb_file):