add dynamictable, vectordata, vectorindex mixins

This commit is contained in:
sneakers-the-rat 2024-07-31 01:13:31 -07:00
parent 723f53035d
commit d118477d8a
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
10 changed files with 361 additions and 40 deletions

View file

@ -284,6 +284,7 @@ api/nwb_linkml/schema/index
meta/todo
meta/changelog
meta/references
genindex
```

View file

@ -20,6 +20,11 @@
### DynamicTable
```{note}
See the [DynamicTable](https://hdmf-common-schema.readthedocs.io/en/stable/format_description.html#dynamictable)
reference docs
```
One of the major special cases in NWB is the use of `DynamicTable` to contain tabular data that
contains columns that are not in the base spec.

11
docs/meta/references.md Normal file
View file

@ -0,0 +1,11 @@
# References
## Documentation
- [hdmf](https://hdmf.readthedocs.io/en/stable/)
- [hdmf-common-schema](https://hdmf-common-schema.readthedocs.io/en/stable/)
- [pynwb](https://pynwb.readthedocs.io/en/latest/)
```{todo}
Add the bibtex refs to NWB papers :)
```

View file

@ -40,7 +40,7 @@ from types import ModuleType
from typing import ClassVar, Dict, List, Optional, Tuple, Type, Union
from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.build import SlotResult
from linkml.generators.pydanticgen.build import SlotResult, ClassResult
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
from linkml.generators.pydanticgen.template import PydanticModule, Import, Imports
from linkml_runtime.linkml_model.meta import (
@ -63,6 +63,7 @@ from pydantic import BaseModel
from nwb_linkml.maps import flat_to_nptyping
from nwb_linkml.maps.naming import module_case, version_module_case
from nwb_linkml.includes.types import ModelTypeString, _get_name, NamedString, NamedImports
from nwb_linkml.includes.hdmf import DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_INJECTS
OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
@ -96,6 +97,9 @@ class NWBPydanticGenerator(PydanticGenerator):
def _check_anyof(
self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView
): # pragma: no cover
"""
Overridden to allow `array` in any_of
"""
# Confirm that the original slot range (ignoring the default that comes in from
# induced_slot) isn't in addition to setting any_of
allowed_keys = ("array",)
@ -127,6 +131,10 @@ class NWBPydanticGenerator(PydanticGenerator):
return slot
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
cls = AfterGenerateClass.inject_dynamictable(cls)
return cls
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
if "source_file" in template.meta:
del template.meta["source_file"]
@ -226,6 +234,33 @@ class AfterGenerateSlot:
slot.imports = NamedImports
return slot
class AfterGenerateClass:
"""
Container class for class-modification methods
"""
@staticmethod
def inject_dynamictable(cls: ClassResult) -> ClassResult:
if cls.cls.name == "DynamicTable":
cls.cls.bases = ["DynamicTableMixin"]
if cls.injected_classes is None:
cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy()
else:
cls.injected_classes.extend(DYNAMIC_TABLE_INJECTS.copy())
if isinstance(cls.imports, Imports):
cls.imports += DYNAMIC_TABLE_IMPORTS
elif isinstance(cls.imports, list):
cls.imports = Imports(imports=cls.imports) + DYNAMIC_TABLE_IMPORTS
else:
cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy()
elif cls.cls.name == "VectorData":
cls.cls.bases = ["VectorDataMixin"]
elif cls.cls.name == "VectorIndex":
cls.cls.bases = ["VectorIndexMixin"]
return cls
def compile_python(
text_or_fn: str, package_path: Path = None, module_name: str = "test"

View file

@ -1,39 +1,248 @@
"""
Special types for mimicking HDMF special case behavior
"""
from typing import Any, ClassVar, Dict, List, Optional, Union, Tuple, overload, TYPE_CHECKING
from typing import Any
from pydantic import BaseModel, ConfigDict
from linkml.generators.pydanticgen.template import Imports, Import, ObjectImport
from numpydantic import NDArray
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, Field, model_validator
if TYPE_CHECKING:
from nwb_linkml.models import VectorData, VectorIndex
class DynamicTableMixin(BaseModel):
"""
Mixin to make DynamicTable subclasses behave like tables/dataframes
Mimicking some of the behavior from :class:`hdmf.common.table.DynamicTable`
but simplifying along the way :)
"""
model_config = ConfigDict(extra="allow")
__pydantic_extra__: Dict[str, Union[list, "NDArray", "VectorData"]]
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ("name", "colnames", "description",)
# overridden by subclass but implemented here for testing and typechecking purposes :)
colnames: List[str] = Field(default_factory=list)
@property
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorData"]]:
return {
k: getattr(self, k) for i, k in enumerate(self.colnames)
}
@property
def _columns_list(self) -> List[Union[list, "NDArray", "VectorData"]]:
return [getattr(self, k) for i, k in enumerate(self.colnames)]
@overload
def __getitem__(self, item: str) -> Union[list, "NDArray", "VectorData"]: ...
@overload
def __getitem__(self, item: int) -> DataFrame: ...
@overload
def __getitem__(self, item: Tuple[int, Union[int, str]]) -> Any: ...
@overload
def __getitem__(self, item: Tuple[Union[int,slice], ...]) -> Union[DataFrame, list, "NDArray", "VectorData",]: ...
@overload
def __getitem__(self, item: slice) -> DataFrame: ...
def __getitem__(self, item: Union[str, int, slice, Tuple[int, Union[int, str]], Tuple[Union[int, slice], ...],]) -> Any:
"""
Get an item from the table
If item is...
- ``str`` : get the column with this name
- ``int`` : get the row at this index
- ``tuple[int, int]`` : get a specific cell value eg. (0,1) gets the 0th row and 1st column
- ``tuple[int, str]`` : get a specific cell value eg. (0, 'colname')
gets the 0th row from ``colname``
- ``tuple[int | slice, int | slice]`` : get a range of cells from a range of columns.
returns as a :class:`pandas.DataFrame`
"""
if isinstance(item, str):
return self._columns[item]
if isinstance(item, (int, slice)):
return DataFrame.from_dict(self._slice_range(item))
elif isinstance(item, tuple):
if len(item) != 2:
raise ValueError(
f"DynamicTables are 2-dimensional, can't index with more than 2 indices like {item}")
# all other cases are tuples of (rows, cols)
rows, cols = item
if isinstance(cols, (int, slice)):
cols = self.colnames[cols]
data = self._slice_range(rows, cols)
return DataFrame.from_dict(data)
else:
raise ValueError(f"Unsure how to get item with key {item}")
def _slice_range(self, rows: Union[int, slice], cols: Optional[Union[str, List[str]]] = None) -> Dict[str, Union[list, "NDArray", "VectorData"]]:
if cols is None:
cols = self.colnames
elif isinstance(cols, str):
cols = [cols]
data = {
k: self._columns[k][rows] for k in cols
}
return data
# @model_validator(mode='after')
# def ensure_equal_length(cls, model: 'DynamicTableMixin') -> 'DynamicTableMixin':
# """
# Ensure all vectors are of equal length
# """
# raise NotImplementedError('TODO')
#
# @model_validator(mode="after")
# def create_index_backrefs(cls, model: 'DynamicTableMixin') -> 'DynamicTableMixin':
# """
# Ensure that vectordata with vectorindexes know about them
# """
# raise NotImplementedError('TODO')
def __getitem__(self, item: str) -> Any:
raise NotImplementedError("TODO")
def __setitem__(self, key: str, value: Any) -> None:
raise NotImplementedError("TODO")
def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]):
"""
Add a column, appending it to ``colnames``
"""
# don't use this while building the model
if not getattr(self, '__pydantic_complete__', False):
return super().__setattr__(key, value)
if key not in self.model_fields_set and not key.endswith('_index'):
self.colnames.append(key)
return super().__setattr__(key, value)
@model_validator(mode="before")
@classmethod
def create_colnames(cls, model: Dict[str, Any]):
"""
Construct colnames from arguments.
the model dict is ordered after python3.6, so we can use that minus
anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order
"""
if 'colnames' not in model:
colnames = [k for k in model.keys()
if k not in cls.NON_COLUMN_FIELDS
and not k.endswith('_index')]
model['colnames'] = colnames
else:
# add any columns not explicitly given an order at the end
colnames = [k for k in model.keys() if
k not in cls.NON_COLUMN_FIELDS
and not k.endswith('_index')
and k not in model['colnames'].keys()
]
model['colnames'].extend(colnames)
return model
@model_validator(mode="after")
def resolve_targets(self) -> "DynamicTableMixin":
"""
Ensure that any implicitly indexed columns are linked, and create backlinks
"""
for key, col in self._columns.items():
if isinstance(col, VectorData):
# find an index
idx = None
for field_name in self.model_fields_set:
# implicit name-based index
field = getattr(self, field_name)
if isinstance(field, VectorIndex):
if field_name == f"{key}_index":
idx = field
break
elif field.target is col:
idx = field
break
if idx is not None:
col._index = idx
idx.target = col
return self
class VectorDataMixin(BaseModel):
"""
Mixin class to give VectorData indexing abilities
"""
_index: Optional["VectorIndex"] = None
# redefined in `VectorData`, but included here for testing and type checking
array: Optional[NDArray] = None
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
return self._index[item]
else:
return self.array[item]
def __setitem__(self, key, value) -> None:
if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
self._index[key] = value
else:
self.array[key] = value
class VectorIndexMixin(BaseModel):
"""
Mixin class to give VectorIndex indexing abilities
"""
# redefined in `VectorData`, but included here for testing and type checking
array: Optional[NDArray] = None
target: Optional["VectorData"] = None
def _getitem_helper(self, arg: int):
"""
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
"""
start = 0 if arg == 0 else self.array[arg - 1]
end = self.array[arg]
return self.target.array[slice(start, end)]
def __getitem__(self, item: Union[int, slice]) -> Any:
if self.target is None:
return self.array[item]
elif type(self.target).__name__ == "VectorData":
if isinstance(item, int):
return self._getitem_helper(item)
else:
idx = range(*item.indices(len(self.array)))
return [self._getitem_helper(i) for i in idx]
else:
raise NotImplementedError("DynamicTableRange not supported yet")
def __setitem__(self, key, value) -> None:
if self._index:
# VectorIndex is the thing that knows how to do the slicing
self._index[key] = value
else:
self.array[key] = value
DYNAMIC_TABLE_IMPORTS = Imports(
imports = [
Import(module="pandas", objects=[ObjectImport(name="DataFrame")]),
Import(module="typing", objects=[ObjectImport(name="ClassVar"), ObjectImport(name="overload"), ObjectImport(name="Tuple")]),
Import(module='numpydantic', objects=[ObjectImport(name='NDArray')]),
Import(module="pydantic", objects=[ObjectImport(name="model_validator")])
]
)
"""
Imports required for the dynamic table mixin
VectorData is purposefully excluded as an import or an inject so that it will be
resolved to the VectorData definition in the generated module
"""
DYNAMIC_TABLE_INJECTS = [VectorDataMixin, VectorIndexMixin, DynamicTableMixin]
# class VectorDataMixin(BaseModel):
# index: Optional[BaseModel] = None

View file

@ -120,7 +120,7 @@ def load_namespace_adapter(
return adapter
def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0") -> NamespacesAdapter:
def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0", hdmf_only:bool=False) -> NamespacesAdapter:
"""
Convenience function for loading the NWB core schema + hdmf-common as a namespace adapter.
@ -136,14 +136,18 @@ def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0") -> N
Args:
core_version (str): an entry in :attr:`.NWB_CORE_REPO.versions`
hdmf_version (str): an entry in :attr:`.NWB_CORE_REPO.versions`
hdmf_only (bool): Only return the hdmf common schema
Returns:
"""
# First get hdmf-common:
hdmf_schema = load_namespace_adapter(HDMF_COMMON_REPO, version=hdmf_version)
schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version)
if hdmf_only:
schema = hdmf_schema
else:
schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version)
schema.imported.append(hdmf_schema)
schema.imported.append(hdmf_schema)
return schema

View file

@ -5,6 +5,7 @@ Provider for LinkML schema built from NWB schema
import shutil
from pathlib import Path
from typing import Dict, Optional, TypedDict
from dataclasses import dataclass
from linkml_runtime import SchemaView
from linkml_runtime.dumpers import yaml_dumper
@ -19,7 +20,8 @@ from nwb_linkml.ui import AdapterProgress
from nwb_schema_language import Namespaces
class LinkMLSchemaBuild(TypedDict):
@dataclass
class LinkMLSchemaBuild:
"""Build result from :meth:`.LinkMLProvider.build`"""
version: str

View file

@ -1,6 +1,7 @@
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from typing import Dict, Optional
import pytest
@ -14,6 +15,8 @@ from linkml_runtime.linkml_model import (
)
from nwb_linkml.adapters.namespaces import NamespacesAdapter
from nwb_linkml.providers import LinkMLProvider, PydanticProvider
from nwb_linkml.providers.linkml import LinkMLSchemaBuild
from nwb_linkml.io import schema as io
from nwb_schema_language import Attribute, Dataset, Group
@ -87,6 +90,26 @@ def nwb_core_fixture(request) -> NamespacesAdapter:
return nwb_core
@pytest.fixture(scope="session")
def nwb_core_linkml(nwb_core_fixture, tmp_output_dir) -> LinkMLSchemaBuild:
provider = LinkMLProvider(tmp_output_dir, allow_repo=False, verbose=False)
result = provider.build(ns_adapter=nwb_core_fixture, force=True)
return result['core']
@pytest.fixture(scope="session")
def nwb_core_module(nwb_core_linkml: LinkMLSchemaBuild, tmp_output_dir) -> ModuleType:
"""
Generated pydantic namespace from nwb core
"""
provider = PydanticProvider(tmp_output_dir, verbose=False)
result = provider.build(nwb_core_linkml.namespace, force=True)
mod = provider.get('core', version=nwb_core_linkml.version, allow_repo=False)
return mod
@pytest.fixture(scope="session")
def data_dir() -> Path:

View file

@ -1,16 +1,19 @@
from typing import Tuple
from typing import Tuple, TYPE_CHECKING
from types import ModuleType
import numpy as np
import pytest
# FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
ElectricalSeries,
ElectrodeGroup,
NWBFileGeneralExtracellularEphysElectrodes,
)
@pytest.fixture()
def electrical_series() -> Tuple[ElectricalSeries, NWBFileGeneralExtracellularEphysElectrodes]:
def electrical_series() -> Tuple["ElectricalSeries", "NWBFileGeneralExtracellularEphysElectrodes"]:
"""
Demo electrical series with adjoining electrodes
"""
@ -19,9 +22,16 @@ def electrical_series() -> Tuple[ElectricalSeries, NWBFileGeneralExtracellularEp
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes)
timestamps = np.linspace(0, 1, n_times)
# electrode group is the physical description of the electrodes
electrode_group = ElectrodeGroup(
name="GroupA",
)
# make electrodes tables
electrodes = NWBFileGeneralExtracellularEphysElectrodes(
id=np.arange(0, n_electrodes),
x=np.arange(0, n_electrodes),
y=np.arange(n_electrodes, n_electrodes * 2),
group=[electrode_group]*n_electrodes,
)

View file

@ -14,13 +14,13 @@ from rich import print
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from nwb_linkml.providers import LinkMLProvider, PydanticProvider
from nwb_linkml.providers.git import NWB_CORE_REPO, GitRepo
from nwb_linkml.providers.git import NWB_CORE_REPO, HDMF_COMMON_REPO, GitRepo
from nwb_linkml.io import schema as io
def generate_core_yaml(output_path:Path, dry_run:bool=False):
def generate_core_yaml(output_path:Path, dry_run:bool=False, hdmf_only:bool=False):
"""Just build the latest version of the core schema"""
core = io.load_nwb_core()
core = io.load_nwb_core(hdmf_only=hdmf_only)
built_schemas = core.build().schemas
for schema in built_schemas:
output_file = output_path / (schema.name + '.yaml')
@ -45,11 +45,10 @@ def generate_core_pydantic(yaml_path:Path, output_path:Path, dry_run:bool=False)
with open(pydantic_file, 'w') as pfile:
pfile.write(gen_pydantic)
def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False, repo:GitRepo=NWB_CORE_REPO, hdmf_only=False):
"""
Generate linkml models for all versions
"""
repo = GitRepo(NWB_CORE_REPO)
#repo.clone(force=True)
repo.clone()
@ -81,7 +80,7 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
linkml_task = None
pydantic_task = None
for version in NWB_CORE_REPO.versions:
for version in repo.namespace.versions:
# build linkml
try:
# check out the version (this should also refresh the hdmf-common schema)
@ -91,9 +90,11 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
# first load the core namespace
core_ns = io.load_namespace_adapter(repo.namespace_file)
# then the hdmf-common namespace
hdmf_common_ns = io.load_namespace_adapter(repo.temp_directory / 'hdmf-common-schema' / 'common' / 'namespace.yaml')
core_ns.imported.append(hdmf_common_ns)
if repo.namespace == NWB_CORE_REPO:
# then the hdmf-common namespace
hdmf_common_ns = io.load_namespace_adapter(repo.temp_directory / 'hdmf-common-schema' / 'common' / 'namespace.yaml')
core_ns.imported.append(hdmf_common_ns)
build_progress.update(linkml_task, advance=1, action="Build LinkML")
@ -101,7 +102,7 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
build_progress.update(linkml_task, advance=1, action="Built LinkML")
# build pydantic
ns_files = [res['namespace'] for res in linkml_res.values()]
ns_files = [res.namespace for res in linkml_res.values()]
pydantic_task = build_progress.add_task('', name=version, action='', total=len(ns_files))
for schema in ns_files:
@ -129,10 +130,20 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
pydantic_task = None
if not dry_run:
shutil.rmtree(yaml_path / 'linkml')
shutil.rmtree(pydantic_path / 'pydantic')
shutil.move(tmp_dir / 'linkml', yaml_path)
shutil.move(tmp_dir / 'pydantic', pydantic_path)
if hdmf_only:
shutil.rmtree(yaml_path / 'linkml' / 'hdmf_common')
shutil.rmtree(yaml_path / 'linkml' / 'hdmf_experimental')
shutil.rmtree(pydantic_path / 'pydantic' / 'hdmf_common')
shutil.rmtree(pydantic_path / 'pydantic' / 'hdmf_experimental')
shutil.move(tmp_dir / 'linkml' / 'hdmf_common', yaml_path / 'linkml')
shutil.move(tmp_dir / 'linkml' / 'hdmf_experimental', yaml_path / 'linkml')
shutil.move(tmp_dir / 'pydantic' / 'hdmf_common', pydantic_path / 'pydantic')
shutil.move(tmp_dir / 'pydantic' / 'hdmf_experimental', pydantic_path / 'pydantic')
else:
shutil.rmtree(yaml_path / 'linkml')
shutil.rmtree(pydantic_path / 'pydantic')
shutil.move(tmp_dir / 'linkml', yaml_path)
shutil.move(tmp_dir / 'pydantic', pydantic_path)
# import the most recent version of the schemaz we built
latest_version = sorted((pydantic_path / 'pydantic' / 'core').iterdir(), key=os.path.getmtime)[-1]
@ -167,6 +178,11 @@ def parser() -> ArgumentParser:
type=Path,
default=Path(__file__).parent.parent / 'nwb_linkml' / 'src' / 'nwb_linkml' / 'models'
)
parser.add_argument(
'--hdmf',
help="Only generate the HDMF namespaces",
action="store_true"
)
parser.add_argument(
'--latest',
help="Only generate the latest version of the core schemas.",
@ -182,14 +198,19 @@ def parser() -> ArgumentParser:
def main():
args = parser().parse_args()
if args.hdmf:
repo = GitRepo(HDMF_COMMON_REPO)
else:
repo = GitRepo(NWB_CORE_REPO)
if not args.dry_run:
args.yaml.mkdir(exist_ok=True)
args.pydantic.mkdir(exist_ok=True)
if args.latest:
generate_core_yaml(args.yaml, args.dry_run)
generate_core_yaml(args.yaml, args.dry_run, args.hdmf)
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
else:
generate_versions(args.yaml, args.pydantic, args.dry_run)
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf)
if __name__ == "__main__":
main()