nwb-linkml/nwb_linkml/tests/test_includes/test_hdmf.py

483 lines
15 KiB
Python

import numpy as np
import pandas as pd
import pytest
from numpydantic import NDArray, Shape
from pydantic import ValidationError
from nwb_linkml.includes import hdmf
from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin
# FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
DynamicTable,
DynamicTableRegion,
ElectrodeGroup,
VectorIndex,
VoltageClampStimulusSeries,
)
from .conftest import _ragged_array
def test_dynamictable_indexing(electrical_series):
"""
Can index values from a dynamictable
"""
series, electrodes = electrical_series
colnames = [
"id",
"x",
"y",
"group",
"group_name",
"location",
"extra_column",
]
dtypes = [
np.dtype("int64"),
np.dtype("float64"),
np.dtype("float64"),
] + ([np.dtype("O")] * 4)
row = electrodes[0]
# successfully get a single row :)
assert row.shape == (1, 7)
assert row.dtypes.values.tolist() == dtypes
assert row.columns.tolist() == colnames
# slice a range of rows
rows = electrodes[0:3]
assert rows.shape == (3, 7)
assert rows.dtypes.values.tolist() == dtypes
assert rows.columns.tolist() == colnames
# get a single column
col = electrodes["y"]
assert all(col.value == [5, 6, 7, 8, 9])
# get a single cell
val = electrodes[0, "y"]
assert val == 5
val = electrodes[0, 2]
assert val == 5
# get a slice of rows and columns
subsection = electrodes[0:3, 0:3]
assert subsection.shape == (3, 3)
assert subsection.columns.tolist() == colnames[0:3]
assert subsection.dtypes.values.tolist() == dtypes[0:3]
def test_dynamictable_ragged(units):
"""
Should be able to index ragged arrays using an implicit _index column
Also tests:
- passing arrays directly instead of wrapping in vectordata/index specifically,
if the models in the fixture instantiate then this works
"""
units, spike_times, spike_idx = units
# ensure we don't pivot to long when indexing
assert units[0].shape[0] == 1
# check that we got the indexing boundaries corrunect
# (and that we are forwarding attr calls to the dataframe by accessing shape
for i in range(units.shape[0]):
assert np.all(units.iloc[i, 0] == spike_times[i])
def test_dynamictable_region_basic(electrical_series):
"""
DynamicTableRegion should be able to refer to a row or rows of another table
itself as a column within a table
"""
series, electrodes = electrical_series
row = series.electrodes[0]
# check that we correctly got the 4th row instead of the 0th row,
# since the indexed table was constructed with inverted indexes because it's a test, ya dummy.
# we will only vaguely check the basic functionality here bc
# a) the indexing behavior of the indexed objects is tested above, and
# b) every other object in the chain is strictly validated,
# so we assume if we got a right shaped df that it is the correct one.
# feel free to @ me when i am wrong about this
assert all(row.id == 4)
assert row.shape == (1, 7)
# and we should still be preserving the model that is the contents of the cell of this row
# so this is a dataframe row with a column "group" that contains an array of ElectrodeGroup
# objects and that's as far as we are going to chase the recursion in this basic indexing test
# ElectrodeGroup is strictly validating so an instance check is all we need.
assert isinstance(row.group.values[0], ElectrodeGroup)
# getting a list of table rows is actually correct behavior here because
# this list of table rows is actually the cell of another table
rows = series.electrodes[0:3]
assert all([all(row.id == idx) for row, idx in zip(rows, [4, 3, 2])])
def test_dynamictable_region_ragged():
"""
Dynamictables can also have indexes so that they are ragged arrays of column rows
"""
spike_times, spike_idx = _ragged_array(24)
spike_times_flat = np.concatenate(spike_times)
# construct a secondary index that selects overlapping segments of the first table
value = np.array([0, 1, 2, 1, 2, 3, 2, 3, 4])
idx = np.array([3, 6, 9])
table = DynamicTable(
name="table",
description="a table what else would it be",
id=np.arange(len(spike_idx)),
timeseries=spike_times_flat,
timeseries_index=spike_idx,
)
region = DynamicTableRegion(
name="dynamictableregion",
description="this field should be optional",
table=table,
value=value,
)
index = VectorIndex(name="index", description="hgggggggjjjj", target=region, value=idx)
region._index = index
rows = region[1]
# i guess this is right?
# the region should be a set of three rows of the table, with a ragged array column timeseries
# like...
#
# id timeseries
# 0 1 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...
# 1 2 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ...
# 2 3 [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, ...
assert rows.shape == (3, 2)
assert all(rows.id == [1, 2, 3])
assert all([all(row[1].timeseries == i) for i, row in zip([1, 2, 3], rows.iterrows())])
def test_dynamictable_append_column():
pass
def test_dynamictable_append_row():
pass
def test_dynamictable_extra_coercion():
"""
Extra fields should be coerced to VectorData and have their
indexing relationships handled when passed as plain arrays.
"""
def test_aligned_dynamictable(intracellular_recordings_table):
"""
Multiple aligned dynamictables should be indexable with a multiindex
"""
# can get a single row.. (check correctness below)
row = intracellular_recordings_table[0]
# can get a single table with its name
stimuli = intracellular_recordings_table["stimuli"]
assert stimuli.shape == (10, 1)
# nab a few rows to make the dataframe
rows = intracellular_recordings_table[0:3]
assert all(
rows.columns
== pd.MultiIndex.from_tuples(
[
("electrodes", "index"),
("electrodes", "electrode"),
("stimuli", "index"),
("stimuli", "stimulus"),
("responses", "index"),
("responses", "response"),
]
)
)
# ensure that we get the actual values from the TimeSeriesReferenceVectorData
# also tested separately
# each individual cell should be an array of VoltageClampStimulusSeries...
# and then we should be able to index within that as well
stims = rows["stimuli", "stimulus"][0]
for i in range(len(stims)):
assert isinstance(stims[i], VoltageClampStimulusSeries)
assert all([i == val for val in stims[i][:]])
# --------------------------------------------------
# Direct mixin tests
# --------------------------------------------------
def test_dynamictable_mixin_indexing():
"""
Can index values from a dynamictable
"""
class MyData(DynamicTableMixin):
col_1: hdmf.VectorData[NDArray[Shape["*"], int]]
col_2: hdmf.VectorData[NDArray[Shape["*"], int]]
col_3: hdmf.VectorData[NDArray[Shape["*"], int]]
cols = {
"col_1": np.arange(10),
"col_2": np.arange(10),
"col_3": np.arange(10),
"col_4": np.arange(10),
"col_5": np.arange(10),
}
colnames = [c for c in cols]
inst = MyData(**cols)
row = inst[0]
# successfully get a single row :)
assert row.shape == (1, 5)
assert row.columns.tolist() == colnames
# slice a range of rows
rows = inst[0:3]
assert rows.shape == (3, 5)
# get a single column
col = inst["col_1"]
assert all(col.value == np.arange(10))
# get a single cell
val = inst[5, "col_2"]
assert val == 5
val = inst[5, 1]
assert val == 5
# get a slice of rows and columns
subsection = inst[0:3, 0:3]
assert subsection.shape == (3, 3)
assert subsection.columns.tolist() == colnames[0:3]
def test_dynamictable_mixin_colnames():
"""
Should correctly infer colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
new_col_1 = VectorDataMixin(value=np.arange(10))
new_col_2 = VectorDataMixin(value=np.arange(10))
inst = MyDT(existing_col=np.arange(10), new_col_1=new_col_1, new_col_2=new_col_2)
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
def test_dynamictable_mixin_colnames_index():
"""
Exclude index columns in colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
"new_col_2": hdmf.VectorData(value=np.arange(10)),
}
# explicit index with mismatching name
cols["weirdname_index"] = VectorIndexMixin(value=np.arange(10), target=cols["new_col_1"])
# implicit index with matching name
cols["new_col_2_index"] = VectorIndexMixin(value=np.arange(10))
inst = MyDT(**cols)
assert inst.colnames == ["existing_col", "new_col_1", "new_col_2"]
def test_dynamictable_mixin_colnames_ordered():
"""
Should be able to pass explicit order to colnames
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
"new_col_2": hdmf.VectorData(value=np.arange(10)),
"new_col_3": hdmf.VectorData(value=np.arange(10)),
}
order = ["new_col_2", "existing_col", "new_col_1", "new_col_3"]
inst = MyDT(**cols, colnames=order)
assert inst.colnames == order
# this should get reflected in the columns selector and the df produces
assert all([key1 == key2 for key1, key2 in zip(order, inst._columns)])
assert all(inst[0].columns == order)
# partial lists should append unnamed columsn at the end
partial_order = ["new_col_3", "new_col_2"]
inst = MyDT(**cols, colnames=partial_order)
assert inst.colnames == [*partial_order, "existing_col", "new_col_1"]
def test_dynamictable_mixin_getattr():
"""
Dynamictable should forward unknown getattr requests to the df
"""
class MyDT(DynamicTableMixin):
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
col = hdmf.VectorData(value=np.arange(10))
inst = MyDT(existing_col=col)
# regular lookup for attrs that exist
assert isinstance(inst.existing_col, hdmf.VectorData)
assert all(inst.existing_col.value == col.value)
# df lookup for thsoe that don't
assert isinstance(inst.columns, pd.Index)
def test_dynamictable_coercion():
"""
Dynamictable should coerce arrays into vectordata objects for known and unknown cols
"""
class MyDT(DynamicTableMixin):
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
cols = {
"existing_col": np.arange(10),
"new_col_1": np.arange(10),
}
inst = MyDT(**cols)
assert isinstance(inst.existing_col, hdmf.VectorData)
assert isinstance(inst.new_col_1, hdmf.VectorData)
assert all(inst.existing_col.value == np.arange(10))
assert all(inst.new_col_1.value == np.arange(10))
def test_dynamictable_create_id():
class MyDT(DynamicTableMixin):
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
cols = {
"existing_col": np.arange(10),
}
inst = MyDT(**cols)
assert all(inst.id == np.arange(10))
def test_dynamictable_resolve_index():
"""
Dynamictable should resolve and connect data to indices, explicit and implicit
"""
class MyDT(DynamicTableMixin):
existing_col: hdmf.VectorData[NDArray[Shape["* col"], int]]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
"new_col_2": hdmf.VectorData(value=np.arange(10)),
}
# explicit index with mismatching name
cols["weirdname_index"] = hdmf.VectorIndex(value=np.arange(10), target=cols["new_col_1"])
# implicit index with matching name
cols["new_col_2_index"] = hdmf.VectorIndex(value=np.arange(10))
inst = MyDT(**cols)
assert inst.weirdname_index.target is inst.new_col_1
assert inst.new_col_2_index.target is inst.new_col_2
assert inst.new_col_1._index is inst.weirdname_index
assert inst.new_col_2._index is inst.new_col_2_index
def dynamictable_assert_equal_length():
"""
Dynamictable validates that columns are of equal length
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["* col"], int]
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(11)),
}
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
_ = MyDT(**cols)
cols = {
"existing_col": np.arange(11),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
}
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
_ = MyDT(**cols)
# wrong lengths are fine as long as the index is good
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(100)),
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 10) + 10),
}
_ = MyDT(**cols)
# but not fine if the index is not good
cols = {
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(100)),
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5),
}
with pytest.raises(ValidationError, pattern="Columns are not of equal length"):
_ = MyDT(**cols)
def test_vectordata_indexing():
"""
Vectordata/VectorIndex pairs should know how to index off each other
"""
n_rows = 50
value_array, index_array = _ragged_array(n_rows)
value_array = np.concat(value_array)
data = hdmf.VectorData(value=value_array)
# before we have an index, things should work as normal, indexing a 1D array
assert data[0] == 0
index = hdmf.VectorIndex(value=index_array, target=data)
data._index = index
# after an index, both objects should index raggedly
for i in range(len(index)):
assert all(data[i] == i)
assert all(index[i] == i)
for item in (data, index):
section = item[0:3]
for i, subitem in enumerate(section):
assert all(subitem == i)
# setting uses the same indexing logic
data[0][:] = 5
assert all(data[0] == 5)
def test_vectordata_generic_numpydantic_validation():
"""
Using VectorData as a generic with a numpydantic array annotation should still validate
Simple test here because numpydantic validation is tested in numpydantic itself,
we just want to check that the annotations work as validation and it doesn't just
"""
class MyDT(DynamicTableMixin):
existing_col: NDArray[Shape["3 col"], int]
with pytest.raises(ValidationError):
_ = MyDT(existing_col=np.zeros((4, 5, 6), dtype=int))