first impl of dynamictable working!

This commit is contained in:
sneakers-the-rat 2024-08-05 20:51:52 -07:00
parent e5d1cc52de
commit c06859a537
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 336 additions and 235 deletions

File diff suppressed because it is too large Load diff

View file

@ -22,7 +22,7 @@ dependencies = [
"dask>=2023.9.2", "dask>=2023.9.2",
"tqdm>=4.66.1", "tqdm>=4.66.1",
'typing-extensions>=4.12.2;python_version<"3.11"', 'typing-extensions>=4.12.2;python_version<"3.11"',
"numpydantic>=1.2.2", "numpydantic>=1.3.0",
"black>=24.4.2", "black>=24.4.2",
"pandas>=2.2.2", "pandas>=2.2.2",
] ]

View file

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Un
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
from numpydantic import NDArray from numpydantic import NDArray
from pandas import DataFrame from pandas import DataFrame, Series
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
if TYPE_CHECKING: if TYPE_CHECKING:
@ -98,6 +98,11 @@ class DynamicTableMixin(BaseModel):
rows, cols = item rows, cols = item
if isinstance(cols, (int, slice)): if isinstance(cols, (int, slice)):
cols = self.colnames[cols] cols = self.colnames[cols]
if isinstance(rows, int) and isinstance(cols, str):
# single scalar value
return self._columns[cols][rows]
data = self._slice_range(rows, cols) data = self._slice_range(rows, cols)
return DataFrame.from_dict(data) return DataFrame.from_dict(data)
else: else:
@ -110,7 +115,14 @@ class DynamicTableMixin(BaseModel):
cols = self.colnames cols = self.colnames
elif isinstance(cols, str): elif isinstance(cols, str):
cols = [cols] cols = [cols]
data = {}
for k in cols:
val = self._columns[k][rows]
if isinstance(val, BaseModel):
# special case where pandas will unpack a pydantic model
# into {n_fields} rows, rather than keeping it in a dict
val = Series([val])
data[k] = val
data = {k: self._columns[k][rows] for k in cols} data = {k: self._columns[k][rows] for k in cols}
return data return data
@ -244,7 +256,9 @@ class VectorIndexMixin(BaseModel):
DYNAMIC_TABLE_IMPORTS = Imports( DYNAMIC_TABLE_IMPORTS = Imports(
imports=[ imports=[
Import(module="pandas", objects=[ObjectImport(name="DataFrame")]), Import(
module="pandas", objects=[ObjectImport(name="DataFrame"), ObjectImport(name="Series")]
),
Import( Import(
module="typing", module="typing",
objects=[ objects=[

View file

@ -1,14 +1,9 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal from ...hdmf_common.v1_8_0.hdmf_common_base import Data
from enum import Enum from pandas import DataFrame, Series
import re from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
import sys from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
import numpy as np
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
from pandas import DataFrame
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from numpydantic import NDArray, Shape from numpydantic import NDArray, Shape
metamodel_version = "None" metamodel_version = "None"
@ -198,6 +193,11 @@ class DynamicTableMixin(BaseModel):
rows, cols = item rows, cols = item
if isinstance(cols, (int, slice)): if isinstance(cols, (int, slice)):
cols = self.colnames[cols] cols = self.colnames[cols]
if isinstance(rows, int) and isinstance(cols, str):
# single scalar value
return self._columns[cols][rows]
data = self._slice_range(rows, cols) data = self._slice_range(rows, cols)
return DataFrame.from_dict(data) return DataFrame.from_dict(data)
else: else:
@ -210,8 +210,14 @@ class DynamicTableMixin(BaseModel):
cols = self.colnames cols = self.colnames
elif isinstance(cols, str): elif isinstance(cols, str):
cols = [cols] cols = [cols]
data = {}
data = {k: self._columns[k][rows] for k in cols} for k in cols:
val = self._columns[k][rows]
if isinstance(val, BaseModel):
# special case where pandas will unpack a pydantic model
# into {n_fields} rows, rather than keeping it in a dict
val = Series([val])
data[k] = val
return data return data
def __setitem__(self, key: str, value: Any) -> None: def __setitem__(self, key: str, value: Any) -> None:

View file

@ -5,6 +5,8 @@ import pytest
# FIXME: Make this just be the output of the provider by patching into import machinery # FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
Device,
DynamicTableRegion,
ElectricalSeries, ElectricalSeries,
ElectrodeGroup, ElectrodeGroup,
ExtracellularEphysElectrodes, ExtracellularEphysElectrodes,
@ -18,18 +20,95 @@ def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrod
""" """
n_electrodes = 5 n_electrodes = 5
n_times = 100 n_times = 100
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes) data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float)
timestamps = np.linspace(0, 1, n_times) timestamps = np.linspace(0, 1, n_times)
device = Device(name="my electrode")
# electrode group is the physical description of the electrodes # electrode group is the physical description of the electrodes
electrode_group = ElectrodeGroup( electrode_group = ElectrodeGroup(
name="GroupA", name="GroupA",
device=device,
description="an electrode group",
location="you know where it is",
) )
# make electrodes tables # make electrodes tables
electrodes = ExtracellularEphysElectrodes( electrodes = ExtracellularEphysElectrodes(
description="idk these are also electrodes",
id=np.arange(0, n_electrodes), id=np.arange(0, n_electrodes),
x=np.arange(0, n_electrodes), x=np.arange(0, n_electrodes).astype(float),
y=np.arange(n_electrodes, n_electrodes * 2), y=np.arange(n_electrodes, n_electrodes * 2).astype(float),
group=[electrode_group] * n_electrodes, group=[electrode_group] * n_electrodes,
group_name=[electrode_group.name] * n_electrodes,
location=[str(i) for i in range(n_electrodes)],
extra_column=["sup"] * n_electrodes,
) )
electrical_series = ElectricalSeries(
name="my recording!",
electrodes=DynamicTableRegion(
table=electrodes, value=np.arange(0, n_electrodes), name="electrodes", description="hey"
),
timestamps=timestamps,
data=data,
)
return electrical_series, electrodes
def test_dynamictable_indexing(electrical_series):
"""
Can index values from a dynamictable
"""
series, electrodes = electrical_series
colnames = [
"id",
"x",
"y",
"group",
"group_name",
"location",
"extra_column",
]
dtypes = [
np.dtype("int64"),
np.dtype("float64"),
np.dtype("float64"),
] + ([np.dtype("O")] * 4)
row = electrodes[0]
# successfully get a single row :)
assert row.shape == (1, 7)
assert row.dtypes.values.tolist() == dtypes
assert row.columns.tolist() == colnames
# slice a range of rows
rows = electrodes[0:3]
assert rows.shape == (3, 7)
assert rows.dtypes.values.tolist() == dtypes
assert rows.columns.tolist() == colnames
# get a single column
col = electrodes["y"]
assert all(col == [5, 6, 7, 8, 9])
# get a single cell
val = electrodes[0, "y"]
assert val == 5
val = electrodes[0, 2]
assert val == 5
# get a slice of rows and columns
subsection = electrodes[0:3, 0:3]
assert subsection.shape == (3, 3)
assert subsection.columns.tolist() == colnames[0:3]
assert subsection.dtypes.values.tolist() == dtypes[0:3]
def test_dynamictable_append_column():
pass
def test_dynamictable_append_row():
pass