From d118477d8a06fe588d7b0668e18464e2c856bbb1 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 31 Jul 2024 01:13:31 -0700 Subject: [PATCH] add dynamictable, vectordata, vectorindex mixins --- docs/index.md | 1 + docs/intro/translation.md | 5 + docs/meta/references.md | 11 + .../src/nwb_linkml/generators/pydantic.py | 37 ++- nwb_linkml/src/nwb_linkml/includes/hdmf.py | 243 ++++++++++++++++-- nwb_linkml/src/nwb_linkml/io/schema.py | 10 +- nwb_linkml/src/nwb_linkml/providers/linkml.py | 4 +- nwb_linkml/tests/fixtures.py | 23 ++ nwb_linkml/tests/test_includes/test_hdmf.py | 14 +- scripts/generate_core.py | 53 ++-- 10 files changed, 361 insertions(+), 40 deletions(-) create mode 100644 docs/meta/references.md diff --git a/docs/index.md b/docs/index.md index 9ffc483..32ed37e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -284,6 +284,7 @@ api/nwb_linkml/schema/index meta/todo meta/changelog +meta/references genindex ``` diff --git a/docs/intro/translation.md b/docs/intro/translation.md index d5c078a..6170fee 100644 --- a/docs/intro/translation.md +++ b/docs/intro/translation.md @@ -20,6 +20,11 @@ ### DynamicTable +```{note} +See the [DynamicTable](https://hdmf-common-schema.readthedocs.io/en/stable/format_description.html#dynamictable) +reference docs +``` + One of the major special cases in NWB is the use of `DynamicTable` to contain tabular data that contains columns that are not in the base spec. diff --git a/docs/meta/references.md b/docs/meta/references.md new file mode 100644 index 0000000..dd36a1a --- /dev/null +++ b/docs/meta/references.md @@ -0,0 +1,11 @@ +# References + +## Documentation + +- [hdmf](https://hdmf.readthedocs.io/en/stable/) +- [hdmf-common-schema](https://hdmf-common-schema.readthedocs.io/en/stable/) +- [pynwb](https://pynwb.readthedocs.io/en/latest/) + +```{todo} +Add the bibtex refs to NWB papers :) +``` \ No newline at end of file diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index b42c83a..3ecf605 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -40,7 +40,7 @@ from types import ModuleType from typing import ClassVar, Dict, List, Optional, Tuple, Type, Union from linkml.generators import PydanticGenerator -from linkml.generators.pydanticgen.build import SlotResult +from linkml.generators.pydanticgen.build import SlotResult, ClassResult from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.template import PydanticModule, Import, Imports from linkml_runtime.linkml_model.meta import ( @@ -63,6 +63,7 @@ from pydantic import BaseModel from nwb_linkml.maps import flat_to_nptyping from nwb_linkml.maps.naming import module_case, version_module_case from nwb_linkml.includes.types import ModelTypeString, _get_name, NamedString, NamedImports +from nwb_linkml.includes.hdmf import DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_INJECTS OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]") @@ -96,6 +97,9 @@ class NWBPydanticGenerator(PydanticGenerator): def _check_anyof( self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView ): # pragma: no cover + """ + Overridden to allow `array` in any_of + """ # Confirm that the original slot range (ignoring the default that comes in from # induced_slot) isn't in addition to setting any_of allowed_keys = ("array",) @@ -127,6 +131,10 @@ class NWBPydanticGenerator(PydanticGenerator): return slot + def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult: + cls = AfterGenerateClass.inject_dynamictable(cls) + return cls + def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: if "source_file" in template.meta: del template.meta["source_file"] @@ -226,6 +234,33 @@ class AfterGenerateSlot: slot.imports = NamedImports return slot +class AfterGenerateClass: + """ + Container class for class-modification methods + """ + + @staticmethod + def inject_dynamictable(cls: ClassResult) -> ClassResult: + if cls.cls.name == "DynamicTable": + cls.cls.bases = ["DynamicTableMixin"] + + if cls.injected_classes is None: + cls.injected_classes = DYNAMIC_TABLE_INJECTS.copy() + else: + cls.injected_classes.extend(DYNAMIC_TABLE_INJECTS.copy()) + + if isinstance(cls.imports, Imports): + cls.imports += DYNAMIC_TABLE_IMPORTS + elif isinstance(cls.imports, list): + cls.imports = Imports(imports=cls.imports) + DYNAMIC_TABLE_IMPORTS + else: + cls.imports = DYNAMIC_TABLE_IMPORTS.model_copy() + elif cls.cls.name == "VectorData": + cls.cls.bases = ["VectorDataMixin"] + elif cls.cls.name == "VectorIndex": + cls.cls.bases = ["VectorIndexMixin"] + return cls + def compile_python( text_or_fn: str, package_path: Path = None, module_name: str = "test" diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index 32647e7..fdbd355 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -1,39 +1,248 @@ """ Special types for mimicking HDMF special case behavior """ +from typing import Any, ClassVar, Dict, List, Optional, Union, Tuple, overload, TYPE_CHECKING -from typing import Any -from pydantic import BaseModel, ConfigDict +from linkml.generators.pydanticgen.template import Imports, Import, ObjectImport +from numpydantic import NDArray +from pandas import DataFrame +from pydantic import BaseModel, ConfigDict, Field, model_validator + +if TYPE_CHECKING: + from nwb_linkml.models import VectorData, VectorIndex class DynamicTableMixin(BaseModel): """ Mixin to make DynamicTable subclasses behave like tables/dataframes + + Mimicking some of the behavior from :class:`hdmf.common.table.DynamicTable` + but simplifying along the way :) """ model_config = ConfigDict(extra="allow") + __pydantic_extra__: Dict[str, Union[list, "NDArray", "VectorData"]] + NON_COLUMN_FIELDS: ClassVar[tuple[str]] = ("name", "colnames", "description",) + + # overridden by subclass but implemented here for testing and typechecking purposes :) + colnames: List[str] = Field(default_factory=list) + + @property + def _columns(self) -> Dict[str, Union[list, "NDArray", "VectorData"]]: + return { + k: getattr(self, k) for i, k in enumerate(self.colnames) + } + + @property + def _columns_list(self) -> List[Union[list, "NDArray", "VectorData"]]: + return [getattr(self, k) for i, k in enumerate(self.colnames)] + + @overload + def __getitem__(self, item: str) -> Union[list, "NDArray", "VectorData"]: ... + + @overload + def __getitem__(self, item: int) -> DataFrame: ... + + @overload + def __getitem__(self, item: Tuple[int, Union[int, str]]) -> Any: ... + + @overload + def __getitem__(self, item: Tuple[Union[int,slice], ...]) -> Union[DataFrame, list, "NDArray", "VectorData",]: ... + + @overload + def __getitem__(self, item: slice) -> DataFrame: ... + + def __getitem__(self, item: Union[str, int, slice, Tuple[int, Union[int, str]], Tuple[Union[int, slice], ...],]) -> Any: + """ + Get an item from the table + + If item is... + + - ``str`` : get the column with this name + - ``int`` : get the row at this index + - ``tuple[int, int]`` : get a specific cell value eg. (0,1) gets the 0th row and 1st column + - ``tuple[int, str]`` : get a specific cell value eg. (0, 'colname') + gets the 0th row from ``colname`` + - ``tuple[int | slice, int | slice]`` : get a range of cells from a range of columns. + returns as a :class:`pandas.DataFrame` + """ + if isinstance(item, str): + return self._columns[item] + if isinstance(item, (int, slice)): + return DataFrame.from_dict(self._slice_range(item)) + elif isinstance(item, tuple): + if len(item) != 2: + raise ValueError( + f"DynamicTables are 2-dimensional, can't index with more than 2 indices like {item}") + + # all other cases are tuples of (rows, cols) + rows, cols = item + if isinstance(cols, (int, slice)): + cols = self.colnames[cols] + data = self._slice_range(rows, cols) + return DataFrame.from_dict(data) + else: + raise ValueError(f"Unsure how to get item with key {item}") + + + def _slice_range(self, rows: Union[int, slice], cols: Optional[Union[str, List[str]]] = None) -> Dict[str, Union[list, "NDArray", "VectorData"]]: + if cols is None: + cols = self.colnames + elif isinstance(cols, str): + cols = [cols] + + data = { + k: self._columns[k][rows] for k in cols + } + return data - # @model_validator(mode='after') - # def ensure_equal_length(cls, model: 'DynamicTableMixin') -> 'DynamicTableMixin': - # """ - # Ensure all vectors are of equal length - # """ - # raise NotImplementedError('TODO') - # - # @model_validator(mode="after") - # def create_index_backrefs(cls, model: 'DynamicTableMixin') -> 'DynamicTableMixin': - # """ - # Ensure that vectordata with vectorindexes know about them - # """ - # raise NotImplementedError('TODO') - def __getitem__(self, item: str) -> Any: - raise NotImplementedError("TODO") def __setitem__(self, key: str, value: Any) -> None: raise NotImplementedError("TODO") + def __setattr__(self, key: str, value: Union[list, "NDArray", "VectorData"]): + """ + Add a column, appending it to ``colnames`` + """ + # don't use this while building the model + if not getattr(self, '__pydantic_complete__', False): + return super().__setattr__(key, value) + + if key not in self.model_fields_set and not key.endswith('_index'): + self.colnames.append(key) + + return super().__setattr__(key, value) + + @model_validator(mode="before") + @classmethod + def create_colnames(cls, model: Dict[str, Any]): + """ + Construct colnames from arguments. + + the model dict is ordered after python3.6, so we can use that minus + anything in :attr:`.NON_COLUMN_FIELDS` to determine order implied from passage order + """ + if 'colnames' not in model: + colnames = [k for k in model.keys() + if k not in cls.NON_COLUMN_FIELDS + and not k.endswith('_index')] + model['colnames'] = colnames + else: + # add any columns not explicitly given an order at the end + colnames = [k for k in model.keys() if + k not in cls.NON_COLUMN_FIELDS + and not k.endswith('_index') + and k not in model['colnames'].keys() + ] + model['colnames'].extend(colnames) + return model + + @model_validator(mode="after") + def resolve_targets(self) -> "DynamicTableMixin": + """ + Ensure that any implicitly indexed columns are linked, and create backlinks + """ + for key, col in self._columns.items(): + if isinstance(col, VectorData): + # find an index + idx = None + for field_name in self.model_fields_set: + # implicit name-based index + field = getattr(self, field_name) + if isinstance(field, VectorIndex): + if field_name == f"{key}_index": + idx = field + break + elif field.target is col: + idx = field + break + if idx is not None: + col._index = idx + idx.target = col + return self + + + + +class VectorDataMixin(BaseModel): + """ + Mixin class to give VectorData indexing abilities + """ + _index: Optional["VectorIndex"] = None + + # redefined in `VectorData`, but included here for testing and type checking + array: Optional[NDArray] = None + + def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: + if self._index: + # Following hdmf, VectorIndex is the thing that knows how to do the slicing + return self._index[item] + else: + return self.array[item] + + def __setitem__(self, key, value) -> None: + if self._index: + # Following hdmf, VectorIndex is the thing that knows how to do the slicing + self._index[key] = value + else: + self.array[key] = value + + +class VectorIndexMixin(BaseModel): + """ + Mixin class to give VectorIndex indexing abilities + """ + # redefined in `VectorData`, but included here for testing and type checking + array: Optional[NDArray] = None + target: Optional["VectorData"] = None + + def _getitem_helper(self, arg: int): + """ + Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper` + """ + + start = 0 if arg == 0 else self.array[arg - 1] + end = self.array[arg] + return self.target.array[slice(start, end)] + + def __getitem__(self, item: Union[int, slice]) -> Any: + if self.target is None: + return self.array[item] + elif type(self.target).__name__ == "VectorData": + if isinstance(item, int): + return self._getitem_helper(item) + else: + idx = range(*item.indices(len(self.array))) + return [self._getitem_helper(i) for i in idx] + else: + raise NotImplementedError("DynamicTableRange not supported yet") + + + def __setitem__(self, key, value) -> None: + if self._index: + # VectorIndex is the thing that knows how to do the slicing + self._index[key] = value + else: + self.array[key] = value + + +DYNAMIC_TABLE_IMPORTS = Imports( + imports = [ + Import(module="pandas", objects=[ObjectImport(name="DataFrame")]), + Import(module="typing", objects=[ObjectImport(name="ClassVar"), ObjectImport(name="overload"), ObjectImport(name="Tuple")]), + Import(module='numpydantic', objects=[ObjectImport(name='NDArray')]), + Import(module="pydantic", objects=[ObjectImport(name="model_validator")]) + ] +) +""" +Imports required for the dynamic table mixin + +VectorData is purposefully excluded as an import or an inject so that it will be +resolved to the VectorData definition in the generated module +""" +DYNAMIC_TABLE_INJECTS = [VectorDataMixin, VectorIndexMixin, DynamicTableMixin] # class VectorDataMixin(BaseModel): # index: Optional[BaseModel] = None diff --git a/nwb_linkml/src/nwb_linkml/io/schema.py b/nwb_linkml/src/nwb_linkml/io/schema.py index 3e2a76e..d5ce5c8 100644 --- a/nwb_linkml/src/nwb_linkml/io/schema.py +++ b/nwb_linkml/src/nwb_linkml/io/schema.py @@ -120,7 +120,7 @@ def load_namespace_adapter( return adapter -def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0") -> NamespacesAdapter: +def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0", hdmf_only:bool=False) -> NamespacesAdapter: """ Convenience function for loading the NWB core schema + hdmf-common as a namespace adapter. @@ -136,14 +136,18 @@ def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0") -> N Args: core_version (str): an entry in :attr:`.NWB_CORE_REPO.versions` hdmf_version (str): an entry in :attr:`.NWB_CORE_REPO.versions` + hdmf_only (bool): Only return the hdmf common schema Returns: """ # First get hdmf-common: hdmf_schema = load_namespace_adapter(HDMF_COMMON_REPO, version=hdmf_version) - schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version) + if hdmf_only: + schema = hdmf_schema + else: + schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version) - schema.imported.append(hdmf_schema) + schema.imported.append(hdmf_schema) return schema diff --git a/nwb_linkml/src/nwb_linkml/providers/linkml.py b/nwb_linkml/src/nwb_linkml/providers/linkml.py index 831bd2c..f868de7 100644 --- a/nwb_linkml/src/nwb_linkml/providers/linkml.py +++ b/nwb_linkml/src/nwb_linkml/providers/linkml.py @@ -5,6 +5,7 @@ Provider for LinkML schema built from NWB schema import shutil from pathlib import Path from typing import Dict, Optional, TypedDict +from dataclasses import dataclass from linkml_runtime import SchemaView from linkml_runtime.dumpers import yaml_dumper @@ -19,7 +20,8 @@ from nwb_linkml.ui import AdapterProgress from nwb_schema_language import Namespaces -class LinkMLSchemaBuild(TypedDict): +@dataclass +class LinkMLSchemaBuild: """Build result from :meth:`.LinkMLProvider.build`""" version: str diff --git a/nwb_linkml/tests/fixtures.py b/nwb_linkml/tests/fixtures.py index e4b8fae..092ba60 100644 --- a/nwb_linkml/tests/fixtures.py +++ b/nwb_linkml/tests/fixtures.py @@ -1,6 +1,7 @@ import shutil from dataclasses import dataclass, field from pathlib import Path +from types import ModuleType from typing import Dict, Optional import pytest @@ -14,6 +15,8 @@ from linkml_runtime.linkml_model import ( ) from nwb_linkml.adapters.namespaces import NamespacesAdapter +from nwb_linkml.providers import LinkMLProvider, PydanticProvider +from nwb_linkml.providers.linkml import LinkMLSchemaBuild from nwb_linkml.io import schema as io from nwb_schema_language import Attribute, Dataset, Group @@ -87,6 +90,26 @@ def nwb_core_fixture(request) -> NamespacesAdapter: return nwb_core +@pytest.fixture(scope="session") +def nwb_core_linkml(nwb_core_fixture, tmp_output_dir) -> LinkMLSchemaBuild: + provider = LinkMLProvider(tmp_output_dir, allow_repo=False, verbose=False) + result = provider.build(ns_adapter=nwb_core_fixture, force=True) + return result['core'] + + +@pytest.fixture(scope="session") +def nwb_core_module(nwb_core_linkml: LinkMLSchemaBuild, tmp_output_dir) -> ModuleType: + """ + Generated pydantic namespace from nwb core + """ + provider = PydanticProvider(tmp_output_dir, verbose=False) + result = provider.build(nwb_core_linkml.namespace, force=True) + mod = provider.get('core', version=nwb_core_linkml.version, allow_repo=False) + return mod + + + + @pytest.fixture(scope="session") def data_dir() -> Path: diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index 0024917..572a651 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -1,16 +1,19 @@ -from typing import Tuple +from typing import Tuple, TYPE_CHECKING +from types import ModuleType import numpy as np import pytest +# FIXME: Make this just be the output of the provider by patching into import machinery from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( ElectricalSeries, + ElectrodeGroup, NWBFileGeneralExtracellularEphysElectrodes, ) @pytest.fixture() -def electrical_series() -> Tuple[ElectricalSeries, NWBFileGeneralExtracellularEphysElectrodes]: +def electrical_series() -> Tuple["ElectricalSeries", "NWBFileGeneralExtracellularEphysElectrodes"]: """ Demo electrical series with adjoining electrodes """ @@ -19,9 +22,16 @@ def electrical_series() -> Tuple[ElectricalSeries, NWBFileGeneralExtracellularEp data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes) timestamps = np.linspace(0, 1, n_times) + # electrode group is the physical description of the electrodes + electrode_group = ElectrodeGroup( + name="GroupA", + ) + # make electrodes tables electrodes = NWBFileGeneralExtracellularEphysElectrodes( id=np.arange(0, n_electrodes), x=np.arange(0, n_electrodes), y=np.arange(n_electrodes, n_electrodes * 2), + group=[electrode_group]*n_electrodes, + ) diff --git a/scripts/generate_core.py b/scripts/generate_core.py index e711e0e..b447e00 100644 --- a/scripts/generate_core.py +++ b/scripts/generate_core.py @@ -14,13 +14,13 @@ from rich import print from nwb_linkml.generators.pydantic import NWBPydanticGenerator from nwb_linkml.providers import LinkMLProvider, PydanticProvider -from nwb_linkml.providers.git import NWB_CORE_REPO, GitRepo +from nwb_linkml.providers.git import NWB_CORE_REPO, HDMF_COMMON_REPO, GitRepo from nwb_linkml.io import schema as io -def generate_core_yaml(output_path:Path, dry_run:bool=False): +def generate_core_yaml(output_path:Path, dry_run:bool=False, hdmf_only:bool=False): """Just build the latest version of the core schema""" - core = io.load_nwb_core() + core = io.load_nwb_core(hdmf_only=hdmf_only) built_schemas = core.build().schemas for schema in built_schemas: output_file = output_path / (schema.name + '.yaml') @@ -45,11 +45,10 @@ def generate_core_pydantic(yaml_path:Path, output_path:Path, dry_run:bool=False) with open(pydantic_file, 'w') as pfile: pfile.write(gen_pydantic) -def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False): +def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False, repo:GitRepo=NWB_CORE_REPO, hdmf_only=False): """ Generate linkml models for all versions """ - repo = GitRepo(NWB_CORE_REPO) #repo.clone(force=True) repo.clone() @@ -81,7 +80,7 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False): linkml_task = None pydantic_task = None - for version in NWB_CORE_REPO.versions: + for version in repo.namespace.versions: # build linkml try: # check out the version (this should also refresh the hdmf-common schema) @@ -91,9 +90,11 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False): # first load the core namespace core_ns = io.load_namespace_adapter(repo.namespace_file) - # then the hdmf-common namespace - hdmf_common_ns = io.load_namespace_adapter(repo.temp_directory / 'hdmf-common-schema' / 'common' / 'namespace.yaml') - core_ns.imported.append(hdmf_common_ns) + if repo.namespace == NWB_CORE_REPO: + # then the hdmf-common namespace + hdmf_common_ns = io.load_namespace_adapter(repo.temp_directory / 'hdmf-common-schema' / 'common' / 'namespace.yaml') + core_ns.imported.append(hdmf_common_ns) + build_progress.update(linkml_task, advance=1, action="Build LinkML") @@ -101,7 +102,7 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False): build_progress.update(linkml_task, advance=1, action="Built LinkML") # build pydantic - ns_files = [res['namespace'] for res in linkml_res.values()] + ns_files = [res.namespace for res in linkml_res.values()] pydantic_task = build_progress.add_task('', name=version, action='', total=len(ns_files)) for schema in ns_files: @@ -129,10 +130,20 @@ def generate_versions(yaml_path:Path, pydantic_path:Path, dry_run:bool=False): pydantic_task = None if not dry_run: - shutil.rmtree(yaml_path / 'linkml') - shutil.rmtree(pydantic_path / 'pydantic') - shutil.move(tmp_dir / 'linkml', yaml_path) - shutil.move(tmp_dir / 'pydantic', pydantic_path) + if hdmf_only: + shutil.rmtree(yaml_path / 'linkml' / 'hdmf_common') + shutil.rmtree(yaml_path / 'linkml' / 'hdmf_experimental') + shutil.rmtree(pydantic_path / 'pydantic' / 'hdmf_common') + shutil.rmtree(pydantic_path / 'pydantic' / 'hdmf_experimental') + shutil.move(tmp_dir / 'linkml' / 'hdmf_common', yaml_path / 'linkml') + shutil.move(tmp_dir / 'linkml' / 'hdmf_experimental', yaml_path / 'linkml') + shutil.move(tmp_dir / 'pydantic' / 'hdmf_common', pydantic_path / 'pydantic') + shutil.move(tmp_dir / 'pydantic' / 'hdmf_experimental', pydantic_path / 'pydantic') + else: + shutil.rmtree(yaml_path / 'linkml') + shutil.rmtree(pydantic_path / 'pydantic') + shutil.move(tmp_dir / 'linkml', yaml_path) + shutil.move(tmp_dir / 'pydantic', pydantic_path) # import the most recent version of the schemaz we built latest_version = sorted((pydantic_path / 'pydantic' / 'core').iterdir(), key=os.path.getmtime)[-1] @@ -167,6 +178,11 @@ def parser() -> ArgumentParser: type=Path, default=Path(__file__).parent.parent / 'nwb_linkml' / 'src' / 'nwb_linkml' / 'models' ) + parser.add_argument( + '--hdmf', + help="Only generate the HDMF namespaces", + action="store_true" + ) parser.add_argument( '--latest', help="Only generate the latest version of the core schemas.", @@ -182,14 +198,19 @@ def parser() -> ArgumentParser: def main(): args = parser().parse_args() + if args.hdmf: + repo = GitRepo(HDMF_COMMON_REPO) + else: + repo = GitRepo(NWB_CORE_REPO) + if not args.dry_run: args.yaml.mkdir(exist_ok=True) args.pydantic.mkdir(exist_ok=True) if args.latest: - generate_core_yaml(args.yaml, args.dry_run) + generate_core_yaml(args.yaml, args.dry_run, args.hdmf) generate_core_pydantic(args.yaml, args.pydantic, args.dry_run) else: - generate_versions(args.yaml, args.pydantic, args.dry_run) + generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf) if __name__ == "__main__": main()