aligned dynamictable tests

This commit is contained in:
sneakers-the-rat 2024-08-15 00:57:44 -07:00
parent 7e7cbc1ac1
commit ce096db349
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 233 additions and 26 deletions

View file

@ -33,7 +33,7 @@ from pydantic import (
model_validator,
)
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from nwb_linkml.models import VectorData, VectorIndex
T = TypeVar("T", bound=NDArray)
@ -211,6 +211,8 @@ class DynamicTableMixin(BaseModel):
"""
Create ID column if not provided
"""
if not isinstance(model, dict):
return model
if "id" not in model:
lengths = []
for key, val in model.items():
@ -235,6 +237,8 @@ class DynamicTableMixin(BaseModel):
the model dict is ordered after python3.6, so we can use that minus
anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order
"""
if not isinstance(model, dict):
return model
if "colnames" not in model:
colnames = [
k
@ -270,6 +274,8 @@ class DynamicTableMixin(BaseModel):
See :meth:`.cast_specified_columns` for handling columns in the class specification
"""
# if columns are not in the specification, cast to a generic VectorData
if isinstance(model, dict):
for key, val in model.items():
if key in cls.model_fields:
continue
@ -437,7 +443,7 @@ class VectorIndexMixin(BaseModel, Generic[T]):
if isinstance(item, slice):
item = range(*item.indices(len(self.value)))
return [self.target.value[self._slice(i)] for i in item]
else:
else: # pragma: no cover
raise AttributeError(f"Could not index with {item}")
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
@ -530,7 +536,7 @@ class DynamicTableRegionMixin(BaseModel):
# so we index table with an array to construct
# a list of lists of rows
return [self.table[idx] for idx in self._index[item]]
else:
else: # pragma: no cover
raise ValueError(f"Dont know how to index with {item}, need an int or a slice")
else:
if isinstance(item, (int, np.integer)):
@ -543,19 +549,26 @@ class DynamicTableRegionMixin(BaseModel):
if isinstance(item, slice):
item = range(*item.indices(len(self.value)))
return [self.table[self.value[i]] for i in item]
else:
else: # pragma: no cover
raise ValueError(f"Dont know how to index with {item}, need an int or a slice")
def __setitem__(self, key: Union[int, str, slice], value: Any) -> None:
self.table[self.value[key]] = value
# self.table[self.value[key]] = value
raise NotImplementedError(
"Assigning values to tables is not implemented yet!"
) # pragma: no cover
class AlignedDynamicTableMixin(DynamicTableMixin):
class AlignedDynamicTableMixin(BaseModel):
"""
Mixin to allow indexing multiple tables that are aligned on a common ID
A great deal of code duplication because we need to avoid diamond inheritance
and also it's not so easy to copy a pydantic validator method.
"""
__pydantic_extra__: Dict[str, "DynamicTableMixin"]
model_config = ConfigDict(extra="allow", validate_assignment=True)
__pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]]
NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = (
"name",
@ -573,7 +586,7 @@ class AlignedDynamicTableMixin(DynamicTableMixin):
return {k: getattr(self, k) for i, k in enumerate(self.categories)}
def __getitem__(
self, item: Union[int, str, slice, Tuple[Union[int, slice], str]]
self, item: Union[int, str, slice, NDArray[Shape["*"], int], Tuple[Union[int, slice], str]]
) -> pd.DataFrame:
"""
Mimic hdmf:
@ -591,25 +604,78 @@ class AlignedDynamicTableMixin(DynamicTableMixin):
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str):
# get a slice of a single table
return self._categories[item[1]][item[0]]
elif isinstance(item, (int, slice)):
elif isinstance(item, (int, slice, Iterable)):
# get a slice of all the tables
ids = self.id[item]
if not isinstance(ids, Iterable):
ids = pd.Series([ids])
ids = pd.DataFrame({"id": ids})
tables = [ids] + [table[item].reset_index() for table in self._categories.values()]
tables = [ids]
for category_name, category in self._categories.items():
table = category[item]
if isinstance(table, pd.DataFrame):
table = table.reset_index()
elif isinstance(table, np.ndarray):
table = pd.DataFrame({category_name: [table]})
elif isinstance(table, Iterable):
table = pd.DataFrame({category_name: table})
else:
raise ValueError(
f"Don't know how to construct category table for {category_name}"
)
tables.append(table)
names = [self.name] + self.categories
# construct below in case we need to support array indexing in the future
else:
raise ValueError(
f"Dont know how to index with {item}, "
"need an int, string, slice, or tuple[int | slice, str]"
"need an int, string, slice, ndarray, or tuple[int | slice, str]"
)
df = pd.concat(tables, axis=1, keys=names)
df.set_index((self.name, "id"), drop=True, inplace=True)
return df
def __getattr__(self, item: str) -> Any:
"""Try and use pandas df attrs if we don't have them"""
try:
return BaseModel.__getattr__(self, item)
except AttributeError as e:
try:
return getattr(self[:], item)
except AttributeError:
raise e from None
def __len__(self) -> int:
"""
Use the id column to determine length.
If the id column doesn't represent length accurately, it's a bug
"""
return len(self.id)
@model_validator(mode="before")
@classmethod
def create_id(cls, model: Dict[str, Any]) -> Dict:
"""
Create ID column if not provided
"""
if "id" not in model:
lengths = []
for key, val in model.items():
# don't get lengths of columns with an index
if (
f"{key}_index" in model
or (isinstance(val, VectorData) and val._index)
or key in cls.NON_CATEGORY_FIELDS
):
continue
lengths.append(len(val))
model["id"] = np.arange(np.max(lengths))
return model
@model_validator(mode="before")
@classmethod
def create_categories(cls, model: Dict[str, Any]) -> Dict:
@ -636,6 +702,42 @@ class AlignedDynamicTableMixin(DynamicTableMixin):
model["categories"].extend(categories)
return model
@model_validator(mode="after")
def resolve_targets(self) -> "DynamicTableMixin":
"""
Ensure that any implicitly indexed columns are linked, and create backlinks
"""
for key, col in self._categories.items():
if isinstance(col, VectorData):
# find an index
idx = None
for field_name in self.model_fields_set:
if field_name in self.NON_CATEGORY_FIELDS or field_name == key:
continue
# implicit name-based index
field = getattr(self, field_name)
if isinstance(field, VectorIndex) and (
field_name == f"{key}_index" or field.target is col
):
idx = field
break
if idx is not None:
col._index = idx
idx.target = col
return self
@model_validator(mode="after")
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
"""
Ensure that all columns are equal length
"""
lengths = [len(v) for v in self._categories.values()] + [len(self.id)]
assert all([length == lengths[0] for length in lengths]), (
"Columns are not of equal length! "
f"Got colnames:\n{self.categories}\nand lengths: {lengths}"
)
return self
class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
"""

View file

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Type
import numpy as np
import pandas as pd
@ -7,7 +7,12 @@ from numpydantic import NDArray, Shape
from pydantic import ValidationError
from nwb_linkml.includes import hdmf
from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin
from nwb_linkml.includes.hdmf import (
AlignedDynamicTableMixin,
DynamicTableMixin,
VectorDataMixin,
VectorIndexMixin,
)
# FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
@ -39,6 +44,33 @@ def basic_table() -> tuple[DynamicTableMixin, dict[str, NDArray[Shape["10"], int
return MyData, cols
@pytest.fixture()
def aligned_table() -> tuple[Type[AlignedDynamicTableMixin], dict[str, DynamicTableMixin]]:
class Table1(DynamicTableMixin):
col1: hdmf.VectorData[NDArray[Shape["*"], int]]
col2: hdmf.VectorData[NDArray[Shape["*"], int]]
class Table2(DynamicTableMixin):
col3: hdmf.VectorData[NDArray[Shape["*"], int]]
col4: hdmf.VectorData[NDArray[Shape["*"], int]]
class Table3(DynamicTableMixin):
col5: hdmf.VectorData[NDArray[Shape["*"], int]]
col6: hdmf.VectorData[NDArray[Shape["*"], int]]
array = np.arange(10)
table1 = Table1(col1=array, col2=array)
table2 = Table2(col3=array, col4=array)
table3 = Table3(col5=array, col6=array)
class AlignedTable(AlignedDynamicTableMixin):
table1: Table1
table2: Table2
return AlignedTable, {"table1": table1, "table2": table2, "table3": table3}
def test_dynamictable_mixin_indexing(basic_table):
"""
Can index values from a dynamictable
@ -357,6 +389,8 @@ def test_vectordata_indexing():
assert all(data[0] == 6)
assert all(data[1] == 6)
assert all(data[2] == 6)
with pytest.raises(ValueError, match=".*equal-length.*"):
data[0:3] = [5, 4]
def test_vectordata_getattr():
@ -506,6 +540,77 @@ def test_dynamictable_region_ragged():
assert df.equals(region[i])
def test_aligned_dynamictable_indexing(aligned_table):
"""
Should be able to index aligned dynamic tables to yield a multi-index df
"""
AlignedTable, tables = aligned_table
atable = AlignedTable(**tables)
row = atable[0]
assert all(
row.columns
== pd.MultiIndex.from_tuples(
[
("table1", "index"),
("table1", "col1"),
("table1", "col2"),
("table2", "index"),
("table2", "col3"),
("table2", "col4"),
("table3", "index"),
("table3", "col5"),
("table3", "col6"),
]
)
)
for i in range(len(atable)):
vals = atable[i]
assert vals.shape == (1, 9)
assert all(vals == i)
# mildly different, indexing with a slice.
rows = atable[0:3]
for i, row in enumerate(rows.iterrows()):
vals = row[1]
assert len(vals) == 9
assert all(vals == i)
# index just a single table
row = atable[0:3, "table3"]
assert all(row.columns.to_numpy() == ["col5", "col6"])
assert row.shape == (3, 2)
# index out of order
rows = atable[np.array([0, 2, 1])]
assert all(rows.iloc[:, 0] == [0, 2, 1])
def test_mixed_aligned_dynamictable(aligned_table):
"""
Aligned dynamictable should also accept vectordata/vector index pairs
"""
AlignedTable, cols = aligned_table
value_array, index_array = _ragged_array(10)
value_array = np.concat(value_array)
data = hdmf.VectorData(value=value_array)
index = hdmf.VectorIndex(value=index_array)
atable = AlignedTable(**cols, extra_col=data, extra_col_index=index)
atable[0]
assert atable[0].columns[-1] == ("extra_col", "extra_col")
for i, row in enumerate(atable[:].extra_col.iterrows()):
array = row[1].iloc[0]
assert all(array == i)
if i > 0:
assert len(array) == index_array[i] - index_array[i - 1]
else:
assert len(array) == index_array[i]
# --------------------------------------------------
# Model-based tests
# --------------------------------------------------