first impl of dynamictable working!

This commit is contained in:
sneakers-the-rat 2024-08-05 20:51:52 -07:00
parent e5d1cc52de
commit c06859a537
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 336 additions and 235 deletions

File diff suppressed because it is too large Load diff

View file

@ -22,7 +22,7 @@ dependencies = [
"dask>=2023.9.2",
"tqdm>=4.66.1",
'typing-extensions>=4.12.2;python_version<"3.11"',
"numpydantic>=1.2.2",
"numpydantic>=1.3.0",
"black>=24.4.2",
"pandas>=2.2.2",
]

View file

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Un
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
from numpydantic import NDArray
from pandas import DataFrame
from pandas import DataFrame, Series
from pydantic import BaseModel, ConfigDict, Field, model_validator
if TYPE_CHECKING:
@ -98,6 +98,11 @@ class DynamicTableMixin(BaseModel):
rows, cols = item
if isinstance(cols, (int, slice)):
cols = self.colnames[cols]
if isinstance(rows, int) and isinstance(cols, str):
# single scalar value
return self._columns[cols][rows]
data = self._slice_range(rows, cols)
return DataFrame.from_dict(data)
else:
@ -110,7 +115,14 @@ class DynamicTableMixin(BaseModel):
cols = self.colnames
elif isinstance(cols, str):
cols = [cols]
data = {}
for k in cols:
val = self._columns[k][rows]
if isinstance(val, BaseModel):
# special case where pandas will unpack a pydantic model
# into {n_fields} rows, rather than keeping it in a dict
val = Series([val])
data[k] = val
data = {k: self._columns[k][rows] for k in cols}
return data
@ -244,7 +256,9 @@ class VectorIndexMixin(BaseModel):
DYNAMIC_TABLE_IMPORTS = Imports(
imports=[
Import(module="pandas", objects=[ObjectImport(name="DataFrame")]),
Import(
module="pandas", objects=[ObjectImport(name="DataFrame"), ObjectImport(name="Series")]
),
Import(
module="typing",
objects=[

View file

@ -1,14 +1,9 @@
from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
import re
import sys
import numpy as np
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
from pandas import DataFrame
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from ...hdmf_common.v1_8_0.hdmf_common_base import Data
from pandas import DataFrame, Series
from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
from numpydantic import NDArray, Shape
metamodel_version = "None"
@ -198,6 +193,11 @@ class DynamicTableMixin(BaseModel):
rows, cols = item
if isinstance(cols, (int, slice)):
cols = self.colnames[cols]
if isinstance(rows, int) and isinstance(cols, str):
# single scalar value
return self._columns[cols][rows]
data = self._slice_range(rows, cols)
return DataFrame.from_dict(data)
else:
@ -210,8 +210,14 @@ class DynamicTableMixin(BaseModel):
cols = self.colnames
elif isinstance(cols, str):
cols = [cols]
data = {k: self._columns[k][rows] for k in cols}
data = {}
for k in cols:
val = self._columns[k][rows]
if isinstance(val, BaseModel):
# special case where pandas will unpack a pydantic model
# into {n_fields} rows, rather than keeping it in a dict
val = Series([val])
data[k] = val
return data
def __setitem__(self, key: str, value: Any) -> None:

View file

@ -5,6 +5,8 @@ import pytest
# FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
Device,
DynamicTableRegion,
ElectricalSeries,
ElectrodeGroup,
ExtracellularEphysElectrodes,
@ -18,18 +20,95 @@ def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrod
"""
n_electrodes = 5
n_times = 100
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes)
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float)
timestamps = np.linspace(0, 1, n_times)
device = Device(name="my electrode")
# electrode group is the physical description of the electrodes
electrode_group = ElectrodeGroup(
name="GroupA",
device=device,
description="an electrode group",
location="you know where it is",
)
# make electrodes tables
electrodes = ExtracellularEphysElectrodes(
description="idk these are also electrodes",
id=np.arange(0, n_electrodes),
x=np.arange(0, n_electrodes),
y=np.arange(n_electrodes, n_electrodes * 2),
x=np.arange(0, n_electrodes).astype(float),
y=np.arange(n_electrodes, n_electrodes * 2).astype(float),
group=[electrode_group] * n_electrodes,
group_name=[electrode_group.name] * n_electrodes,
location=[str(i) for i in range(n_electrodes)],
extra_column=["sup"] * n_electrodes,
)
electrical_series = ElectricalSeries(
name="my recording!",
electrodes=DynamicTableRegion(
table=electrodes, value=np.arange(0, n_electrodes), name="electrodes", description="hey"
),
timestamps=timestamps,
data=data,
)
return electrical_series, electrodes
def test_dynamictable_indexing(electrical_series):
"""
Can index values from a dynamictable
"""
series, electrodes = electrical_series
colnames = [
"id",
"x",
"y",
"group",
"group_name",
"location",
"extra_column",
]
dtypes = [
np.dtype("int64"),
np.dtype("float64"),
np.dtype("float64"),
] + ([np.dtype("O")] * 4)
row = electrodes[0]
# successfully get a single row :)
assert row.shape == (1, 7)
assert row.dtypes.values.tolist() == dtypes
assert row.columns.tolist() == colnames
# slice a range of rows
rows = electrodes[0:3]
assert rows.shape == (3, 7)
assert rows.dtypes.values.tolist() == dtypes
assert rows.columns.tolist() == colnames
# get a single column
col = electrodes["y"]
assert all(col == [5, 6, 7, 8, 9])
# get a single cell
val = electrodes[0, "y"]
assert val == 5
val = electrodes[0, 2]
assert val == 5
# get a slice of rows and columns
subsection = electrodes[0:3, 0:3]
assert subsection.shape == (3, 3)
assert subsection.columns.tolist() == colnames[0:3]
assert subsection.dtypes.values.tolist() == dtypes[0:3]
def test_dynamictable_append_column():
pass
def test_dynamictable_append_row():
pass