mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
aligned dynamictable tests
This commit is contained in:
parent
7e7cbc1ac1
commit
ce096db349
2 changed files with 233 additions and 26 deletions
|
@ -33,7 +33,7 @@ from pydantic import (
|
|||
model_validator,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from nwb_linkml.models import VectorData, VectorIndex
|
||||
|
||||
T = TypeVar("T", bound=NDArray)
|
||||
|
@ -211,6 +211,8 @@ class DynamicTableMixin(BaseModel):
|
|||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if not isinstance(model, dict):
|
||||
return model
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
|
@ -235,6 +237,8 @@ class DynamicTableMixin(BaseModel):
|
|||
the model dict is ordered after python3.6, so we can use that minus
|
||||
anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order
|
||||
"""
|
||||
if not isinstance(model, dict):
|
||||
return model
|
||||
if "colnames" not in model:
|
||||
colnames = [
|
||||
k
|
||||
|
@ -270,19 +274,21 @@ class DynamicTableMixin(BaseModel):
|
|||
See :meth:`.cast_specified_columns` for handling columns in the class specification
|
||||
"""
|
||||
# if columns are not in the specification, cast to a generic VectorData
|
||||
for key, val in model.items():
|
||||
if key in cls.model_fields:
|
||||
continue
|
||||
if not isinstance(val, (VectorData, VectorIndex)):
|
||||
try:
|
||||
if key.endswith("_index"):
|
||||
model[key] = VectorIndex(name=key, description="", value=val)
|
||||
else:
|
||||
model[key] = VectorData(name=key, description="", value=val)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise ValidationError(
|
||||
f"field {key} cannot be cast to VectorData from {val}"
|
||||
) from e
|
||||
|
||||
if isinstance(model, dict):
|
||||
for key, val in model.items():
|
||||
if key in cls.model_fields:
|
||||
continue
|
||||
if not isinstance(val, (VectorData, VectorIndex)):
|
||||
try:
|
||||
if key.endswith("_index"):
|
||||
model[key] = VectorIndex(name=key, description="", value=val)
|
||||
else:
|
||||
model[key] = VectorData(name=key, description="", value=val)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise ValidationError(
|
||||
f"field {key} cannot be cast to VectorData from {val}"
|
||||
) from e
|
||||
return model
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
@ -437,7 +443,7 @@ class VectorIndexMixin(BaseModel, Generic[T]):
|
|||
if isinstance(item, slice):
|
||||
item = range(*item.indices(len(self.value)))
|
||||
return [self.target.value[self._slice(i)] for i in item]
|
||||
else:
|
||||
else: # pragma: no cover
|
||||
raise AttributeError(f"Could not index with {item}")
|
||||
|
||||
def __setitem__(self, key: Union[int, slice], value: Any) -> None:
|
||||
|
@ -530,7 +536,7 @@ class DynamicTableRegionMixin(BaseModel):
|
|||
# so we index table with an array to construct
|
||||
# a list of lists of rows
|
||||
return [self.table[idx] for idx in self._index[item]]
|
||||
else:
|
||||
else: # pragma: no cover
|
||||
raise ValueError(f"Dont know how to index with {item}, need an int or a slice")
|
||||
else:
|
||||
if isinstance(item, (int, np.integer)):
|
||||
|
@ -543,19 +549,26 @@ class DynamicTableRegionMixin(BaseModel):
|
|||
if isinstance(item, slice):
|
||||
item = range(*item.indices(len(self.value)))
|
||||
return [self.table[self.value[i]] for i in item]
|
||||
else:
|
||||
else: # pragma: no cover
|
||||
raise ValueError(f"Dont know how to index with {item}, need an int or a slice")
|
||||
|
||||
def __setitem__(self, key: Union[int, str, slice], value: Any) -> None:
|
||||
self.table[self.value[key]] = value
|
||||
# self.table[self.value[key]] = value
|
||||
raise NotImplementedError(
|
||||
"Assigning values to tables is not implemented yet!"
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
class AlignedDynamicTableMixin(DynamicTableMixin):
|
||||
class AlignedDynamicTableMixin(BaseModel):
|
||||
"""
|
||||
Mixin to allow indexing multiple tables that are aligned on a common ID
|
||||
|
||||
A great deal of code duplication because we need to avoid diamond inheritance
|
||||
and also it's not so easy to copy a pydantic validator method.
|
||||
"""
|
||||
|
||||
__pydantic_extra__: Dict[str, "DynamicTableMixin"]
|
||||
model_config = ConfigDict(extra="allow", validate_assignment=True)
|
||||
__pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]]
|
||||
|
||||
NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = (
|
||||
"name",
|
||||
|
@ -573,7 +586,7 @@ class AlignedDynamicTableMixin(DynamicTableMixin):
|
|||
return {k: getattr(self, k) for i, k in enumerate(self.categories)}
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, str, slice, Tuple[Union[int, slice], str]]
|
||||
self, item: Union[int, str, slice, NDArray[Shape["*"], int], Tuple[Union[int, slice], str]]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Mimic hdmf:
|
||||
|
@ -591,25 +604,78 @@ class AlignedDynamicTableMixin(DynamicTableMixin):
|
|||
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str):
|
||||
# get a slice of a single table
|
||||
return self._categories[item[1]][item[0]]
|
||||
elif isinstance(item, (int, slice)):
|
||||
elif isinstance(item, (int, slice, Iterable)):
|
||||
# get a slice of all the tables
|
||||
ids = self.id[item]
|
||||
if not isinstance(ids, Iterable):
|
||||
ids = pd.Series([ids])
|
||||
ids = pd.DataFrame({"id": ids})
|
||||
tables = [ids] + [table[item].reset_index() for table in self._categories.values()]
|
||||
tables = [ids]
|
||||
for category_name, category in self._categories.items():
|
||||
table = category[item]
|
||||
if isinstance(table, pd.DataFrame):
|
||||
table = table.reset_index()
|
||||
elif isinstance(table, np.ndarray):
|
||||
table = pd.DataFrame({category_name: [table]})
|
||||
elif isinstance(table, Iterable):
|
||||
table = pd.DataFrame({category_name: table})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Don't know how to construct category table for {category_name}"
|
||||
)
|
||||
tables.append(table)
|
||||
|
||||
names = [self.name] + self.categories
|
||||
# construct below in case we need to support array indexing in the future
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Dont know how to index with {item}, "
|
||||
"need an int, string, slice, or tuple[int | slice, str]"
|
||||
"need an int, string, slice, ndarray, or tuple[int | slice, str]"
|
||||
)
|
||||
|
||||
df = pd.concat(tables, axis=1, keys=names)
|
||||
df.set_index((self.name, "id"), drop=True, inplace=True)
|
||||
return df
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
"""Try and use pandas df attrs if we don't have them"""
|
||||
try:
|
||||
return BaseModel.__getattr__(self, item)
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return getattr(self[:], item)
|
||||
except AttributeError:
|
||||
raise e from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Use the id column to determine length.
|
||||
|
||||
If the id column doesn't represent length accurately, it's a bug
|
||||
"""
|
||||
return len(self.id)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_id(cls, model: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Create ID column if not provided
|
||||
"""
|
||||
if "id" not in model:
|
||||
lengths = []
|
||||
for key, val in model.items():
|
||||
# don't get lengths of columns with an index
|
||||
if (
|
||||
f"{key}_index" in model
|
||||
or (isinstance(val, VectorData) and val._index)
|
||||
or key in cls.NON_CATEGORY_FIELDS
|
||||
):
|
||||
continue
|
||||
lengths.append(len(val))
|
||||
model["id"] = np.arange(np.max(lengths))
|
||||
|
||||
return model
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def create_categories(cls, model: Dict[str, Any]) -> Dict:
|
||||
|
@ -636,6 +702,42 @@ class AlignedDynamicTableMixin(DynamicTableMixin):
|
|||
model["categories"].extend(categories)
|
||||
return model
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_targets(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that any implicitly indexed columns are linked, and create backlinks
|
||||
"""
|
||||
for key, col in self._categories.items():
|
||||
if isinstance(col, VectorData):
|
||||
# find an index
|
||||
idx = None
|
||||
for field_name in self.model_fields_set:
|
||||
if field_name in self.NON_CATEGORY_FIELDS or field_name == key:
|
||||
continue
|
||||
# implicit name-based index
|
||||
field = getattr(self, field_name)
|
||||
if isinstance(field, VectorIndex) and (
|
||||
field_name == f"{key}_index" or field.target is col
|
||||
):
|
||||
idx = field
|
||||
break
|
||||
if idx is not None:
|
||||
col._index = idx
|
||||
idx.target = col
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_equal_length_cols(self) -> "DynamicTableMixin":
|
||||
"""
|
||||
Ensure that all columns are equal length
|
||||
"""
|
||||
lengths = [len(v) for v in self._categories.values()] + [len(self.id)]
|
||||
assert all([length == lengths[0] for length in lengths]), (
|
||||
"Columns are not of equal length! "
|
||||
f"Got colnames:\n{self.categories}\nand lengths: {lengths}"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -7,7 +7,12 @@ from numpydantic import NDArray, Shape
|
|||
from pydantic import ValidationError
|
||||
|
||||
from nwb_linkml.includes import hdmf
|
||||
from nwb_linkml.includes.hdmf import DynamicTableMixin, VectorDataMixin, VectorIndexMixin
|
||||
from nwb_linkml.includes.hdmf import (
|
||||
AlignedDynamicTableMixin,
|
||||
DynamicTableMixin,
|
||||
VectorDataMixin,
|
||||
VectorIndexMixin,
|
||||
)
|
||||
|
||||
# FIXME: Make this just be the output of the provider by patching into import machinery
|
||||
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||
|
@ -39,6 +44,33 @@ def basic_table() -> tuple[DynamicTableMixin, dict[str, NDArray[Shape["10"], int
|
|||
return MyData, cols
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def aligned_table() -> tuple[Type[AlignedDynamicTableMixin], dict[str, DynamicTableMixin]]:
|
||||
class Table1(DynamicTableMixin):
|
||||
col1: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
col2: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
|
||||
class Table2(DynamicTableMixin):
|
||||
col3: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
col4: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
|
||||
class Table3(DynamicTableMixin):
|
||||
col5: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
col6: hdmf.VectorData[NDArray[Shape["*"], int]]
|
||||
|
||||
array = np.arange(10)
|
||||
|
||||
table1 = Table1(col1=array, col2=array)
|
||||
table2 = Table2(col3=array, col4=array)
|
||||
table3 = Table3(col5=array, col6=array)
|
||||
|
||||
class AlignedTable(AlignedDynamicTableMixin):
|
||||
table1: Table1
|
||||
table2: Table2
|
||||
|
||||
return AlignedTable, {"table1": table1, "table2": table2, "table3": table3}
|
||||
|
||||
|
||||
def test_dynamictable_mixin_indexing(basic_table):
|
||||
"""
|
||||
Can index values from a dynamictable
|
||||
|
@ -357,6 +389,8 @@ def test_vectordata_indexing():
|
|||
assert all(data[0] == 6)
|
||||
assert all(data[1] == 6)
|
||||
assert all(data[2] == 6)
|
||||
with pytest.raises(ValueError, match=".*equal-length.*"):
|
||||
data[0:3] = [5, 4]
|
||||
|
||||
|
||||
def test_vectordata_getattr():
|
||||
|
@ -506,6 +540,77 @@ def test_dynamictable_region_ragged():
|
|||
assert df.equals(region[i])
|
||||
|
||||
|
||||
def test_aligned_dynamictable_indexing(aligned_table):
|
||||
"""
|
||||
Should be able to index aligned dynamic tables to yield a multi-index df
|
||||
"""
|
||||
AlignedTable, tables = aligned_table
|
||||
atable = AlignedTable(**tables)
|
||||
|
||||
row = atable[0]
|
||||
assert all(
|
||||
row.columns
|
||||
== pd.MultiIndex.from_tuples(
|
||||
[
|
||||
("table1", "index"),
|
||||
("table1", "col1"),
|
||||
("table1", "col2"),
|
||||
("table2", "index"),
|
||||
("table2", "col3"),
|
||||
("table2", "col4"),
|
||||
("table3", "index"),
|
||||
("table3", "col5"),
|
||||
("table3", "col6"),
|
||||
]
|
||||
)
|
||||
)
|
||||
for i in range(len(atable)):
|
||||
vals = atable[i]
|
||||
assert vals.shape == (1, 9)
|
||||
assert all(vals == i)
|
||||
|
||||
# mildly different, indexing with a slice.
|
||||
rows = atable[0:3]
|
||||
for i, row in enumerate(rows.iterrows()):
|
||||
vals = row[1]
|
||||
assert len(vals) == 9
|
||||
assert all(vals == i)
|
||||
|
||||
# index just a single table
|
||||
row = atable[0:3, "table3"]
|
||||
assert all(row.columns.to_numpy() == ["col5", "col6"])
|
||||
assert row.shape == (3, 2)
|
||||
|
||||
# index out of order
|
||||
rows = atable[np.array([0, 2, 1])]
|
||||
assert all(rows.iloc[:, 0] == [0, 2, 1])
|
||||
|
||||
|
||||
def test_mixed_aligned_dynamictable(aligned_table):
|
||||
"""
|
||||
Aligned dynamictable should also accept vectordata/vector index pairs
|
||||
"""
|
||||
|
||||
AlignedTable, cols = aligned_table
|
||||
value_array, index_array = _ragged_array(10)
|
||||
value_array = np.concat(value_array)
|
||||
|
||||
data = hdmf.VectorData(value=value_array)
|
||||
index = hdmf.VectorIndex(value=index_array)
|
||||
|
||||
atable = AlignedTable(**cols, extra_col=data, extra_col_index=index)
|
||||
atable[0]
|
||||
assert atable[0].columns[-1] == ("extra_col", "extra_col")
|
||||
|
||||
for i, row in enumerate(atable[:].extra_col.iterrows()):
|
||||
array = row[1].iloc[0]
|
||||
assert all(array == i)
|
||||
if i > 0:
|
||||
assert len(array) == index_array[i] - index_array[i - 1]
|
||||
else:
|
||||
assert len(array) == index_array[i]
|
||||
|
||||
|
||||
# --------------------------------------------------
|
||||
# Model-based tests
|
||||
# --------------------------------------------------
|
||||
|
|
Loading…
Reference in a new issue