mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
CHECKPOINT WITH IT WORKING before cleanup and model regeneration
This commit is contained in:
parent
d1498a3733
commit
8078492f90
11 changed files with 178 additions and 77 deletions
|
@ -854,3 +854,12 @@ class DatasetAdapter(ClassAdapter):
|
|||
return None
|
||||
else:
|
||||
return matches[0]
|
||||
|
||||
def special_cases(self, res: BuildResult) -> BuildResult:
|
||||
"""
|
||||
Apply special cases to build result
|
||||
"""
|
||||
res = self._datetime_or_str(res)
|
||||
|
||||
def _datetime_or_str(self, res: BuildResult) -> BuildResult:
|
||||
"""HDF5 doesn't support datetime, so"""
|
||||
|
|
|
@ -22,10 +22,16 @@ BASEMODEL_COERCE_VALUE = """
|
|||
return handler(v)
|
||||
except Exception as e1:
|
||||
try:
|
||||
if hasattr(v, "value"):
|
||||
return handler(v.value)
|
||||
else:
|
||||
return handler(v["value"])
|
||||
except Exception as e2:
|
||||
raise e2 from e1
|
||||
return handler(v.value)
|
||||
except:
|
||||
raise e1
|
||||
"""
|
||||
|
||||
BASEMODEL_COERCE_CHILD = """
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def coerce_subclass(cls, v: Any, info) -> Any:
|
||||
\"\"\"Recast parent classes into child classes\"\"\"
|
||||
return v
|
||||
pdb.set_trace()
|
||||
"""
|
||||
|
|
|
@ -53,6 +53,7 @@ class DynamicTableMixin(BaseModel):
|
|||
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
|
||||
"id",
|
||||
"name",
|
||||
"categories",
|
||||
"colnames",
|
||||
"description",
|
||||
"hdf5_path",
|
||||
|
|
|
@ -39,7 +39,7 @@ from numpydantic.interface.hdf5 import H5ArrayPath
|
|||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from nwb_linkml.maps.hdf5 import get_references
|
||||
from nwb_linkml.maps.hdf5 import get_references, resolve_hardlink
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nwb_linkml.providers.schema import SchemaProvider
|
||||
|
@ -51,7 +51,7 @@ else:
|
|||
from typing_extensions import Never
|
||||
|
||||
|
||||
SKIP_PATTERN = re.compile("^/specifications.*")
|
||||
SKIP_PATTERN = re.compile("(^/specifications.*)|(\.specloc)")
|
||||
"""Nodes to always skip in reading e.g. because they are handled elsewhere"""
|
||||
|
||||
|
||||
|
@ -95,7 +95,11 @@ def hdf_dependency_graph(h5f: Path | h5py.File | h5py.Group) -> nx.DiGraph:
|
|||
|
||||
# add children, if group
|
||||
if isinstance(node, h5py.Group):
|
||||
children = [child.name for child in node.values() if not SKIP_PATTERN.match(child.name)]
|
||||
children = [
|
||||
resolve_hardlink(child)
|
||||
for child in node.values()
|
||||
if not SKIP_PATTERN.match(child.name)
|
||||
]
|
||||
edges = [(node.name, ref) for ref in children if not SKIP_PATTERN.match(ref)]
|
||||
g.add_edges_from(edges, label="child")
|
||||
|
||||
|
@ -157,21 +161,15 @@ def _load_node(
|
|||
else:
|
||||
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
|
||||
|
||||
# if obj.name == "/general/intracellular_ephys/simultaneous_recordings/recordings":
|
||||
# pdb.set_trace()
|
||||
|
||||
# resolve attr references
|
||||
for k, v in args.items():
|
||||
if isinstance(v, h5py.h5r.Reference):
|
||||
ref_path = h5f[v].name
|
||||
args[k] = context[ref_path]
|
||||
|
||||
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
|
||||
|
||||
# add additional needed params
|
||||
args["hdf5_path"] = path
|
||||
args["name"] = path.split("/")[-1]
|
||||
return model(**args)
|
||||
if "neurodata_type" in obj.attrs:
|
||||
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
|
||||
return model(**args)
|
||||
else:
|
||||
if "name" in args:
|
||||
del args["name"]
|
||||
if "hdf5_path" in args:
|
||||
del args["hdf5_path"]
|
||||
return args
|
||||
|
||||
|
||||
def _load_dataset(
|
||||
|
@ -214,6 +212,15 @@ def _load_dataset(
|
|||
res["name"] = dataset.name.split("/")[-1]
|
||||
res["hdf5_path"] = dataset.name
|
||||
|
||||
# resolve attr references
|
||||
for k, v in res.items():
|
||||
if isinstance(v, h5py.h5r.Reference):
|
||||
ref_path = h5f[v].name
|
||||
if SKIP_PATTERN.match(ref_path):
|
||||
res[k] = ref_path
|
||||
else:
|
||||
res[k] = context[ref_path]
|
||||
|
||||
if len(res) == 1:
|
||||
return res["value"]
|
||||
else:
|
||||
|
@ -242,8 +249,20 @@ def _load_group(group: h5py.Group, h5f: h5py.File, context: dict) -> dict:
|
|||
del res["namespace"]
|
||||
if "neurodata_type" in res:
|
||||
del res["neurodata_type"]
|
||||
res["name"] = group.name.split("/")[-1]
|
||||
res["hdf5_path"] = group.name
|
||||
name = group.name.split("/")[-1]
|
||||
if name:
|
||||
res["name"] = name
|
||||
res["hdf5_path"] = group.name
|
||||
|
||||
# resolve attr references
|
||||
for k, v in res.items():
|
||||
if isinstance(v, h5py.h5r.Reference):
|
||||
ref_path = h5f[v].name
|
||||
if SKIP_PATTERN.match(ref_path):
|
||||
res[k] = ref_path
|
||||
else:
|
||||
res[k] = context[ref_path]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
|
@ -319,6 +338,10 @@ class HDF5IO:
|
|||
res = _load_node(node, h5f, provider, context)
|
||||
context[node] = res
|
||||
|
||||
if path is None:
|
||||
path = "/"
|
||||
return context[path]
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
def write(self, path: Path) -> Never:
|
||||
|
|
|
@ -844,7 +844,7 @@ def resolve_references(
|
|||
return res, errors, completes
|
||||
|
||||
|
||||
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> HDF5_Path:
|
||||
def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> str:
|
||||
"""
|
||||
Unhelpfully, hardlinks are pretty challenging to detect with h5py, so we have
|
||||
to do extra work to check if an item is "real" or a hardlink to another item.
|
||||
|
@ -856,4 +856,4 @@ def resolve_hardlink(obj: Union[h5py.Group, h5py.Dataset]) -> HDF5_Path:
|
|||
We basically dereference the object and return that path instead of the path
|
||||
given by the object's ``name``
|
||||
"""
|
||||
return HDF5_Path(obj.file[obj.ref].name)
|
||||
return obj.file[obj.ref].name
|
||||
|
|
|
@ -113,3 +113,17 @@ def test_dependency_graph(nwb_file, tmp_output_dir):
|
|||
A_filtered = nx.nx_agraph.to_agraph(graph)
|
||||
A_filtered.draw(tmp_output_dir / "test_nwb_filtered.png", prog="dot")
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
def test_dependencies_hardlink(nwb_file):
|
||||
"""
|
||||
Test that hardlinks are resolved (eg. from /processing/ecephys/LFP/ElectricalSeries/electrodes
|
||||
to /acquisition/ElectricalSeries/electrodes
|
||||
Args:
|
||||
nwb_file:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -147,7 +147,7 @@ class TimeIntervals(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
tags: VectorData[Optional[NDArray[Any, str]]] = Field(
|
||||
tags: Optional[VectorData[NDArray[Any, str]]] = Field(
|
||||
None,
|
||||
description="""User-defined tags that identify or categorize events.""",
|
||||
json_schema_extra={
|
||||
|
@ -168,7 +168,7 @@ class TimeIntervals(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
timeseries: Named[Optional[TimeSeriesReferenceVectorData]] = Field(
|
||||
timeseries: Optional[Named[TimeSeriesReferenceVectorData]] = Field(
|
||||
None,
|
||||
description="""An index into a TimeSeries object.""",
|
||||
json_schema_extra={
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pdb
|
||||
import re
|
||||
import sys
|
||||
from datetime import date, datetime, time
|
||||
|
@ -72,12 +73,24 @@ class ConfiguredBaseModel(BaseModel):
|
|||
return handler(v)
|
||||
except Exception as e1:
|
||||
try:
|
||||
if hasattr(v, "value"):
|
||||
return handler(v.value)
|
||||
else:
|
||||
return handler(v.value)
|
||||
except AttributeError:
|
||||
try:
|
||||
return handler(v["value"])
|
||||
except Exception as e2:
|
||||
raise e2 from e1
|
||||
except (KeyError, TypeError):
|
||||
raise e1
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def coerce_subclass(cls, v: Any, info) -> Any:
|
||||
"""Recast parent classes into child classes"""
|
||||
if isinstance(v, BaseModel):
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
|
||||
# pdb.set_trace()
|
||||
if issubclass(annotation, type(v)) and annotation is not type(v):
|
||||
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
|
||||
return v
|
||||
|
||||
|
||||
class LinkMLMeta(RootModel):
|
||||
|
@ -176,28 +189,28 @@ class NWBFile(NWBContainer):
|
|||
...,
|
||||
description="""Date and time corresponding to time zero of all timestamps. The date is stored in UTC with local timezone offset as ISO 8601 extended formatted string: 2018-09-28T14:43:54.123+02:00. Dates stored in UTC end in \"Z\" with no timezone offset. Date accuracy is up to milliseconds. All times stored in the file use this time as reference (i.e., time zero).""",
|
||||
)
|
||||
acquisition: Optional[List[Union[DynamicTable, NWBDataInterface]]] = Field(
|
||||
acquisition: Optional[dict[str, Union[DynamicTable, NWBDataInterface]]] = Field(
|
||||
None,
|
||||
description="""Data streams recorded from the system, including ephys, ophys, tracking, etc. This group should be read-only after the experiment is completed and timestamps are corrected to a common timebase. The data stored here may be links to raw data stored in external NWB files. This will allow keeping bulky raw data out of the file while preserving the option of keeping some/all in the file. Acquired data includes tracking and experimental data streams (i.e., everything measured from the system). If bulky data is stored in the /acquisition group, the data can exist in a separate NWB file that is linked to by the file being used for processing and analysis.""",
|
||||
json_schema_extra={
|
||||
"linkml_meta": {"any_of": [{"range": "NWBDataInterface"}, {"range": "DynamicTable"}]}
|
||||
},
|
||||
)
|
||||
analysis: Optional[List[Union[DynamicTable, NWBContainer]]] = Field(
|
||||
analysis: Optional[dict[str, Union[DynamicTable, NWBContainer]]] = Field(
|
||||
None,
|
||||
description="""Lab-specific and custom scientific analysis of data. There is no defined format for the content of this group - the format is up to the individual user/lab. To facilitate sharing analysis data between labs, the contents here should be stored in standard types (e.g., neurodata_types) and appropriately documented. The file can store lab-specific and custom data analysis without restriction on its form or schema, reducing data formatting restrictions on end users. Such data should be placed in the analysis group. The analysis data should be documented so that it could be shared with other labs.""",
|
||||
json_schema_extra={
|
||||
"linkml_meta": {"any_of": [{"range": "NWBContainer"}, {"range": "DynamicTable"}]}
|
||||
},
|
||||
)
|
||||
scratch: Optional[List[Union[DynamicTable, NWBContainer]]] = Field(
|
||||
scratch: Optional[dict[str, Union[DynamicTable, NWBContainer]]] = Field(
|
||||
None,
|
||||
description="""A place to store one-off analysis results. Data placed here is not intended for sharing. By placing data here, users acknowledge that there is no guarantee that their data meets any standard.""",
|
||||
json_schema_extra={
|
||||
"linkml_meta": {"any_of": [{"range": "NWBContainer"}, {"range": "DynamicTable"}]}
|
||||
},
|
||||
)
|
||||
processing: Optional[List[ProcessingModule]] = Field(
|
||||
processing: Optional[dict[str, ProcessingModule]] = Field(
|
||||
None,
|
||||
description="""The home for ProcessingModules. These modules perform intermediate analysis of data that is necessary to perform before scientific analysis. Examples include spike clustering, extracting position from tracking data, stitching together image slices. ProcessingModules can be large and express many data sets from relatively complex analysis (e.g., spike detection and clustering) or small, representing extraction of position information from tracking video, or even binary lick/no-lick decisions. Common software tools (e.g., klustakwik, MClust) are expected to read/write data here. 'Processing' refers to intermediate analysis of the acquired data to make it more amenable to scientific analysis.""",
|
||||
json_schema_extra={"linkml_meta": {"any_of": [{"range": "ProcessingModule"}]}},
|
||||
|
@ -230,7 +243,7 @@ class NWBFileStimulus(ConfiguredBaseModel):
|
|||
"linkml_meta": {"equals_string": "stimulus", "ifabsent": "string(stimulus)"}
|
||||
},
|
||||
)
|
||||
presentation: Optional[List[Union[DynamicTable, NWBDataInterface, TimeSeries]]] = Field(
|
||||
presentation: Optional[dict[str, Union[DynamicTable, NWBDataInterface, TimeSeries]]] = Field(
|
||||
None,
|
||||
description="""Stimuli presented during the experiment.""",
|
||||
json_schema_extra={
|
||||
|
@ -243,7 +256,7 @@ class NWBFileStimulus(ConfiguredBaseModel):
|
|||
}
|
||||
},
|
||||
)
|
||||
templates: Optional[List[Union[Images, TimeSeries]]] = Field(
|
||||
templates: Optional[dict[str, Union[Images, TimeSeries]]] = Field(
|
||||
None,
|
||||
description="""Template stimuli. Timestamps in templates are based on stimulus design and are relative to the beginning of the stimulus. When templates are used, the stimulus instances must convert presentation times to the experiment`s time reference frame.""",
|
||||
json_schema_extra={
|
||||
|
@ -327,7 +340,7 @@ class NWBFileGeneral(ConfiguredBaseModel):
|
|||
None,
|
||||
description="""Place-holder than can be extended so that lab-specific meta-data can be placed in /general.""",
|
||||
)
|
||||
devices: Optional[List[Device]] = Field(
|
||||
devices: Optional[dict[str, Device]] = Field(
|
||||
None,
|
||||
description="""Description of hardware devices used during experiment, e.g., monitors, ADC boards, microscopes, etc.""",
|
||||
json_schema_extra={"linkml_meta": {"any_of": [{"range": "Device"}]}},
|
||||
|
@ -408,7 +421,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
"linkml_meta": {"equals_string": "electrodes", "ifabsent": "string(electrodes)"}
|
||||
},
|
||||
)
|
||||
x: VectorData[Optional[NDArray[Any, float]]] = Field(
|
||||
x: Optional[VectorData[NDArray[Any, float]]] = Field(
|
||||
None,
|
||||
description="""x coordinate of the channel location in the brain (+x is posterior).""",
|
||||
json_schema_extra={
|
||||
|
@ -417,7 +430,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
y: VectorData[Optional[NDArray[Any, float]]] = Field(
|
||||
y: Optional[VectorData[NDArray[Any, float]]] = Field(
|
||||
None,
|
||||
description="""y coordinate of the channel location in the brain (+y is inferior).""",
|
||||
json_schema_extra={
|
||||
|
@ -426,7 +439,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
z: VectorData[Optional[NDArray[Any, float]]] = Field(
|
||||
z: Optional[VectorData[NDArray[Any, float]]] = Field(
|
||||
None,
|
||||
description="""z coordinate of the channel location in the brain (+z is right).""",
|
||||
json_schema_extra={
|
||||
|
@ -435,7 +448,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
imp: VectorData[Optional[NDArray[Any, float]]] = Field(
|
||||
imp: Optional[VectorData[NDArray[Any, float]]] = Field(
|
||||
None,
|
||||
description="""Impedance of the channel, in ohms.""",
|
||||
json_schema_extra={
|
||||
|
@ -453,7 +466,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
filtering: VectorData[Optional[NDArray[Any, str]]] = Field(
|
||||
filtering: Optional[VectorData[NDArray[Any, str]]] = Field(
|
||||
None,
|
||||
description="""Description of hardware filtering, including the filter name and frequency cutoffs.""",
|
||||
json_schema_extra={
|
||||
|
@ -474,7 +487,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
rel_x: VectorData[Optional[NDArray[Any, float]]] = Field(
|
||||
rel_x: Optional[VectorData[NDArray[Any, float]]] = Field(
|
||||
None,
|
||||
description="""x coordinate in electrode group""",
|
||||
json_schema_extra={
|
||||
|
@ -483,7 +496,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
rel_y: VectorData[Optional[NDArray[Any, float]]] = Field(
|
||||
rel_y: Optional[VectorData[NDArray[Any, float]]] = Field(
|
||||
None,
|
||||
description="""y coordinate in electrode group""",
|
||||
json_schema_extra={
|
||||
|
@ -492,7 +505,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
rel_z: VectorData[Optional[NDArray[Any, float]]] = Field(
|
||||
rel_z: Optional[VectorData[NDArray[Any, float]]] = Field(
|
||||
None,
|
||||
description="""z coordinate in electrode group""",
|
||||
json_schema_extra={
|
||||
|
@ -501,7 +514,7 @@ class ExtracellularEphysElectrodes(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
reference: VectorData[Optional[NDArray[Any, str]]] = Field(
|
||||
reference: Optional[VectorData[NDArray[Any, str]]] = Field(
|
||||
None,
|
||||
description="""Description of the reference electrode and/or reference scheme used for this electrode, e.g., \"stainless steel skull screw\" or \"online common average referencing\".""",
|
||||
json_schema_extra={
|
||||
|
|
|
@ -518,7 +518,7 @@ class Units(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
obs_intervals: VectorData[Optional[NDArray[Shape["* num_intervals, 2 start_end"], float]]] = (
|
||||
obs_intervals: Optional[VectorData[NDArray[Shape["* num_intervals, 2 start_end"], float]]] = (
|
||||
Field(
|
||||
None,
|
||||
description="""Observation intervals for each unit.""",
|
||||
|
@ -546,7 +546,7 @@ class Units(DynamicTable):
|
|||
}
|
||||
},
|
||||
)
|
||||
electrodes: Named[Optional[DynamicTableRegion]] = Field(
|
||||
electrodes: Optional[Named[DynamicTableRegion]] = Field(
|
||||
None,
|
||||
description="""Electrode that each spike unit came from, specified using a DynamicTableRegion.""",
|
||||
json_schema_extra={
|
||||
|
@ -561,23 +561,23 @@ class Units(DynamicTable):
|
|||
electrode_group: Optional[List[ElectrodeGroup]] = Field(
|
||||
None, description="""Electrode group that each spike unit came from."""
|
||||
)
|
||||
waveform_mean: VectorData[
|
||||
Optional[
|
||||
waveform_mean: Optional[
|
||||
VectorData[
|
||||
Union[
|
||||
NDArray[Shape["* num_units, * num_samples"], float],
|
||||
NDArray[Shape["* num_units, * num_samples, * num_electrodes"], float],
|
||||
]
|
||||
]
|
||||
] = Field(None, description="""Spike waveform mean for each spike unit.""")
|
||||
waveform_sd: VectorData[
|
||||
Optional[
|
||||
waveform_sd: Optional[
|
||||
VectorData[
|
||||
Union[
|
||||
NDArray[Shape["* num_units, * num_samples"], float],
|
||||
NDArray[Shape["* num_units, * num_samples, * num_electrodes"], float],
|
||||
]
|
||||
]
|
||||
] = Field(None, description="""Spike waveform standard deviation for each spike unit.""")
|
||||
waveforms: VectorData[Optional[NDArray[Shape["* num_waveforms, * num_samples"], float]]] = (
|
||||
waveforms: Optional[VectorData[NDArray[Shape["* num_waveforms, * num_samples"], float]]] = (
|
||||
Field(
|
||||
None,
|
||||
description="""Individual waveforms for each spike on each electrode. This is a doubly indexed column. The 'waveforms_index' column indexes which waveforms in this column belong to the same spike event for a given unit, where each waveform was recorded from a different electrode. The 'waveforms_index_index' column indexes the 'waveforms_index' column to indicate which spike events belong to a given unit. For example, if the 'waveforms_index_index' column has values [2, 5, 6], then the first 2 elements of the 'waveforms_index' column correspond to the 2 spike events of the first unit, the next 3 elements of the 'waveforms_index' column correspond to the 3 spike events of the second unit, and the next 1 element of the 'waveforms_index' column corresponds to the 1 spike event of the third unit. If the 'waveforms_index' column has values [3, 6, 8, 10, 12, 13], then the first 3 elements of the 'waveforms' column contain the 3 spike waveforms that were recorded from 3 different electrodes for the first spike time of the first unit. See https://nwb-schema.readthedocs.io/en/stable/format_description.html#doubly-ragged-arrays for a graphical representation of this example. When there is only one electrode for each unit (i.e., each spike time is associated with a single waveform), then the 'waveforms_index' column will have values 1, 2, ..., N, where N is the number of spike events. The number of electrodes for each spike event should be the same within a given unit. The 'electrodes' column should be used to indicate which electrodes are associated with each unit, and the order of the waveforms within a given unit x spike event should be in the same order as the electrodes referenced in the 'electrodes' column of this table. The number of samples for each waveform must be the same.""",
|
||||
|
|
|
@ -46,12 +46,23 @@ class ConfiguredBaseModel(BaseModel):
|
|||
return handler(v)
|
||||
except Exception as e1:
|
||||
try:
|
||||
if hasattr(v, "value"):
|
||||
return handler(v.value)
|
||||
else:
|
||||
return handler(v.value)
|
||||
except AttributeError:
|
||||
try:
|
||||
return handler(v["value"])
|
||||
except Exception as e2:
|
||||
raise e2 from e1
|
||||
except (KeyError, TypeError):
|
||||
raise e1
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def coerce_subclass(cls, v: Any, info) -> Any:
|
||||
"""Recast parent classes into child classes"""
|
||||
if isinstance(v, BaseModel):
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
|
||||
if issubclass(annotation, type(v)) and annotation is not type(v):
|
||||
v = annotation(**v.__dict__)
|
||||
return v
|
||||
|
||||
|
||||
class LinkMLMeta(RootModel):
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pdb
|
||||
import re
|
||||
import sys
|
||||
from datetime import date, datetime, time
|
||||
|
@ -50,6 +51,7 @@ class ConfiguredBaseModel(BaseModel):
|
|||
arbitrary_types_allowed=True,
|
||||
use_enum_values=True,
|
||||
strict=False,
|
||||
validation_error_cause=True,
|
||||
)
|
||||
hdf5_path: Optional[str] = Field(
|
||||
None, description="The absolute path that this object is stored in an NWB file"
|
||||
|
@ -67,18 +69,35 @@ class ConfiguredBaseModel(BaseModel):
|
|||
|
||||
@field_validator("*", mode="wrap")
|
||||
@classmethod
|
||||
def coerce_value(cls, v: Any, handler) -> Any:
|
||||
def coerce_value(cls, v: Any, handler, info) -> Any:
|
||||
"""Try to rescue instantiation by using the value field"""
|
||||
try:
|
||||
return handler(v)
|
||||
except Exception as e1:
|
||||
try:
|
||||
if hasattr(v, "value"):
|
||||
return handler(v.value)
|
||||
else:
|
||||
return handler(v.value)
|
||||
except AttributeError:
|
||||
try:
|
||||
return handler(v["value"])
|
||||
except Exception as e2:
|
||||
raise e2 from e1
|
||||
except (KeyError, TypeError):
|
||||
raise e1
|
||||
# try:
|
||||
# if hasattr(v, "value"):
|
||||
# else:
|
||||
# return handler(v["value"])
|
||||
# except Exception as e2:
|
||||
# raise e2 from e1
|
||||
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def coerce_subclass(cls, v: Any, info) -> Any:
|
||||
"""Recast parent classes into child classes"""
|
||||
if isinstance(v, BaseModel):
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
|
||||
if issubclass(annotation, type(v)) and annotation is not type(v):
|
||||
v = annotation(**v.__dict__)
|
||||
return v
|
||||
|
||||
|
||||
class LinkMLMeta(RootModel):
|
||||
|
@ -310,11 +329,12 @@ class DynamicTableMixin(BaseModel):
|
|||
but simplifying along the way :)
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow", validate_assignment=True)
|
||||
model_config = ConfigDict(extra="allow", validate_assignment=True, validation_error_cause=True)
|
||||
__pydantic_extra__: Dict[str, Union["VectorDataMixin", "VectorIndexMixin", "NDArray", list]]
|
||||
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = (
|
||||
"id",
|
||||
"name",
|
||||
"categories",
|
||||
"colnames",
|
||||
"description",
|
||||
"hdf5_path",
|
||||
|
@ -510,6 +530,7 @@ class DynamicTableMixin(BaseModel):
|
|||
if k not in cls.NON_COLUMN_FIELDS
|
||||
and not k.endswith("_index")
|
||||
and not isinstance(model[k], VectorIndexMixin)
|
||||
and model[k] is not None
|
||||
]
|
||||
model["colnames"] = colnames
|
||||
else:
|
||||
|
@ -525,6 +546,7 @@ class DynamicTableMixin(BaseModel):
|
|||
and not k.endswith("_index")
|
||||
and k not in model["colnames"]
|
||||
and not isinstance(model[k], VectorIndexMixin)
|
||||
and model[k] is not None
|
||||
]
|
||||
)
|
||||
model["colnames"] = colnames
|
||||
|
@ -597,9 +619,9 @@ class DynamicTableMixin(BaseModel):
|
|||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._columns.values()] + [len(self.id)]
|
||||
lengths = [len(v) for v in self._columns.values() if v is not None] + [len(self.id)]
|
||||
assert all([length == lengths[0] for length in lengths]), (
|
||||
"Columns are not of equal length! "
|
||||
"DynamicTable Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
@ -645,7 +667,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
|||
and also it's not so easy to copy a pydantic validator method.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow", validate_assignment=True)
|
||||
model_config = ConfigDict(extra="allow", validate_assignment=True, validation_error_cause=True)
|
||||
__pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]]
|
||||
|
||||
NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = (
|
||||
|
@ -654,6 +676,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
|||
"colnames",
|
||||
"description",
|
||||
"hdf5_path",
|
||||
"id",
|
||||
"object_id",
|
||||
)
|
||||
|
||||
|
@ -684,7 +707,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
|||
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str):
|
||||
# get a slice of a single table
|
||||
return self._categories[item[1]][item[0]]
|
||||
elif isinstance(item, (int, slice, Iterable)):
|
||||
elif isinstance(item, (int, slice, Iterable, np.int_)):
|
||||
# get a slice of all the tables
|
||||
ids = self.id[item]
|
||||
if not isinstance(ids, Iterable):
|
||||
|
@ -696,9 +719,9 @@ class AlignedDynamicTableMixin(BaseModel):
|
|||
if isinstance(table, pd.DataFrame):
|
||||
table = table.reset_index()
|
||||
elif isinstance(table, np.ndarray):
|
||||
table = pd.DataFrame({category_name: [table]})
|
||||
table = pd.DataFrame({category_name: [table]}, index=ids.index)
|
||||
elif isinstance(table, Iterable):
|
||||
table = pd.DataFrame({category_name: table})
|
||||
table = pd.DataFrame({category_name: table}, index=ids.index)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Don't know how to construct category table for {category_name}"
|
||||
|
@ -708,6 +731,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
|||
names = [self.name] + self.categories
|
||||
# construct below in case we need to support array indexing in the future
|
||||
else:
|
||||
pdb.set_trace()
|
||||
raise ValueError(
|
||||
f"Dont know how to index with {item}, "
|
||||
"need an int, string, slice, ndarray, or tuple[int | slice, str]"
|
||||
|
@ -818,7 +842,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
|||
"""
|
||||
lengths = [len(v) for v in self._categories.values()] + [len(self.id)]
|
||||
assert all([length == lengths[0] for length in lengths]), (
|
||||
"Columns are not of equal length! "
|
||||
"AlignedDynamicTable Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.categories}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
|
Loading…
Reference in a new issue