nwb-linkml/nwb_linkml/tests/fixtures.py

762 lines
26 KiB
Python

import shutil
from dataclasses import dataclass, field
from datetime import datetime
from itertools import product
from pathlib import Path
from types import ModuleType
from typing import Dict, Optional
import numpy as np
import pytest
from linkml_runtime.dumpers import yaml_dumper
from linkml_runtime.linkml_model import (
ClassDefinition,
Prefix,
SchemaDefinition,
SlotDefinition,
TypeDefinition,
)
from pynwb import NWBHDF5IO, NWBFile, TimeSeries
from pynwb.base import TimeSeriesReference, TimeSeriesReferenceVectorData
from pynwb.behavior import Position, SpatialSeries
from pynwb.core import DynamicTable, VectorData
from pynwb.ecephys import LFP, ElectricalSeries
from pynwb.file import Subject
from pynwb.icephys import VoltageClampSeries, VoltageClampStimulusSeries
from pynwb.image import ImageSeries
from pynwb.ophys import (
CorrectedImageStack,
Fluorescence,
ImageSegmentation,
MotionCorrection,
OnePhotonSeries,
OpticalChannel,
RoiResponseSeries,
TwoPhotonSeries,
)
from nwb_linkml.adapters.namespaces import NamespacesAdapter
from nwb_linkml.io import schema as io
from nwb_linkml.providers import LinkMLProvider, PydanticProvider
from nwb_linkml.providers.linkml import LinkMLSchemaBuild
from nwb_schema_language import Attribute, Dataset, Group
__all__ = [
"NWBSchemaTest",
"TestSchemas",
"data_dir",
"linkml_schema",
"linkml_schema_bare",
"nwb_core_fixture",
"nwb_file",
"nwb_schema",
"tmp_output_dir",
"tmp_output_dir_func",
"tmp_output_dir_mod",
]
@pytest.fixture(scope="session")
def tmp_output_dir() -> Path:
path = Path(__file__).parent.resolve() / "__tmp__"
if path.exists():
for subdir in path.iterdir():
if subdir.name == "git":
# don't wipe out git repos every time, they don't rly change
continue
elif subdir.is_file() and subdir.parent != path:
continue
elif subdir.is_file():
subdir.unlink(missing_ok=True)
else:
shutil.rmtree(str(subdir))
path.mkdir(exist_ok=True)
return path
@pytest.fixture(scope="function")
def tmp_output_dir_func(tmp_output_dir) -> Path:
"""
tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected
"""
subpath = tmp_output_dir / "__tmpfunc__"
if subpath.exists():
shutil.rmtree(str(subpath))
subpath.mkdir()
return subpath
@pytest.fixture(scope="module")
def tmp_output_dir_mod(tmp_output_dir) -> Path:
"""
tmp output dir that gets cleared between every function
cleans at the start rather than at cleanup in case the output is to be inspected
"""
subpath = tmp_output_dir / "__tmpmod__"
if subpath.exists():
shutil.rmtree(str(subpath))
subpath.mkdir()
return subpath
@pytest.fixture(scope="session", params=[{"core_version": "2.7.0", "hdmf_version": "1.8.0"}])
def nwb_core_fixture(request) -> NamespacesAdapter:
nwb_core = io.load_nwb_core(**request.param)
assert (
request.param["core_version"] in nwb_core.versions["core"]
) # 2.6.0 is actually 2.6.0-alpha
assert nwb_core.versions["hdmf-common"] == request.param["hdmf_version"]
return nwb_core
@pytest.fixture(scope="session")
def nwb_core_linkml(nwb_core_fixture, tmp_output_dir) -> LinkMLSchemaBuild:
provider = LinkMLProvider(tmp_output_dir, allow_repo=False, verbose=False)
result = provider.build(ns_adapter=nwb_core_fixture, force=True)
return result["core"]
@pytest.fixture(scope="session")
def nwb_core_module(nwb_core_linkml: LinkMLSchemaBuild, tmp_output_dir) -> ModuleType:
"""
Generated pydantic namespace from nwb core
"""
provider = PydanticProvider(tmp_output_dir, verbose=False)
result = provider.build(nwb_core_linkml.namespace, force=True)
mod = provider.get("core", version=nwb_core_linkml.version, allow_repo=False)
return mod
@pytest.fixture(scope="session")
def data_dir() -> Path:
path = Path(__file__).parent.resolve() / "data"
return path
@dataclass
class TestSchemas:
__test__ = False
core: SchemaDefinition
imported: SchemaDefinition
namespace: SchemaDefinition
core_path: Optional[Path] = None
imported_path: Optional[Path] = None
namespace_path: Optional[Path] = None
@pytest.fixture(scope="module")
def linkml_schema_bare() -> TestSchemas:
schema = TestSchemas(
core=SchemaDefinition(
name="core",
id="core",
version="1.0.1",
imports=["imported", "linkml:types"],
default_prefix="core",
prefixes={"linkml": Prefix("linkml", "https://w3id.org/linkml")},
description="Test core schema",
classes=[
ClassDefinition(
name="MainTopLevel",
description="The main class we are testing!",
is_a="MainThing",
tree_root=True,
attributes=[
SlotDefinition(
name="name",
description="A fixed property that should use Literal and be frozen",
range="string",
required=True,
ifabsent="string(toplevel)",
equals_string="toplevel",
identifier=True,
),
SlotDefinition(name="array", range="MainTopLevel__Array"),
SlotDefinition(
name="SkippableSlot", description="A slot that was meant to be skipped!"
),
SlotDefinition(
name="inline_dict",
description=(
"This should be inlined as a dictionary despite this class having"
" an identifier"
),
multivalued=True,
inlined=True,
inlined_as_list=False,
any_of=[{"range": "OtherClass"}, {"range": "StillAnotherClass"}],
),
],
),
ClassDefinition(
name="MainTopLevel__Array",
description="Main class's array",
is_a="Arraylike",
attributes=[
SlotDefinition(name="x", range="numeric", required=True),
SlotDefinition(name="y", range="numeric", required=True),
SlotDefinition(
name="z",
range="numeric",
required=False,
maximum_cardinality=3,
minimum_cardinality=3,
),
SlotDefinition(
name="a",
range="numeric",
required=False,
minimum_cardinality=4,
maximum_cardinality=4,
),
],
),
ClassDefinition(
name="skippable",
description="A class that lives to be skipped!",
),
ClassDefinition(
name="OtherClass",
description="Another class yno!",
attributes=[
SlotDefinition(name="name", range="string", required=True, identifier=True)
],
),
ClassDefinition(
name="StillAnotherClass",
description="And yet another!",
attributes=[
SlotDefinition(name="name", range="string", required=True, identifier=True)
],
),
],
types=[TypeDefinition(name="numeric", typeof="float")],
),
imported=SchemaDefinition(
name="imported",
id="imported",
version="1.4.5",
default_prefix="core",
imports=["linkml:types"],
prefixes={"linkml": Prefix("linkml", "https://w3id.org/linkml")},
classes=[
ClassDefinition(
name="MainThing",
description="Class imported by our main thing class!",
attributes=[SlotDefinition(name="meta_slot", range="string")],
),
ClassDefinition(name="Arraylike", abstract=True),
],
),
namespace=SchemaDefinition(
name="namespace",
id="namespace",
version="1.1.1",
default_prefix="namespace",
annotations=[
{"tag": "is_namespace", "value": "True"},
{"tag": "namespace", "value": "core"},
],
description="A namespace package that should import all other classes",
imports=["core", "imported"],
),
)
return schema
@pytest.fixture(scope="module")
def linkml_schema(tmp_output_dir_mod, linkml_schema_bare) -> TestSchemas:
"""
A test schema that includes
- Two schemas, one importing from the other
- Arraylike
- Required/static "name" field
- linkml metadata like tree_root
- skipping classes
"""
schema = linkml_schema_bare
test_schema_path = tmp_output_dir_mod / "test_schema"
test_schema_path.mkdir()
core_path = test_schema_path / "core.yaml"
imported_path = test_schema_path / "imported.yaml"
namespace_path = test_schema_path / "namespace.yaml"
schema.core_path = core_path
schema.imported_path = imported_path
schema.namespace_path = namespace_path
yaml_dumper.dump(schema.core, schema.core_path)
yaml_dumper.dump(schema.imported, schema.imported_path)
yaml_dumper.dump(schema.namespace, schema.namespace_path)
return schema
@dataclass
class NWBSchemaTest:
datasets: Dict[str, Dataset] = field(default_factory=dict)
groups: Dict[str, Group] = field(default_factory=dict)
@pytest.fixture()
def nwb_schema() -> NWBSchemaTest:
"""Minimal NWB schema for testing"""
image = Dataset(
neurodata_type_def="Image",
dtype="numeric",
neurodata_type_inc="NWBData",
dims=[["x", "y"], ["x", "y", "r, g, b"], ["x", "y", "r, g, b, a"]],
shape=[[None, None], [None, None, 3], [None, None, 4]],
doc="An image!",
attributes=[
Attribute(dtype="float32", name="resolution", doc="resolution!"),
Attribute(dtype="text", name="description", doc="Description!"),
],
)
images = Group(
neurodata_type_def="Images",
neurodata_type_inc="NWBDataInterface",
default_name="Images",
doc="Images!",
attributes=[Attribute(dtype="text", name="description", doc="description!")],
datasets=[
Dataset(neurodata_type_inc="Image", quantity="+", doc="images!"),
Dataset(
neurodata_type_inc="ImageReferences",
name="order_of_images",
doc="Image references!",
quantity="?",
),
],
)
return NWBSchemaTest(datasets={"image": image}, groups={"images": images})
@pytest.fixture(scope="session")
def nwb_file(tmp_output_dir) -> Path:
"""
NWB File created with pynwb that uses all the weird language features
Borrowing code from pynwb docs in one humonogous fixture function
since there's not really a reason to
"""
generator = np.random.default_rng()
nwb_path = tmp_output_dir / "test_nwb.nwb"
if nwb_path.exists():
return nwb_path
nwbfile = NWBFile(
session_description="All that you touch, you change.", # required
identifier="1111-1111-1111-1111", # required
session_start_time=datetime(year=2024, month=1, day=1), # required
session_id="session_1234", # optional
experimenter=[
"Lauren Oya Olamina",
], # optional
institution="Earthseed Research Institute", # optional
experiment_description="All that you change, changes you.", # optional
keywords=["behavior", "belief"], # optional
related_publications="doi:10.1016/j.neuron.2016.12.011", # optional
)
subject = Subject(
subject_id="001",
age="P90D",
description="mouse 5",
species="Mus musculus",
sex="M",
)
nwbfile.subject = subject
data = np.arange(100, 200, 10)
timestamps = np.arange(10.0)
time_series_with_timestamps = TimeSeries(
name="test_timeseries",
description="an example time series",
data=data,
unit="m",
timestamps=timestamps,
)
nwbfile.add_acquisition(time_series_with_timestamps)
position_data = np.array([np.linspace(0, 10, 50), np.linspace(0, 8, 50)]).T
position_timestamps = np.linspace(0, 50).astype(float) / 200
spatial_series_obj = SpatialSeries(
name="SpatialSeries",
description="(x,y) position in open field",
data=position_data,
timestamps=position_timestamps,
reference_frame="(0,0) is bottom left corner",
)
# name is set to "Position" by default
position_obj = Position(spatial_series=spatial_series_obj)
behavior_module = nwbfile.create_processing_module(
name="behavior", description="processed behavioral data"
)
behavior_module.add(position_obj)
nwbfile.add_trial_column(
name="correct",
description="whether the trial was correct",
)
nwbfile.add_trial(start_time=1.0, stop_time=5.0, correct=True)
nwbfile.add_trial(start_time=6.0, stop_time=10.0, correct=False)
# --------------------------------------------------
# Extracellular Ephys
# https://pynwb.readthedocs.io/en/latest/tutorials/domain/ecephys.html
# --------------------------------------------------
device = nwbfile.create_device(name="array", description="old reliable", manufacturer="diy")
nwbfile.add_electrode_column(name="label", description="label of electrode")
nshanks = 4
nchannels_per_shank = 3
electrode_counter = 0
for ishank in range(nshanks):
# create an electrode group for this shank
electrode_group = nwbfile.create_electrode_group(
name=f"shank{ishank}",
description=f"electrode group for shank {ishank}",
device=device,
location="brain area",
)
# add electrodes to the electrode table
for ielec in range(nchannels_per_shank):
nwbfile.add_electrode(
group=electrode_group,
label=f"shank{ishank}elec{ielec}",
location="brain area",
)
electrode_counter += 1
all_table_region = nwbfile.create_electrode_table_region(
region=list(range(electrode_counter)), # reference row indices 0 to N-1
description="all electrodes",
)
raw_data = generator.standard_normal((50, 12))
raw_electrical_series = ElectricalSeries(
name="ElectricalSeries",
description="Raw acquisition traces",
data=raw_data,
electrodes=all_table_region,
starting_time=0.0,
# timestamp of the first sample in seconds relative to the session start time
rate=20000.0, # in Hz
)
nwbfile.add_acquisition(raw_electrical_series)
# --------------------------------------------------
# LFP
# --------------------------------------------------
lfp_data = generator.standard_normal((50, 12))
lfp_electrical_series = ElectricalSeries(
name="ElectricalSeries",
description="LFP data",
data=lfp_data,
electrodes=all_table_region,
starting_time=0.0,
rate=200.0,
)
lfp = LFP(electrical_series=lfp_electrical_series)
ecephys_module = nwbfile.create_processing_module(
name="ecephys", description="processed extracellular electrophysiology data"
)
ecephys_module.add(lfp)
# Spike Times
nwbfile.add_unit_column(name="quality", description="sorting quality")
firing_rate = 20
n_units = 10
res = 1000
duration = 20
for _ in range(n_units):
spike_times = np.where(generator.random(res * duration) < (firing_rate / res))[0] / res
nwbfile.add_unit(spike_times=spike_times, quality="good")
# --------------------------------------------------
# Intracellular ephys
# --------------------------------------------------
device = nwbfile.create_device(name="Heka ITC-1600")
electrode = nwbfile.create_icephys_electrode(
name="elec0", description="a mock intracellular electrode", device=device
)
stimulus = VoltageClampStimulusSeries(
name="ccss",
data=[1, 2, 3, 4, 5],
starting_time=123.6,
rate=10e3,
electrode=electrode,
gain=0.02,
sweep_number=np.uint64(15),
)
# Create and icephys response
response = VoltageClampSeries(
name="vcs",
data=[0.1, 0.2, 0.3, 0.4, 0.5],
conversion=1e-12,
resolution=np.nan,
starting_time=123.6,
rate=20e3,
electrode=electrode,
gain=0.02,
capacitance_slow=100e-12,
resistance_comp_correction=70.0,
sweep_number=np.uint64(15),
)
# we can also add stimulus template data as follows
rowindex = nwbfile.add_intracellular_recording(
electrode=electrode, stimulus=stimulus, response=response, id=10
)
rowindex2 = nwbfile.add_intracellular_recording(
electrode=electrode,
stimulus=stimulus,
stimulus_start_index=1,
stimulus_index_count=3,
response=response,
response_start_index=2,
response_index_count=3,
id=11,
)
rowindex3 = nwbfile.add_intracellular_recording(electrode=electrode, response=response, id=12)
nwbfile.intracellular_recordings.add_column(
name="recording_tag",
data=["A1", "A2", "A3"],
description="String with a recording tag",
)
location_column = VectorData(
name="location",
data=["Mordor", "Gondor", "Rohan"],
description="Recording location in Middle Earth",
)
lab_category = DynamicTable(
name="recording_lab_data",
description="category table for lab-specific recording metadata",
colnames=[
"location",
],
columns=[
location_column,
],
)
# Add the table as a new category to our intracellular_recordings
nwbfile.intracellular_recordings.add_category(category=lab_category)
nwbfile.intracellular_recordings.add_column(
name="voltage_threshold",
data=[0.1, 0.12, 0.13],
description="Just an example column on the electrodes category table",
category="electrodes",
)
stimulus_template = VoltageClampStimulusSeries(
name="ccst",
data=[0, 1, 2, 3, 4],
starting_time=0.0,
rate=10e3,
electrode=electrode,
gain=0.02,
)
nwbfile.add_stimulus_template(stimulus_template)
nwbfile.intracellular_recordings.add_column(
name="stimulus_template",
data=[
TimeSeriesReference(0, 5, stimulus_template),
# (start_index, index_count, stimulus_template)
TimeSeriesReference(1, 3, stimulus_template),
TimeSeriesReference.empty(stimulus_template),
],
# if there was no data for that recording, use empty reference
description=(
"Column storing the reference to the stimulus template for the recording (rows)."
),
category="stimuli",
col_cls=TimeSeriesReferenceVectorData,
)
icephys_simultaneous_recordings = nwbfile.get_icephys_simultaneous_recordings()
icephys_simultaneous_recordings.add_column(
name="simultaneous_recording_tag",
description="A custom tag for simultaneous_recordings",
)
simultaneous_index = nwbfile.add_icephys_simultaneous_recording(
recordings=[rowindex, rowindex2, rowindex3],
id=12,
simultaneous_recording_tag="LabTag1",
)
repetition_index = nwbfile.add_icephys_repetition(
sequential_recordings=[simultaneous_index], id=17
)
nwbfile.add_icephys_experimental_condition(repetitions=[repetition_index], id=19)
nwbfile.icephys_experimental_conditions.add_column(
name="tag",
data=np.arange(1),
description="integer tag for a experimental condition",
)
# --------------------------------------------------
# Calcium Imaging
# https://pynwb.readthedocs.io/en/latest/tutorials/domain/ophys.html
# --------------------------------------------------
device = nwbfile.create_device(
name="Microscope",
description="My two-photon microscope",
manufacturer="The best microscope manufacturer",
)
optical_channel = OpticalChannel(
name="OpticalChannel",
description="an optical channel",
emission_lambda=500.0,
)
imaging_plane = nwbfile.create_imaging_plane(
name="ImagingPlane",
optical_channel=optical_channel,
imaging_rate=30.0,
description="a very interesting part of the brain",
device=device,
excitation_lambda=600.0,
indicator="GFP",
location="V1",
grid_spacing=[0.01, 0.01],
grid_spacing_unit="meters",
origin_coords=[1.0, 2.0, 3.0],
origin_coords_unit="meters",
)
one_p_series = OnePhotonSeries(
name="OnePhotonSeries",
description="Raw 1p data",
data=np.ones((1000, 100, 100)),
imaging_plane=imaging_plane,
rate=1.0,
unit="normalized amplitude",
)
nwbfile.add_acquisition(one_p_series)
two_p_series = TwoPhotonSeries(
name="TwoPhotonSeries",
description="Raw 2p data",
data=np.ones((1000, 100, 100)),
imaging_plane=imaging_plane,
rate=1.0,
unit="normalized amplitude",
)
nwbfile.add_acquisition(two_p_series)
corrected = ImageSeries(
name="corrected", # this must be named "corrected"
description="A motion corrected image stack",
data=np.ones((1000, 100, 100)),
unit="na",
format="raw",
starting_time=0.0,
rate=1.0,
)
xy_translation = TimeSeries(
name="xy_translation",
description="x,y translation in pixels",
data=np.ones((1000, 2)),
unit="pixels",
starting_time=0.0,
rate=1.0,
)
corrected_image_stack = CorrectedImageStack(
corrected=corrected,
original=one_p_series,
xy_translation=xy_translation,
)
motion_correction = MotionCorrection(corrected_image_stacks=[corrected_image_stack])
ophys_module = nwbfile.create_processing_module(
name="ophys", description="optical physiology processed data"
)
ophys_module.add(motion_correction)
img_seg = ImageSegmentation()
ps = img_seg.create_plane_segmentation(
name="PlaneSegmentation",
description="output from segmenting my favorite imaging plane",
imaging_plane=imaging_plane,
reference_images=one_p_series, # optional
)
ophys_module.add(img_seg)
for _ in range(30):
image_mask = np.zeros((100, 100))
# randomly generate example image masks
x = generator.integers(0, 95)
y = generator.integers(0, 95)
image_mask[x : x + 5, y : y + 5] = 1
# add image mask to plane segmentation
ps.add_roi(image_mask=image_mask)
ps2 = img_seg.create_plane_segmentation(
name="PlaneSegmentation2",
description="output from segmenting my favorite imaging plane",
imaging_plane=imaging_plane,
reference_images=one_p_series, # optional
)
for _ in range(30):
# randomly generate example starting points for region
x = generator.integers(0, 95)
y = generator.integers(0, 95)
# define an example 4 x 3 region of pixels of weight '1'
pixel_mask = [(ix, iy, 1) for ix in range(x, x + 4) for iy in range(y, y + 3)]
# add pixel mask to plane segmentation
ps2.add_roi(pixel_mask=pixel_mask)
ps3 = img_seg.create_plane_segmentation(
name="PlaneSegmentation3",
description="output from segmenting my favorite imaging plane",
imaging_plane=imaging_plane,
reference_images=one_p_series, # optional
)
for _ in range(30):
# randomly generate example starting points for region
x = generator.integers(0, 95)
y = generator.integers(0, 95)
z = generator.integers(0, 15)
# define an example 4 x 3 x 2 voxel region of weight '0.5'
voxel_mask = []
for ix, iy, iz in product(range(x, x + 4), range(y, y + 3), range(z, z + 2)):
voxel_mask.append((ix, iy, iz, 0.5))
# add voxel mask to plane segmentation
ps3.add_roi(voxel_mask=voxel_mask)
rt_region = ps.create_roi_table_region(region=[0, 1], description="the first of two ROIs")
roi_resp_series = RoiResponseSeries(
name="RoiResponseSeries",
description="Fluorescence responses for two ROIs",
data=np.ones((50, 2)), # 50 samples, 2 ROIs
rois=rt_region,
unit="lumens",
rate=30.0,
)
fl = Fluorescence(roi_response_series=roi_resp_series)
ophys_module.add(fl)
with NWBHDF5IO(nwb_path, "w") as io:
io.write(nwbfile)
return nwb_path