CHECKPOINT WITH IT WORKING before cleanup and model regeneration

This commit is contained in:
sneakers-the-rat 2024-09-03 17:48:36 -07:00
parent d1498a3733
commit 8078492f90
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
11 changed files with 178 additions and 77 deletions

View file

@ -854,3 +854,12 @@ class DatasetAdapter(ClassAdapter):
return None return None
else: else:
return matches[0] return matches[0]
def special_cases(self, res: BuildResult) -> BuildResult:
"""
Apply special cases to build result
"""
res = self._datetime_or_str(res)
def _datetime_or_str(self, res: BuildResult) -> BuildResult:
"""HDF5 doesn't support datetime, so"""

View file

@ -22,10 +22,16 @@ BASEMODEL_COERCE_VALUE = """
return handler(v) return handler(v)
except Exception as e1: except Exception as e1:
try: try:
if hasattr(v, "value"):
return handler(v.value) return handler(v.value)
else: except:
return handler(v["value"]) raise e1
except Exception as e2: """
raise e2 from e1
BASEMODEL_COERCE_CHILD = """
@field_validator("*", mode="before")
@classmethod
def coerce_subclass(cls, v: Any, info) -> Any:
\"\"\"Recast parent classes into child classes\"\"\"
return v
pdb.set_trace()
""" """

View file

@ -53,6 +53,7 @@ class DynamicTableMixin(BaseModel):
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ( NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
"id", "id",
"name", "name",
"categories",
"colnames", "colnames",
"description", "description",
"hdf5_path", "hdf5_path",

View file

@ -39,7 +39,7 @@ from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
from nwb_linkml.maps.hdf5 import get_references from nwb_linkml.maps.hdf5 import get_references, resolve_hardlink
if TYPE_CHECKING: if TYPE_CHECKING:
from nwb_linkml.providers.schema import SchemaProvider from nwb_linkml.providers.schema import SchemaProvider
@ -51,7 +51,7 @@ else:
from typing_extensions import Never from typing_extensions import Never
SKIP_PATTERN = re.compile("^/specifications.*") SKIP_PATTERN = re.compile("(^/specifications.*)|(\.specloc)")
"""Nodes to always skip in reading e.g. because they are handled elsewhere""" """Nodes to always skip in reading e.g. because they are handled elsewhere"""
@ -95,7 +95,11 @@ def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
# add children, if group # add children, if group
if isinstance(node, h5py.Group): if isinstance(node, h5py.Group):
children = [child.name for child in node.values() if not SKIP_PATTERN.match(child.name)] children = [
resolve_hardlink(child)
for child in node.values()
if not SKIP_PATTERN.match(child.name)
]
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)] edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
g.add_edges_from(edges, label="child") g.add_edges_from(edges, label="child")
@ -157,21 +161,15 @@ def _load_node(
else: else:
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}") raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
# if obj.name == "/general/intracellular_ephys/simultaneous_recordings/recordings": if "neurodata_type" in obj.attrs:
# pdb.set_trace()
# resolve attr references
for k, v in args.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
args[k] = context[ref_path]
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"]) model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
# add additional needed params
args["hdf5_path"] = path
args["name"] = path.split("/")[-1]
return model(**args) return model(**args)
else:
if "name" in args:
del args["name"]
if "hdf5_path" in args:
del args["hdf5_path"]
return args
def _load_dataset( def _load_dataset(
@ -214,6 +212,15 @@ def _load_dataset(
res["name"] = dataset.name.split("/")[-1] res["name"] = dataset.name.split("/")[-1]
res["hdf5_path"] = dataset.name res["hdf5_path"] = dataset.name
# resolve attr references
for k, v in res.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
if SKIP_PATTERN.match(ref_path):
res[k] = ref_path
else:
res[k] = context[ref_path]
if len(res) == 1: if len(res) == 1:
return res["value"] return res["value"]
else: else:
@ -242,8 +249,20 @@ def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
del res["namespace"] del res["namespace"]
if "neurodata_type" in res: if "neurodata_type" in res:
del res["neurodata_type"] del res["neurodata_type"]
res["name"] = group.name.split("/")[-1] name = group.name.split("/")[-1]
if name:
res["name"] = name
res["hdf5_path"] = group.name res["hdf5_path"] = group.name
# resolve attr references
for k, v in res.items():
if isinstance(v, h5py.h5r.Reference):
ref_path = h5f[v].name
if SKIP_PATTERN.match(ref_path):
res[k] = ref_path
else:
res[k] = context[ref_path]
return res return res
@ -319,6 +338,10 @@ class HDF5IO:
res = _load_node(node, h5f, provider, context) res = _load_node(node, h5f, provider, context)
context[node] = res context[node] = res
if path is None:
path = "/"
return context[path]
pdb.set_trace() pdb.set_trace()
def write(self, path: Path) -> Never: def write(self, path: Path) -> Never:

View file

@ -844,7 +844,7 @@ def resolve_references(
return res, errors, completes return res, errors, completes
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> HDF5_Path: def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> str:
""" """
Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have
to do extra work to check if an item is "real" or a hardlink to another item. to do extra work to check if an item is "real" or a hardlink to another item.
@ -856,4 +856,4 @@ def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> HDF5_Path:
We basically dereference the object and return that path instead of the path We basically dereference the object and return that path instead of the path
given by the object's ``name`` given by the object's ``name``
""" """
return HDF5_Path(obj.file[obj.ref].name) return obj.file[obj.ref].name

View file

@ -113,3 +113,17 @@ def test_dependency_graph(nwb_file, tmp_output_dir):
A_filtered = nx.nx_agraph.to_agraph(graph) A_filtered = nx.nx_agraph.to_agraph(graph)
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot") A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
pass pass
@pytest.mark.skip
def test_dependencies_hardlink(nwb_file):
"""
Test that hardlinks are resolved (eg. from /processing/ecephys/LFP/ElectricalSeries/electrodes
to /acquisition/ElectricalSeries/electrodes
Args:
nwb_file:
Returns:
"""
pass

View file

@ -147,7 +147,7 @@ class TimeIntervals(DynamicTable):
} }
}, },
) )
tags: VectorData[Optional[NDArray[Any, str]]] = Field( tags: Optional[VectorData[NDArray[Any, str]]] = Field(
None, None,
description="""User-defined tags that identify or categorize events.""", description="""User-defined tags that identify or categorize events.""",
json_schema_extra={ json_schema_extra={
@ -168,7 +168,7 @@ class TimeIntervals(DynamicTable):
} }
}, },
) )
timeseries: Named[Optional[TimeSeriesReferenceVectorData]] = Field( timeseries: Optional[Named[TimeSeriesReferenceVectorData]] = Field(
None, None,
description="""An index into a TimeSeries object.""", description="""An index into a TimeSeries object.""",
json_schema_extra={ json_schema_extra={

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import pdb
import re import re
import sys import sys
from datetime import date, datetime, time from datetime import date, datetime, time
@ -72,12 +73,24 @@ class ConfiguredBaseModel(BaseModel):
return handler(v) return handler(v)
except Exception as e1: except Exception as e1:
try: try:
if hasattr(v, "value"):
return handler(v.value) return handler(v.value)
else: except AttributeError:
try:
return handler(v["value"]) return handler(v["value"])
except Exception as e2: except (KeyError, TypeError):
raise e2 from e1 raise e1
@field_validator("*", mode="before")
@classmethod
def coerce_subclass(cls, v: Any, info) -> Any:
"""Recast parent classes into child classes"""
if isinstance(v, BaseModel):
annotation = cls.model_fields[info.field_name].annotation
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
# pdb.set_trace()
if issubclass(annotation, type(v)) and annotation is not type(v):
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
return v
class LinkMLMeta(RootModel): class LinkMLMeta(RootModel):
@ -176,28 +189,28 @@ class NWBFile(NWBContainer):
..., ...,
description="""Date and time corresponding to time zero of all timestamps. The date is stored in UTC with local timezone offset as ISO 8601 extended formatted string: 2018-09-28T14:43:54.123+02:00. Dates stored in UTC end in \"Z\" with no timezone offset. Date accuracy is up to milliseconds. All times stored in the file use this time as reference (i.e., time zero).""", description="""Date and time corresponding to time zero of all timestamps. The date is stored in UTC with local timezone offset as ISO 8601 extended formatted string: 2018-09-28T14:43:54.123+02:00. Dates stored in UTC end in \"Z\" with no timezone offset. Date accuracy is up to milliseconds. All times stored in the file use this time as reference (i.e., time zero).""",
) )
acquisition: Optional[List[Union[DynamicTable, NWBDataInterface]]] = Field( acquisition: Optional[dict[str, Union[DynamicTable, NWBDataInterface]]] = Field(
None, None,
description="""Data streams recorded from the system, including ephys, ophys, tracking, etc. This group should be read-only after the experiment is completed and timestamps are corrected to a common timebase. The data stored here may be links to raw data stored in external NWB files. This will allow keeping bulky raw data out of the file while preserving the option of keeping some/all in the file. Acquired data includes tracking and experimental data streams (i.e., everything measured from the system). If bulky data is stored in the /acquisition group, the data can exist in a separate NWB file that is linked to by the file being used for processing and analysis.""", description="""Data streams recorded from the system, including ephys, ophys, tracking, etc. This group should be read-only after the experiment is completed and timestamps are corrected to a common timebase. The data stored here may be links to raw data stored in external NWB files. This will allow keeping bulky raw data out of the file while preserving the option of keeping some/all in the file. Acquired data includes tracking and experimental data streams (i.e., everything measured from the system). If bulky data is stored in the /acquisition group, the data can exist in a separate NWB file that is linked to by the file being used for processing and analysis.""",
json_schema_extra={ json_schema_extra={
"linkml_meta": {"any_of": [{"range": "NWBDataInterface"}, {"range": "DynamicTable"}]} "linkml_meta": {"any_of": [{"range": "NWBDataInterface"}, {"range": "DynamicTable"}]}
}, },
) )
analysis: Optional[List[Union[DynamicTable, NWBContainer]]] = Field( analysis: Optional[dict[str, Union[DynamicTable, NWBContainer]]] = Field(
None, None,
description="""Lab-specific and custom scientific analysis of data. There is no defined format for the content of this group - the format is up to the individual user/lab. To facilitate sharing analysis data between labs, the contents here should be stored in standard types (e.g., neurodata_types) and appropriately documented. The file can store lab-specific and custom data analysis without restriction on its form or schema, reducing data formatting restrictions on end users. Such data should be placed in the analysis group. The analysis data should be documented so that it could be shared with other labs.""", description="""Lab-specific and custom scientific analysis of data. There is no defined format for the content of this group - the format is up to the individual user/lab. To facilitate sharing analysis data between labs, the contents here should be stored in standard types (e.g., neurodata_types) and appropriately documented. The file can store lab-specific and custom data analysis without restriction on its form or schema, reducing data formatting restrictions on end users. Such data should be placed in the analysis group. The analysis data should be documented so that it could be shared with other labs.""",
json_schema_extra={ json_schema_extra={
"linkml_meta": {"any_of": [{"range": "NWBContainer"}, {"range": "DynamicTable"}]} "linkml_meta": {"any_of": [{"range": "NWBContainer"}, {"range": "DynamicTable"}]}
}, },
) )
scratch: Optional[List[Union[DynamicTable, NWBContainer]]] = Field( scratch: Optional[dict[str, Union[DynamicTable, NWBContainer]]] = Field(
None, None,
description="""A place to store one-off analysis results. Data placed here is not intended for sharing. By placing data here, users acknowledge that there is no guarantee that their data meets any standard.""", description="""A place to store one-off analysis results. Data placed here is not intended for sharing. By placing data here, users acknowledge that there is no guarantee that their data meets any standard.""",
json_schema_extra={ json_schema_extra={
"linkml_meta": {"any_of": [{"range": "NWBContainer"}, {"range": "DynamicTable"}]} "linkml_meta": {"any_of": [{"range": "NWBContainer"}, {"range": "DynamicTable"}]}
}, },
) )
processing: Optional[List[ProcessingModule]] = Field( processing: Optional[dict[str, ProcessingModule]] = Field(
None, None,
description="""The home for ProcessingModules. These modules perform intermediate analysis of data that is necessary to perform before scientific analysis. Examples include spike clustering, extracting position from tracking data, stitching together image slices. ProcessingModules can be large and express many data sets from relatively complex analysis (e.g., spike detection and clustering) or small, representing extraction of position information from tracking video, or even binary lick/no-lick decisions. Common software tools (e.g., klustakwik, MClust) are expected to read/write data here. 'Processing' refers to intermediate analysis of the acquired data to make it more amenable to scientific analysis.""", description="""The home for ProcessingModules. These modules perform intermediate analysis of data that is necessary to perform before scientific analysis. Examples include spike clustering, extracting position from tracking data, stitching together image slices. ProcessingModules can be large and express many data sets from relatively complex analysis (e.g., spike detection and clustering) or small, representing extraction of position information from tracking video, or even binary lick/no-lick decisions. Common software tools (e.g., klustakwik, MClust) are expected to read/write data here. 'Processing' refers to intermediate analysis of the acquired data to make it more amenable to scientific analysis.""",
json_schema_extra={"linkml_meta": {"any_of": [{"range": "ProcessingModule"}]}}, json_schema_extra={"linkml_meta": {"any_of": [{"range": "ProcessingModule"}]}},
@ -230,7 +243,7 @@ class NWBFileStimulus(ConfiguredBaseModel):
"linkml_meta": {"equals_string": "stimulus", "ifabsent": "string(stimulus)"} "linkml_meta": {"equals_string": "stimulus", "ifabsent": "string(stimulus)"}
}, },
) )
presentation: Optional[List[Union[DynamicTable, NWBDataInterface, TimeSeries]]] = Field( presentation: Optional[dict[str, Union[DynamicTable, NWBDataInterface, TimeSeries]]] = Field(
None, None,
description="""Stimuli presented during the experiment.""", description="""Stimuli presented during the experiment.""",
json_schema_extra={ json_schema_extra={
@ -243,7 +256,7 @@ class NWBFileStimulus(ConfiguredBaseModel):
} }
}, },
) )
templates: Optional[List[Union[Images, TimeSeries]]] = Field( templates: Optional[dict[str, Union[Images, TimeSeries]]] = Field(
None, None,
description="""Template stimuli. Timestamps in templates are based on stimulus design and are relative to the beginning of the stimulus. When templates are used, the stimulus instances must convert presentation times to the experiment`s time reference frame.""", description="""Template stimuli. Timestamps in templates are based on stimulus design and are relative to the beginning of the stimulus. When templates are used, the stimulus instances must convert presentation times to the experiment`s time reference frame.""",
json_schema_extra={ json_schema_extra={
@ -327,7 +340,7 @@ class NWBFileGeneral(ConfiguredBaseModel):
None, None,
description="""Place-holder than can be extended so that lab-specific meta-data can be placed in /general.""", description="""Place-holder than can be extended so that lab-specific meta-data can be placed in /general.""",
) )
devices: Optional[List[Device]] = Field( devices: Optional[dict[str, Device]] = Field(
None, None,
description="""Description of hardware devices used during experiment, e.g., monitors, ADC boards, microscopes, etc.""", description="""Description of hardware devices used during experiment, e.g., monitors, ADC boards, microscopes, etc.""",
json_schema_extra={"linkml_meta": {"any_of": [{"range": "Device"}]}}, json_schema_extra={"linkml_meta": {"any_of": [{"range": "Device"}]}},
@ -408,7 +421,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
"linkml_meta": {"equals_string": "electrodes", "ifabsent": "string(electrodes)"} "linkml_meta": {"equals_string": "electrodes", "ifabsent": "string(electrodes)"}
}, },
) )
x: VectorData[Optional[NDArray[Any, float]]] = Field( x: Optional[VectorData[NDArray[Any, float]]] = Field(
None, None,
description="""x coordinate of the channel location in the brain (+x is posterior).""", description="""x coordinate of the channel location in the brain (+x is posterior).""",
json_schema_extra={ json_schema_extra={
@ -417,7 +430,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
y: VectorData[Optional[NDArray[Any, float]]] = Field( y: Optional[VectorData[NDArray[Any, float]]] = Field(
None, None,
description="""y coordinate of the channel location in the brain (+y is inferior).""", description="""y coordinate of the channel location in the brain (+y is inferior).""",
json_schema_extra={ json_schema_extra={
@ -426,7 +439,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
z: VectorData[Optional[NDArray[Any, float]]] = Field( z: Optional[VectorData[NDArray[Any, float]]] = Field(
None, None,
description="""z coordinate of the channel location in the brain (+z is right).""", description="""z coordinate of the channel location in the brain (+z is right).""",
json_schema_extra={ json_schema_extra={
@ -435,7 +448,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
imp: VectorData[Optional[NDArray[Any, float]]] = Field( imp: Optional[VectorData[NDArray[Any, float]]] = Field(
None, None,
description="""Impedance of the channel, in ohms.""", description="""Impedance of the channel, in ohms.""",
json_schema_extra={ json_schema_extra={
@ -453,7 +466,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
filtering: VectorData[Optional[NDArray[Any, str]]] = Field( filtering: Optional[VectorData[NDArray[Any, str]]] = Field(
None, None,
description="""Description of hardware filtering, including the filter name and frequency cutoffs.""", description="""Description of hardware filtering, including the filter name and frequency cutoffs.""",
json_schema_extra={ json_schema_extra={
@ -474,7 +487,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
rel_x: VectorData[Optional[NDArray[Any, float]]] = Field( rel_x: Optional[VectorData[NDArray[Any, float]]] = Field(
None, None,
description="""x coordinate in electrode group""", description="""x coordinate in electrode group""",
json_schema_extra={ json_schema_extra={
@ -483,7 +496,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
rel_y: VectorData[Optional[NDArray[Any, float]]] = Field( rel_y: Optional[VectorData[NDArray[Any, float]]] = Field(
None, None,
description="""y coordinate in electrode group""", description="""y coordinate in electrode group""",
json_schema_extra={ json_schema_extra={
@ -492,7 +505,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
rel_z: VectorData[Optional[NDArray[Any, float]]] = Field( rel_z: Optional[VectorData[NDArray[Any, float]]] = Field(
None, None,
description="""z coordinate in electrode group""", description="""z coordinate in electrode group""",
json_schema_extra={ json_schema_extra={
@ -501,7 +514,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
} }
}, },
) )
reference: VectorData[Optional[NDArray[Any, str]]] = Field( reference: Optional[VectorData[NDArray[Any, str]]] = Field(
None, None,
description="""Description of the reference electrode and/or reference scheme used for this electrode, e.g., \"stainless steel skull screw\" or \"online common average referencing\".""", description="""Description of the reference electrode and/or reference scheme used for this electrode, e.g., \"stainless steel skull screw\" or \"online common average referencing\".""",
json_schema_extra={ json_schema_extra={

View file

@ -518,7 +518,7 @@ class Units(DynamicTable):
} }
}, },
) )
obs_intervals: VectorData[Optional[NDArray[Shape["* num_intervals, 2 start_end"], float]]] = ( obs_intervals: Optional[VectorData[NDArray[Shape["* num_intervals, 2 start_end"], float]]] = (
Field( Field(
None, None,
description="""Observation intervals for each unit.""", description="""Observation intervals for each unit.""",
@ -546,7 +546,7 @@ class Units(DynamicTable):
} }
}, },
) )
electrodes: Named[Optional[DynamicTableRegion]] = Field( electrodes: Optional[Named[DynamicTableRegion]] = Field(
None, None,
description="""Electrode that each spike unit came from, specified using a DynamicTableRegion.""", description="""Electrode that each spike unit came from, specified using a DynamicTableRegion.""",
json_schema_extra={ json_schema_extra={
@ -561,23 +561,23 @@ class Units(DynamicTable):
electrode_group: Optional[List[ElectrodeGroup]] = Field( electrode_group: Optional[List[ElectrodeGroup]] = Field(
None, description="""Electrode group that each spike unit came from.""" None, description="""Electrode group that each spike unit came from."""
) )
waveform_mean: VectorData[ waveform_mean: Optional[
Optional[ VectorData[
Union[ Union[
NDArray[Shape["* num_units, * num_samples"], float], NDArray[Shape["* num_units, * num_samples"], float],
NDArray[Shape["* num_units, * num_samples, * num_electrodes"], float], NDArray[Shape["* num_units, * num_samples, * num_electrodes"], float],
] ]
] ]
] = Field(None, description="""Spike waveform mean for each spike unit.""") ] = Field(None, description="""Spike waveform mean for each spike unit.""")
waveform_sd: VectorData[ waveform_sd: Optional[
Optional[ VectorData[
Union[ Union[
NDArray[Shape["* num_units, * num_samples"], float], NDArray[Shape["* num_units, * num_samples"], float],
NDArray[Shape["* num_units, * num_samples, * num_electrodes"], float], NDArray[Shape["* num_units, * num_samples, * num_electrodes"], float],
] ]
] ]
] = Field(None, description="""Spike waveform standard deviation for each spike unit.""") ] = Field(None, description="""Spike waveform standard deviation for each spike unit.""")
waveforms: VectorData[Optional[NDArray[Shape["* num_waveforms, * num_samples"], float]]] = ( waveforms: Optional[VectorData[NDArray[Shape["* num_waveforms, * num_samples"], float]]] = (
Field( Field(
None, None,
description="""Individual waveforms for each spike on each electrode. This is a doubly indexed column. The 'waveforms_index' column indexes which waveforms in this column belong to the same spike event for a given unit, where each waveform was recorded from a different electrode. The 'waveforms_index_index' column indexes the 'waveforms_index' column to indicate which spike events belong to a given unit. For example, if the 'waveforms_index_index' column has values [2, 5, 6], then the first 2 elements of the 'waveforms_index' column correspond to the 2 spike events of the first unit, the next 3 elements of the 'waveforms_index' column correspond to the 3 spike events of the second unit, and the next 1 element of the 'waveforms_index' column corresponds to the 1 spike event of the third unit. If the 'waveforms_index' column has values [3, 6, 8, 10, 12, 13], then the first 3 elements of the 'waveforms' column contain the 3 spike waveforms that were recorded from 3 different electrodes for the first spike time of the first unit. See https://nwb-schema.readthedocs.io/en/stable/format_description.html#doubly-ragged-arrays for a graphical representation of this example. When there is only one electrode for each unit (i.e., each spike time is associated with a single waveform), then the 'waveforms_index' column will have values 1, 2, ..., N, where N is the number of spike events. The number of electrodes for each spike event should be the same within a given unit. The 'electrodes' column should be used to indicate which electrodes are associated with each unit, and the order of the waveforms within a given unit x spike event should be in the same order as the electrodes referenced in the 'electrodes' column of this table. The number of samples for each waveform must be the same.""", description="""Individual waveforms for each spike on each electrode. This is a doubly indexed column. The 'waveforms_index' column indexes which waveforms in this column belong to the same spike event for a given unit, where each waveform was recorded from a different electrode. The 'waveforms_index_index' column indexes the 'waveforms_index' column to indicate which spike events belong to a given unit. For example, if the 'waveforms_index_index' column has values [2, 5, 6], then the first 2 elements of the 'waveforms_index' column correspond to the 2 spike events of the first unit, the next 3 elements of the 'waveforms_index' column correspond to the 3 spike events of the second unit, and the next 1 element of the 'waveforms_index' column corresponds to the 1 spike event of the third unit. If the 'waveforms_index' column has values [3, 6, 8, 10, 12, 13], then the first 3 elements of the 'waveforms' column contain the 3 spike waveforms that were recorded from 3 different electrodes for the first spike time of the first unit. See https://nwb-schema.readthedocs.io/en/stable/format_description.html#doubly-ragged-arrays for a graphical representation of this example. When there is only one electrode for each unit (i.e., each spike time is associated with a single waveform), then the 'waveforms_index' column will have values 1, 2, ..., N, where N is the number of spike events. The number of electrodes for each spike event should be the same within a given unit. The 'electrodes' column should be used to indicate which electrodes are associated with each unit, and the order of the waveforms within a given unit x spike event should be in the same order as the electrodes referenced in the 'electrodes' column of this table. The number of samples for each waveform must be the same.""",

View file

@ -46,12 +46,23 @@ class ConfiguredBaseModel(BaseModel):
return handler(v) return handler(v)
except Exception as e1: except Exception as e1:
try: try:
if hasattr(v, "value"):
return handler(v.value) return handler(v.value)
else: except AttributeError:
try:
return handler(v["value"]) return handler(v["value"])
except Exception as e2: except (KeyError, TypeError):
raise e2 from e1 raise e1
@field_validator("*", mode="before")
@classmethod
def coerce_subclass(cls, v: Any, info) -> Any:
"""Recast parent classes into child classes"""
if isinstance(v, BaseModel):
annotation = cls.model_fields[info.field_name].annotation
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
if issubclass(annotation, type(v)) and annotation is not type(v):
v = annotation(**v.__dict__)
return v
class LinkMLMeta(RootModel): class LinkMLMeta(RootModel):

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import pdb
import re import re
import sys import sys
from datetime import date, datetime, time from datetime import date, datetime, time
@ -50,6 +51,7 @@ class ConfiguredBaseModel(BaseModel):
arbitrary_types_allowed=True, arbitrary_types_allowed=True,
use_enum_values=True, use_enum_values=True,
strict=False, strict=False,
validation_error_cause=True,
) )
hdf5_path: Optional[str] = Field( hdf5_path: Optional[str] = Field(
None, description="The absolute path that this object is stored in an NWB file" None, description="The absolute path that this object is stored in an NWB file"
@ -67,18 +69,35 @@ class ConfiguredBaseModel(BaseModel):
@field_validator("*", mode="wrap") @field_validator("*", mode="wrap")
@classmethod @classmethod
def coerce_value(cls, v: Any, handler) -> Any: def coerce_value(cls, v: Any, handler, info) -> Any:
"""Try to rescue instantiation by using the value field""" """Try to rescue instantiation by using the value field"""
try: try:
return handler(v) return handler(v)
except Exception as e1: except Exception as e1:
try: try:
if hasattr(v, "value"):
return handler(v.value) return handler(v.value)
else: except AttributeError:
try:
return handler(v["value"]) return handler(v["value"])
except Exception as e2: except (KeyError, TypeError):
raise e2 from e1 raise e1
# try:
# if hasattr(v, "value"):
# else:
# return handler(v["value"])
# except Exception as e2:
# raise e2 from e1
@field_validator("*", mode="before")
@classmethod
def coerce_subclass(cls, v: Any, info) -> Any:
"""Recast parent classes into child classes"""
if isinstance(v, BaseModel):
annotation = cls.model_fields[info.field_name].annotation
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
if issubclass(annotation, type(v)) and annotation is not type(v):
v = annotation(**v.__dict__)
return v
class LinkMLMeta(RootModel): class LinkMLMeta(RootModel):
@ -310,11 +329,12 @@ class DynamicTableMixin(BaseModel):
but simplifying along the way :) but simplifying along the way :)
""" """
model_config = ConfigDict(extra="allow", validate_assignment=True) model_config = ConfigDict(extra="allow", validate_assignment=True, validation_error_cause=True)
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]] __pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ( NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
"id", "id",
"name", "name",
"categories",
"colnames", "colnames",
"description", "description",
"hdf5_path", "hdf5_path",
@ -510,6 +530,7 @@ class DynamicTableMixin(BaseModel):
if k not in cls.NON_COLUMN_FIELDS if k not in cls.NON_COLUMN_FIELDS
and not k.endswith("_index") and not k.endswith("_index")
and not isinstance(model[k], VectorIndexMixin) and not isinstance(model[k], VectorIndexMixin)
and model[k] is not None
] ]
model["colnames"] = colnames model["colnames"] = colnames
else: else:
@ -525,6 +546,7 @@ class DynamicTableMixin(BaseModel):
and not k.endswith("_index") and not k.endswith("_index")
and k not in model["colnames"] and k not in model["colnames"]
and not isinstance(model[k], VectorIndexMixin) and not isinstance(model[k], VectorIndexMixin)
and model[k] is not None
] ]
) )
model["colnames"] = colnames model["colnames"] = colnames
@ -597,9 +619,9 @@ class DynamicTableMixin(BaseModel):
""" """
Ensure that all columns are equal length Ensure that all columns are equal length
""" """
lengths = [len(v) for v in self._columns.values()] + [len(self.id)] lengths = [len(v) for v in self._columns.values() if v is not None] + [len(self.id)]
assert all([length == lengths[0] for length in lengths]), ( assert all([length == lengths[0] for length in lengths]), (
"Columns are not of equal length! " "DynamicTable Columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
) )
return self return self
@ -645,7 +667,7 @@ class AlignedDynamicTableMixin(BaseModel):
and also it's not so easy to copy a pydantic validator method. and also it's not so easy to copy a pydantic validator method.
""" """
model_config = ConfigDict(extra="allow", validate_assignment=True) model_config = ConfigDict(extra="allow", validate_assignment=True, validation_error_cause=True)
__pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]] __pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]]
NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = ( NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = (
@ -654,6 +676,7 @@ class AlignedDynamicTableMixin(BaseModel):
"colnames", "colnames",
"description", "description",
"hdf5_path", "hdf5_path",
"id",
"object_id", "object_id",
) )
@ -684,7 +707,7 @@ class AlignedDynamicTableMixin(BaseModel):
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str): elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str):
# get a slice of a single table # get a slice of a single table
return self._categories[item[1]][item[0]] return self._categories[item[1]][item[0]]
elif isinstance(item, (int, slice, Iterable)): elif isinstance(item, (int, slice, Iterable, np.int_)):
# get a slice of all the tables # get a slice of all the tables
ids = self.id[item] ids = self.id[item]
if not isinstance(ids, Iterable): if not isinstance(ids, Iterable):
@ -696,9 +719,9 @@ class AlignedDynamicTableMixin(BaseModel):
if isinstance(table, pd.DataFrame): if isinstance(table, pd.DataFrame):
table = table.reset_index() table = table.reset_index()
elif isinstance(table, np.ndarray): elif isinstance(table, np.ndarray):
table = pd.DataFrame({category_name: [table]}) table = pd.DataFrame({category_name: [table]}, index=ids.index)
elif isinstance(table, Iterable): elif isinstance(table, Iterable):
table = pd.DataFrame({category_name: table}) table = pd.DataFrame({category_name: table}, index=ids.index)
else: else:
raise ValueError( raise ValueError(
f"Don't know how to construct category table for {category_name}" f"Don't know how to construct category table for {category_name}"
@ -708,6 +731,7 @@ class AlignedDynamicTableMixin(BaseModel):
names = [self.name] + self.categories names = [self.name] + self.categories
# construct below in case we need to support array indexing in the future # construct below in case we need to support array indexing in the future
else: else:
pdb.set_trace()
raise ValueError( raise ValueError(
f"Dont know how to index with {item}, " f"Dont know how to index with {item}, "
"need an int, string, slice, ndarray, or tuple[int | slice, str]" "need an int, string, slice, ndarray, or tuple[int | slice, str]"
@ -818,7 +842,7 @@ class AlignedDynamicTableMixin(BaseModel):
""" """
lengths = [len(v) for v in self._categories.values()] + [len(self.id)] lengths = [len(v) for v in self._categories.values()] + [len(self.id)]
assert all([length == lengths[0] for length in lengths]), ( assert all([length == lengths[0] for length in lengths]), (
"Columns are not of equal length! " "AlignedDynamicTable Columns are not of equal length! "
f"Got colnames:\n{self.categories}\nand lengths: {lengths}" f"Got colnames:\n{self.categories}\nand lengths: {lengths}"
) )
return self return self