mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-09 13:44:27 +00:00
add dynamictable, vectordata, vectorindex mixins
This commit is contained in:
parent
723f53035d
commit
d118477d8a
10 changed files with 361 additions and 40 deletions
|
@ -284,6 +284,7 @@ api/nwb_linkml/schema/index
|
||||||
|
|
||||||
meta/todo
|
meta/todo
|
||||||
meta/changelog
|
meta/changelog
|
||||||
|
meta/references
|
||||||
genindex
|
genindex
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,11 @@
|
||||||
|
|
||||||
### DynamicTable
|
### DynamicTable
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
See the [DynamicTable](https://hdmf-common-schema.readthedocs.io/en/stable/format_description.html#dynamictable)
|
||||||
|
reference docs
|
||||||
|
```
|
||||||
|
|
||||||
One of the major special cases in NWB is the use of `DynamicTable` to contain tabular data that
|
One of the major special cases in NWB is the use of `DynamicTable` to contain tabular data that
|
||||||
contains columns that are not in the base spec.
|
contains columns that are not in the base spec.
|
||||||
|
|
||||||
|
|
11
docs/meta/references.md
Normal file
11
docs/meta/references.md
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
# References
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
- [hdmf](https://hdmf.readthedocs.io/en/stable/)
|
||||||
|
- [hdmf-common-schema](https://hdmf-common-schema.readthedocs.io/en/stable/)
|
||||||
|
- [pynwb](https://pynwb.readthedocs.io/en/latest/)
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Add the bibtex refs to NWB papers :)
|
||||||
|
```
|
|
@ -40,7 +40,7 @@ from types import ModuleType
|
||||||
from typing import ClassVar, Dict, List, Optional, Tuple, Type, Union
|
from typing import ClassVar, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from linkml.generators import PydanticGenerator
|
from linkml.generators import PydanticGenerator
|
||||||
from linkml.generators.pydanticgen.build import SlotResult
|
from linkml.generators.pydanticgen.build import SlotResult, ClassResult
|
||||||
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
|
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
|
||||||
from linkml.generators.pydanticgen.template import PydanticModule, Import, Imports
|
from linkml.generators.pydanticgen.template import PydanticModule, Import, Imports
|
||||||
from linkml_runtime.linkml_model.meta import (
|
from linkml_runtime.linkml_model.meta import (
|
||||||
|
@ -63,6 +63,7 @@ from pydantic import BaseModel
|
||||||
from nwb_linkml.maps import flat_to_nptyping
|
from nwb_linkml.maps import flat_to_nptyping
|
||||||
from nwb_linkml.maps.naming import module_case, version_module_case
|
from nwb_linkml.maps.naming import module_case, version_module_case
|
||||||
from nwb_linkml.includes.types import ModelTypeString, _get_name, NamedString, NamedImports
|
from nwb_linkml.includes.types import ModelTypeString, _get_name, NamedString, NamedImports
|
||||||
|
from nwb_linkml.includes.hdmf import DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_INJECTS
|
||||||
|
|
||||||
OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
|
OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
|
||||||
|
|
||||||
|
@ -96,6 +97,9 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
def _check_anyof(
|
def _check_anyof(
|
||||||
self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView
|
self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView
|
||||||
): # pragma: no cover
|
): # pragma: no cover
|
||||||
|
"""
|
||||||
|
Overridden to allow `array` in any_of
|
||||||
|
"""
|
||||||
# Confirm that the original slot range (ignoring the default that comes in from
|
# Confirm that the original slot range (ignoring the default that comes in from
|
||||||
# induced_slot) isn't in addition to setting any_of
|
# induced_slot) isn't in addition to setting any_of
|
||||||
allowed_keys = ("array",)
|
allowed_keys = ("array",)
|
||||||
|
@ -127,6 +131,10 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
|
|
||||||
return slot
|
return slot
|
||||||
|
|
||||||
|
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
|
||||||
|
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
||||||
if "source_file" in template.meta:
|
if "source_file" in template.meta:
|
||||||
del template.meta["source_file"]
|
del template.meta["source_file"]
|
||||||
|
@ -226,6 +234,33 @@ class AfterGenerateSlot:
|
||||||
slot.imports = NamedImports
|
slot.imports = NamedImports
|
||||||
return slot
|
return slot
|
||||||
|
|
||||||
|
class AfterGenerateClass:
|
||||||
|
"""
|
||||||
|
Container class for class-modification methods
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inject_dynamictable(cls: ClassResult) -> ClassResult:
|
||||||
|
if cls.cls.name == "DynamicTable":
|
||||||
|
cls.cls.bases = ["DynamicTableMixin"]
|
||||||
|
|
||||||
|
if cls.injected_classes is None:
|
||||||
|
cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy()
|
||||||
|
else:
|
||||||
|
cls.injected_classes.extend(DYNAMIC_TABLE_INJECTS.copy())
|
||||||
|
|
||||||
|
if isinstance(cls.imports, Imports):
|
||||||
|
cls.imports += DYNAMIC_TABLE_IMPORTS
|
||||||
|
elif isinstance(cls.imports, list):
|
||||||
|
cls.imports = Imports(imports=cls.imports) + DYNAMIC_TABLE_IMPORTS
|
||||||
|
else:
|
||||||
|
cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy()
|
||||||
|
elif cls.cls.name == "VectorData":
|
||||||
|
cls.cls.bases = ["VectorDataMixin"]
|
||||||
|
elif cls.cls.name == "VectorIndex":
|
||||||
|
cls.cls.bases = ["VectorIndexMixin"]
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def compile_python(
|
def compile_python(
|
||||||
text_or_fn: str, package_path: Path = None, module_name: str = "test"
|
text_or_fn: str, package_path: Path = None, module_name: str = "test"
|
||||||
|
|
|
@ -1,39 +1,248 @@
|
||||||
"""
|
"""
|
||||||
Special types for mimicking HDMF special case behavior
|
Special types for mimicking HDMF special case behavior
|
||||||
"""
|
"""
|
||||||
|
from typing import Any, ClassVar, Dict, List, Optional, Union, Tuple, overload, TYPE_CHECKING
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from linkml.generators.pydanticgen.template import Imports, Import, ObjectImport
|
||||||
|
from numpydantic import NDArray
|
||||||
|
from pandas import DataFrame
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nwb_linkml.models import VectorData, VectorIndex
|
||||||
|
|
||||||
|
|
||||||
class DynamicTableMixin(BaseModel):
|
class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
Mixin to make DynamicTable subclasses behave like tables/dataframes
|
Mixin to make DynamicTable subclasses behave like tables/dataframes
|
||||||
|
|
||||||
|
Mimicking some of the behavior from :class:`hdmf.common.table.DynamicTable`
|
||||||
|
but simplifying along the way :)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
__pydantic_extra__: Dict[str, Union[list, "NDArray", "VectorData"]]
|
||||||
|
NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ("name", "colnames", "description",)
|
||||||
|
|
||||||
|
# overridden by subclass but implemented here for testing and typechecking purposes :)
|
||||||
|
colnames: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorData"]]:
|
||||||
|
return {
|
||||||
|
k: getattr(self, k) for i, k in enumerate(self.colnames)
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _columns_list(self) -> List[Union[list, "NDArray", "VectorData"]]:
|
||||||
|
return [getattr(self, k) for i, k in enumerate(self.colnames)]
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: str) -> Union[list, "NDArray", "VectorData"]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: int) -> DataFrame: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: Tuple[int, Union[int, str]]) -> Any: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: Tuple[Union[int,slice], ...]) -> Union[DataFrame, list, "NDArray", "VectorData",]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, item: slice) -> DataFrame: ...
|
||||||
|
|
||||||
|
def __getitem__(self, item: Union[str, int, slice, Tuple[int, Union[int, str]], Tuple[Union[int, slice], ...],]) -> Any:
|
||||||
|
"""
|
||||||
|
Get an item from the table
|
||||||
|
|
||||||
|
If item is...
|
||||||
|
|
||||||
|
- ``str`` : get the column with this name
|
||||||
|
- ``int`` : get the row at this index
|
||||||
|
- ``tuple[int, int]`` : get a specific cell value eg. (0,1) gets the 0th row and 1st column
|
||||||
|
- ``tuple[int, str]`` : get a specific cell value eg. (0, 'colname')
|
||||||
|
gets the 0th row from ``colname``
|
||||||
|
- ``tuple[int | slice, int | slice]`` : get a range of cells from a range of columns.
|
||||||
|
returns as a :class:`pandas.DataFrame`
|
||||||
|
"""
|
||||||
|
if isinstance(item, str):
|
||||||
|
return self._columns[item]
|
||||||
|
if isinstance(item, (int, slice)):
|
||||||
|
return DataFrame.from_dict(self._slice_range(item))
|
||||||
|
elif isinstance(item, tuple):
|
||||||
|
if len(item) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"DynamicTables are 2-dimensional, can't index with more than 2 indices like {item}")
|
||||||
|
|
||||||
|
# all other cases are tuples of (rows, cols)
|
||||||
|
rows, cols = item
|
||||||
|
if isinstance(cols, (int, slice)):
|
||||||
|
cols = self.colnames[cols]
|
||||||
|
data = self._slice_range(rows, cols)
|
||||||
|
return DataFrame.from_dict(data)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsure how to get item with key {item}")
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_range(self, rows: Union[int, slice], cols: Optional[Union[str, List[str]]] = None) -> Dict[str, Union[list, "NDArray", "VectorData"]]:
|
||||||
|
if cols is None:
|
||||||
|
cols = self.colnames
|
||||||
|
elif isinstance(cols, str):
|
||||||
|
cols = [cols]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
k: self._columns[k][rows] for k in cols
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
# @model_validator(mode='after')
|
|
||||||
# def ensure_equal_length(cls, model: 'DynamicTableMixin') -> 'DynamicTableMixin':
|
|
||||||
# """
|
|
||||||
# Ensure all vectors are of equal length
|
|
||||||
# """
|
|
||||||
# raise NotImplementedError('TODO')
|
|
||||||
#
|
|
||||||
# @model_validator(mode="after")
|
|
||||||
# def create_index_backrefs(cls, model: 'DynamicTableMixin') -> 'DynamicTableMixin':
|
|
||||||
# """
|
|
||||||
# Ensure that vectordata with vectorindexes know about them
|
|
||||||
# """
|
|
||||||
# raise NotImplementedError('TODO')
|
|
||||||
|
|
||||||
def __getitem__(self, item: str) -> Any:
|
|
||||||
raise NotImplementedError("TODO")
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: Any) -> None:
|
def __setitem__(self, key: str, value: Any) -> None:
|
||||||
raise NotImplementedError("TODO")
|
raise NotImplementedError("TODO")
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]):
|
||||||
|
"""
|
||||||
|
Add a column, appending it to ``colnames``
|
||||||
|
"""
|
||||||
|
# don't use this while building the model
|
||||||
|
if not getattr(self, '__pydantic_complete__', False):
|
||||||
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
if key not in self.model_fields_set and not key.endswith('_index'):
|
||||||
|
self.colnames.append(key)
|
||||||
|
|
||||||
|
return super().__setattr__(key, value)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def create_colnames(cls, model: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Construct colnames from arguments.
|
||||||
|
|
||||||
|
the model dict is ordered after python3.6, so we can use that minus
|
||||||
|
anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order
|
||||||
|
"""
|
||||||
|
if 'colnames' not in model:
|
||||||
|
colnames = [k for k in model.keys()
|
||||||
|
if k not in cls.NON_COLUMN_FIELDS
|
||||||
|
and not k.endswith('_index')]
|
||||||
|
model['colnames'] = colnames
|
||||||
|
else:
|
||||||
|
# add any columns not explicitly given an order at the end
|
||||||
|
colnames = [k for k in model.keys() if
|
||||||
|
k not in cls.NON_COLUMN_FIELDS
|
||||||
|
and not k.endswith('_index')
|
||||||
|
and k not in model['colnames'].keys()
|
||||||
|
]
|
||||||
|
model['colnames'].extend(colnames)
|
||||||
|
return model
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def resolve_targets(self) -> "DynamicTableMixin":
|
||||||
|
"""
|
||||||
|
Ensure that any implicitly indexed columns are linked, and create backlinks
|
||||||
|
"""
|
||||||
|
for key, col in self._columns.items():
|
||||||
|
if isinstance(col, VectorData):
|
||||||
|
# find an index
|
||||||
|
idx = None
|
||||||
|
for field_name in self.model_fields_set:
|
||||||
|
# implicit name-based index
|
||||||
|
field = getattr(self, field_name)
|
||||||
|
if isinstance(field, VectorIndex):
|
||||||
|
if field_name == f"{key}_index":
|
||||||
|
idx = field
|
||||||
|
break
|
||||||
|
elif field.target is col:
|
||||||
|
idx = field
|
||||||
|
break
|
||||||
|
if idx is not None:
|
||||||
|
col._index = idx
|
||||||
|
idx.target = col
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDataMixin(BaseModel):
|
||||||
|
"""
|
||||||
|
Mixin class to give VectorData indexing abilities
|
||||||
|
"""
|
||||||
|
_index: Optional["VectorIndex"] = None
|
||||||
|
|
||||||
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
|
array: Optional[NDArray] = None
|
||||||
|
|
||||||
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
|
if self._index:
|
||||||
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
return self._index[item]
|
||||||
|
else:
|
||||||
|
return self.array[item]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value) -> None:
|
||||||
|
if self._index:
|
||||||
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
self._index[key] = value
|
||||||
|
else:
|
||||||
|
self.array[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIndexMixin(BaseModel):
|
||||||
|
"""
|
||||||
|
Mixin class to give VectorIndex indexing abilities
|
||||||
|
"""
|
||||||
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
|
array: Optional[NDArray] = None
|
||||||
|
target: Optional["VectorData"] = None
|
||||||
|
|
||||||
|
def _getitem_helper(self, arg: int):
|
||||||
|
"""
|
||||||
|
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
|
||||||
|
"""
|
||||||
|
|
||||||
|
start = 0 if arg == 0 else self.array[arg - 1]
|
||||||
|
end = self.array[arg]
|
||||||
|
return self.target.array[slice(start, end)]
|
||||||
|
|
||||||
|
def __getitem__(self, item: Union[int, slice]) -> Any:
|
||||||
|
if self.target is None:
|
||||||
|
return self.array[item]
|
||||||
|
elif type(self.target).__name__ == "VectorData":
|
||||||
|
if isinstance(item, int):
|
||||||
|
return self._getitem_helper(item)
|
||||||
|
else:
|
||||||
|
idx = range(*item.indices(len(self.array)))
|
||||||
|
return [self._getitem_helper(i) for i in idx]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("DynamicTableRange not supported yet")
|
||||||
|
|
||||||
|
|
||||||
|
def __setitem__(self, key, value) -> None:
|
||||||
|
if self._index:
|
||||||
|
# VectorIndex is the thing that knows how to do the slicing
|
||||||
|
self._index[key] = value
|
||||||
|
else:
|
||||||
|
self.array[key] = value
|
||||||
|
|
||||||
|
|
||||||
|
DYNAMIC_TABLE_IMPORTS = Imports(
|
||||||
|
imports = [
|
||||||
|
Import(module="pandas", objects=[ObjectImport(name="DataFrame")]),
|
||||||
|
Import(module="typing", objects=[ObjectImport(name="ClassVar"), ObjectImport(name="overload"), ObjectImport(name="Tuple")]),
|
||||||
|
Import(module='numpydantic', objects=[ObjectImport(name='NDArray')]),
|
||||||
|
Import(module="pydantic", objects=[ObjectImport(name="model_validator")])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Imports required for the dynamic table mixin
|
||||||
|
|
||||||
|
VectorData is purposefully excluded as an import or an inject so that it will be
|
||||||
|
resolved to the VectorData definition in the generated module
|
||||||
|
"""
|
||||||
|
DYNAMIC_TABLE_INJECTS = [VectorDataMixin, VectorIndexMixin, DynamicTableMixin]
|
||||||
|
|
||||||
# class VectorDataMixin(BaseModel):
|
# class VectorDataMixin(BaseModel):
|
||||||
# index: Optional[BaseModel] = None
|
# index: Optional[BaseModel] = None
|
||||||
|
|
|
@ -120,7 +120,7 @@ def load_namespace_adapter(
|
||||||
return adapter
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0") -> NamespacesAdapter:
|
def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0", hdmf_only:bool=False) -> NamespacesAdapter:
|
||||||
"""
|
"""
|
||||||
Convenience function for loading the NWB core schema + hdmf-common as a namespace adapter.
|
Convenience function for loading the NWB core schema + hdmf-common as a namespace adapter.
|
||||||
|
|
||||||
|
@ -136,14 +136,18 @@ def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0") -> N
|
||||||
Args:
|
Args:
|
||||||
core_version (str): an entry in :attr:`.NWB_CORE_REPO.versions`
|
core_version (str): an entry in :attr:`.NWB_CORE_REPO.versions`
|
||||||
hdmf_version (str): an entry in :attr:`.NWB_CORE_REPO.versions`
|
hdmf_version (str): an entry in :attr:`.NWB_CORE_REPO.versions`
|
||||||
|
hdmf_only (bool): Only return the hdmf common schema
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# First get hdmf-common:
|
# First get hdmf-common:
|
||||||
hdmf_schema = load_namespace_adapter(HDMF_COMMON_REPO, version=hdmf_version)
|
hdmf_schema = load_namespace_adapter(HDMF_COMMON_REPO, version=hdmf_version)
|
||||||
schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version)
|
if hdmf_only:
|
||||||
|
schema = hdmf_schema
|
||||||
|
else:
|
||||||
|
schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version)
|
||||||
|
|
||||||
schema.imported.append(hdmf_schema)
|
schema.imported.append(hdmf_schema)
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
|
@ -5,6 +5,7 @@ Provider for LinkML schema built from NWB schema
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, TypedDict
|
from typing import Dict, Optional, TypedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from linkml_runtime import SchemaView
|
from linkml_runtime import SchemaView
|
||||||
from linkml_runtime.dumpers import yaml_dumper
|
from linkml_runtime.dumpers import yaml_dumper
|
||||||
|
@ -19,7 +20,8 @@ from nwb_linkml.ui import AdapterProgress
|
||||||
from nwb_schema_language import Namespaces
|
from nwb_schema_language import Namespaces
|
||||||
|
|
||||||
|
|
||||||
class LinkMLSchemaBuild(TypedDict):
|
@dataclass
|
||||||
|
class LinkMLSchemaBuild:
|
||||||
"""Build result from :meth:`.LinkMLProvider.build`"""
|
"""Build result from :meth:`.LinkMLProvider.build`"""
|
||||||
|
|
||||||
version: str
|
version: str
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -14,6 +15,8 @@ from linkml_runtime.linkml_model import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from nwb_linkml.adapters.namespaces import NamespacesAdapter
|
from nwb_linkml.adapters.namespaces import NamespacesAdapter
|
||||||
|
from nwb_linkml.providers import LinkMLProvider, PydanticProvider
|
||||||
|
from nwb_linkml.providers.linkml import LinkMLSchemaBuild
|
||||||
from nwb_linkml.io import schema as io
|
from nwb_linkml.io import schema as io
|
||||||
from nwb_schema_language import Attribute, Dataset, Group
|
from nwb_schema_language import Attribute, Dataset, Group
|
||||||
|
|
||||||
|
@ -87,6 +90,26 @@ def nwb_core_fixture(request) -> NamespacesAdapter:
|
||||||
|
|
||||||
return nwb_core
|
return nwb_core
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def nwb_core_linkml(nwb_core_fixture, tmp_output_dir) -> LinkMLSchemaBuild:
|
||||||
|
provider = LinkMLProvider(tmp_output_dir, allow_repo=False, verbose=False)
|
||||||
|
result = provider.build(ns_adapter=nwb_core_fixture, force=True)
|
||||||
|
return result['core']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def nwb_core_module(nwb_core_linkml: LinkMLSchemaBuild, tmp_output_dir) -> ModuleType:
|
||||||
|
"""
|
||||||
|
Generated pydantic namespace from nwb core
|
||||||
|
"""
|
||||||
|
provider = PydanticProvider(tmp_output_dir, verbose=False)
|
||||||
|
result = provider.build(nwb_core_linkml.namespace, force=True)
|
||||||
|
mod = provider.get('core', version=nwb_core_linkml.version, allow_repo=False)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def data_dir() -> Path:
|
def data_dir() -> Path:
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
from typing import Tuple
|
from typing import Tuple, TYPE_CHECKING
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
# FIXME: Make this just be the output of the provider by patching into import machinery
|
||||||
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||||
ElectricalSeries,
|
ElectricalSeries,
|
||||||
|
ElectrodeGroup,
|
||||||
NWBFileGeneralExtracellularEphysElectrodes,
|
NWBFileGeneralExtracellularEphysElectrodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def electrical_series() -> Tuple[ElectricalSeries, NWBFileGeneralExtracellularEphysElectrodes]:
|
def electrical_series() -> Tuple["ElectricalSeries", "NWBFileGeneralExtracellularEphysElectrodes"]:
|
||||||
"""
|
"""
|
||||||
Demo electrical series with adjoining electrodes
|
Demo electrical series with adjoining electrodes
|
||||||
"""
|
"""
|
||||||
|
@ -19,9 +22,16 @@ def electrical_series() -> Tuple[ElectricalSeries, NWBFileGeneralExtracellularEp
|
||||||
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes)
|
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes)
|
||||||
timestamps = np.linspace(0, 1, n_times)
|
timestamps = np.linspace(0, 1, n_times)
|
||||||
|
|
||||||
|
# electrode group is the physical description of the electrodes
|
||||||
|
electrode_group = ElectrodeGroup(
|
||||||
|
name="GroupA",
|
||||||
|
)
|
||||||
|
|
||||||
# make electrodes tables
|
# make electrodes tables
|
||||||
electrodes = NWBFileGeneralExtracellularEphysElectrodes(
|
electrodes = NWBFileGeneralExtracellularEphysElectrodes(
|
||||||
id=np.arange(0, n_electrodes),
|
id=np.arange(0, n_electrodes),
|
||||||
x=np.arange(0, n_electrodes),
|
x=np.arange(0, n_electrodes),
|
||||||
y=np.arange(n_electrodes, n_electrodes * 2),
|
y=np.arange(n_electrodes, n_electrodes * 2),
|
||||||
|
group=[electrode_group]*n_electrodes,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,13 +14,13 @@ from rich import print
|
||||||
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
|
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
|
||||||
|
|
||||||
from nwb_linkml.providers import LinkMLProvider, PydanticProvider
|
from nwb_linkml.providers import LinkMLProvider, PydanticProvider
|
||||||
from nwb_linkml.providers.git import NWB_CORE_REPO, GitRepo
|
from nwb_linkml.providers.git import NWB_CORE_REPO, HDMF_COMMON_REPO, GitRepo
|
||||||
from nwb_linkml.io import schema as io
|
from nwb_linkml.io import schema as io
|
||||||
|
|
||||||
def generate_core_yaml(output_path:Path, dry_run:bool=False):
|
def generate_core_yaml(output_path:Path, dry_run:bool=False, hdmf_only:bool=False):
|
||||||
"""Just build the latest version of the core schema"""
|
"""Just build the latest version of the core schema"""
|
||||||
|
|
||||||
core = io.load_nwb_core()
|
core = io.load_nwb_core(hdmf_only=hdmf_only)
|
||||||
built_schemas = core.build().schemas
|
built_schemas = core.build().schemas
|
||||||
for schema in built_schemas:
|
for schema in built_schemas:
|
||||||
output_file = output_path / (schema.name + '.yaml')
|
output_file = output_path / (schema.name + '.yaml')
|
||||||
|
@ -45,11 +45,10 @@ def generate_core_pydantic(yaml_path:Path, output_path:Path, dry_run:bool=False)
|
||||||
with open(pydantic_file, 'w') as pfile:
|
with open(pydantic_file, 'w') as pfile:
|
||||||
pfile.write(gen_pydantic)
|
pfile.write(gen_pydantic)
|
||||||
|
|
||||||
def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
|
def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False, repo:GitRepo=NWB_CORE_REPO, hdmf_only=False):
|
||||||
"""
|
"""
|
||||||
Generate linkml models for all versions
|
Generate linkml models for all versions
|
||||||
"""
|
"""
|
||||||
repo = GitRepo(NWB_CORE_REPO)
|
|
||||||
#repo.clone(force=True)
|
#repo.clone(force=True)
|
||||||
repo.clone()
|
repo.clone()
|
||||||
|
|
||||||
|
@ -81,7 +80,7 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
|
||||||
linkml_task = None
|
linkml_task = None
|
||||||
pydantic_task = None
|
pydantic_task = None
|
||||||
|
|
||||||
for version in NWB_CORE_REPO.versions:
|
for version in repo.namespace.versions:
|
||||||
# build linkml
|
# build linkml
|
||||||
try:
|
try:
|
||||||
# check out the version (this should also refresh the hdmf-common schema)
|
# check out the version (this should also refresh the hdmf-common schema)
|
||||||
|
@ -91,9 +90,11 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
|
||||||
|
|
||||||
# first load the core namespace
|
# first load the core namespace
|
||||||
core_ns = io.load_namespace_adapter(repo.namespace_file)
|
core_ns = io.load_namespace_adapter(repo.namespace_file)
|
||||||
# then the hdmf-common namespace
|
if repo.namespace == NWB_CORE_REPO:
|
||||||
hdmf_common_ns = io.load_namespace_adapter(repo.temp_directory / 'hdmf-common-schema' / 'common' / 'namespace.yaml')
|
# then the hdmf-common namespace
|
||||||
core_ns.imported.append(hdmf_common_ns)
|
hdmf_common_ns = io.load_namespace_adapter(repo.temp_directory / 'hdmf-common-schema' / 'common' / 'namespace.yaml')
|
||||||
|
core_ns.imported.append(hdmf_common_ns)
|
||||||
|
|
||||||
build_progress.update(linkml_task, advance=1, action="Build LinkML")
|
build_progress.update(linkml_task, advance=1, action="Build LinkML")
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,7 +102,7 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
|
||||||
build_progress.update(linkml_task, advance=1, action="Built LinkML")
|
build_progress.update(linkml_task, advance=1, action="Built LinkML")
|
||||||
|
|
||||||
# build pydantic
|
# build pydantic
|
||||||
ns_files = [res['namespace'] for res in linkml_res.values()]
|
ns_files = [res.namespace for res in linkml_res.values()]
|
||||||
|
|
||||||
pydantic_task = build_progress.add_task('', name=version, action='', total=len(ns_files))
|
pydantic_task = build_progress.add_task('', name=version, action='', total=len(ns_files))
|
||||||
for schema in ns_files:
|
for schema in ns_files:
|
||||||
|
@ -129,10 +130,20 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False):
|
||||||
pydantic_task = None
|
pydantic_task = None
|
||||||
|
|
||||||
if not dry_run:
|
if not dry_run:
|
||||||
shutil.rmtree(yaml_path / 'linkml')
|
if hdmf_only:
|
||||||
shutil.rmtree(pydantic_path / 'pydantic')
|
shutil.rmtree(yaml_path / 'linkml' / 'hdmf_common')
|
||||||
shutil.move(tmp_dir / 'linkml', yaml_path)
|
shutil.rmtree(yaml_path / 'linkml' / 'hdmf_experimental')
|
||||||
shutil.move(tmp_dir / 'pydantic', pydantic_path)
|
shutil.rmtree(pydantic_path / 'pydantic' / 'hdmf_common')
|
||||||
|
shutil.rmtree(pydantic_path / 'pydantic' / 'hdmf_experimental')
|
||||||
|
shutil.move(tmp_dir / 'linkml' / 'hdmf_common', yaml_path / 'linkml')
|
||||||
|
shutil.move(tmp_dir / 'linkml' / 'hdmf_experimental', yaml_path / 'linkml')
|
||||||
|
shutil.move(tmp_dir / 'pydantic' / 'hdmf_common', pydantic_path / 'pydantic')
|
||||||
|
shutil.move(tmp_dir / 'pydantic' / 'hdmf_experimental', pydantic_path / 'pydantic')
|
||||||
|
else:
|
||||||
|
shutil.rmtree(yaml_path / 'linkml')
|
||||||
|
shutil.rmtree(pydantic_path / 'pydantic')
|
||||||
|
shutil.move(tmp_dir / 'linkml', yaml_path)
|
||||||
|
shutil.move(tmp_dir / 'pydantic', pydantic_path)
|
||||||
|
|
||||||
# import the most recent version of the schemaz we built
|
# import the most recent version of the schemaz we built
|
||||||
latest_version = sorted((pydantic_path / 'pydantic' / 'core').iterdir(), key=os.path.getmtime)[-1]
|
latest_version = sorted((pydantic_path / 'pydantic' / 'core').iterdir(), key=os.path.getmtime)[-1]
|
||||||
|
@ -167,6 +178,11 @@ def parser() -> ArgumentParser:
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path(__file__).parent.parent / 'nwb_linkml' / 'src' / 'nwb_linkml' / 'models'
|
default=Path(__file__).parent.parent / 'nwb_linkml' / 'src' / 'nwb_linkml' / 'models'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--hdmf',
|
||||||
|
help="Only generate the HDMF namespaces",
|
||||||
|
action="store_true"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--latest',
|
'--latest',
|
||||||
help="Only generate the latest version of the core schemas.",
|
help="Only generate the latest version of the core schemas.",
|
||||||
|
@ -182,14 +198,19 @@ def parser() -> ArgumentParser:
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
|
if args.hdmf:
|
||||||
|
repo = GitRepo(HDMF_COMMON_REPO)
|
||||||
|
else:
|
||||||
|
repo = GitRepo(NWB_CORE_REPO)
|
||||||
|
|
||||||
if not args.dry_run:
|
if not args.dry_run:
|
||||||
args.yaml.mkdir(exist_ok=True)
|
args.yaml.mkdir(exist_ok=True)
|
||||||
args.pydantic.mkdir(exist_ok=True)
|
args.pydantic.mkdir(exist_ok=True)
|
||||||
if args.latest:
|
if args.latest:
|
||||||
generate_core_yaml(args.yaml, args.dry_run)
|
generate_core_yaml(args.yaml, args.dry_run, args.hdmf)
|
||||||
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
|
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
|
||||||
else:
|
else:
|
||||||
generate_versions(args.yaml, args.pydantic, args.dry_run)
|
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in a new issue