From cad57554fd04095fa49ffecb1ca7d122258d7891 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 13 Sep 2024 23:05:34 -0700 Subject: [PATCH] get ting there, working rolldown of extra attributes, but something still funny in patchclampseries children w.r.t. losing attributes in data --- .../src/nwb_linkml/adapters/namespaces.py | 266 ++++++++++++------ .../src/nwb_linkml/generators/pydantic.py | 15 +- nwb_linkml/src/nwb_linkml/util.py | 73 ----- .../datamodel/nwb_schema_pydantic.py | 11 +- .../src/nwb_schema_language/generator.py | 20 +- .../src/nwb_schema_language/util.py | 39 +++ scripts/generate_core.py | 5 +- 7 files changed, 264 insertions(+), 165 deletions(-) delete mode 100644 nwb_linkml/src/nwb_linkml/util.py create mode 100644 nwb_schema_language/src/nwb_schema_language/util.py diff --git a/nwb_linkml/src/nwb_linkml/adapters/namespaces.py b/nwb_linkml/src/nwb_linkml/adapters/namespaces.py index 78e3027..afbb82d 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/namespaces.py +++ b/nwb_linkml/src/nwb_linkml/adapters/namespaces.py @@ -8,7 +8,6 @@ for extracting information and generating translated schema import contextlib from copy import copy from pathlib import Path -from pprint import pformat from typing import Dict, Generator, List, Optional from linkml_runtime.dumpers import yaml_dumper @@ -19,7 +18,6 @@ from nwb_linkml.adapters.adapter import Adapter, BuildResult from nwb_linkml.adapters.schema import SchemaAdapter from nwb_linkml.lang_elements import NwbLangSchema from nwb_linkml.ui import AdapterProgress -from nwb_linkml.util import merge_dicts from nwb_schema_language import Dataset, Group, Namespaces @@ -188,93 +186,105 @@ class NamespacesAdapter(Adapter): if not cls.neurodata_type_inc: continue - # get parents - parent = self.get(cls.neurodata_type_inc) - parents = [parent] - while parent.neurodata_type_inc: - parent = self.get(parent.neurodata_type_inc) - parents.insert(0, parent) - parents.append(cls) + parents = self._get_class_ancestors(cls, include_child=True) # merge and cast - # note that we don't want to exclude_none in the model dump here, - # if the child class has a field completely unset, we want to inherit it - # from the parent without rolling it down - we are only rolling down - # the things that need to be modified/merged in the child new_cls: dict = {} - for parent in parents: - new_cls = merge_dicts( - new_cls, - parent.model_dump(exclude_unset=True), - list_key="name", - exclude=["neurodata_type_def"], - ) + for i, parent in enumerate(parents): + # if parent.neurodata_type_def == "PatchClampSeries": + # pdb.set_trace() + complete = True + if i == len(parents) - 1: + complete = False + new_cls = roll_down_nwb_class(new_cls, parent, complete=complete) new_cls: Group | Dataset = type(cls)(**new_cls) new_cls.parent = cls.parent # reinsert - if new_cls.parent: - if isinstance(cls, Dataset): - new_cls.parent.datasets[new_cls.parent.datasets.index(cls)] = new_cls - else: - new_cls.parent.groups[new_cls.parent.groups.index(cls)] = new_cls + self._overwrite_class(new_cls, cls) + + def _get_class_ancestors( + self, cls: Dataset | Group, include_child: bool = True + ) -> list[Dataset | Group]: + """ + Get the chain of ancestor classes inherited via ``neurodata_type_inc`` + + Args: + cls (:class:`.Dataset` | :class:`.Group`): The class to get ancestors of + include_child (bool): If ``True`` (default), include ``cls`` in the output list + """ + parent = self.get(cls.neurodata_type_inc) + parents = [parent] + while parent.neurodata_type_inc: + parent = self.get(parent.neurodata_type_inc) + parents.insert(0, parent) + + if include_child: + parents.append(cls) + + return parents + + def _overwrite_class(self, new_cls: Dataset | Group, old_cls: Dataset | Group): + """ + Overwrite the version of a dataset or group that is stored in our schemas + """ + if old_cls.parent: + if isinstance(old_cls, Dataset): + new_cls.parent.datasets[new_cls.parent.datasets.index(old_cls)] = new_cls else: - # top level class, need to go and find it - found = False - for schema in self.all_schemas(): - if isinstance(cls, Dataset): - if cls in schema.datasets: - schema.datasets[schema.datasets.index(cls)] = new_cls - found = True - break - else: - if cls in schema.groups: - schema.groups[schema.groups.index(cls)] = new_cls - found = True - break - if not found: - raise KeyError( - f"Unable to find source schema for {cls} when reinserting after rolling" - " down!" - ) - - def find_type_source(self, name: str) -> SchemaAdapter: - """ - Given some neurodata_type_inc, find the schema that it's defined in. - - Rather than returning as soon as a match is found, check all - """ - # First check within the main schema - internal_matches = [] - for schema in self.schemas: - class_names = [cls.neurodata_type_def for cls in schema.created_classes] - if name in class_names: - internal_matches.append(schema) - - if len(internal_matches) > 1: - raise KeyError( - f"Found multiple schemas in namespace that define {name}:\ninternal:" - f" {pformat(internal_matches)}\nimported:{pformat(internal_matches)}" - ) - elif len(internal_matches) == 1: - return internal_matches[0] - - import_matches = [] - for imported_ns in self.imported: - for schema in imported_ns.schemas: - class_names = [cls.neurodata_type_def for cls in schema.created_classes] - if name in class_names: - import_matches.append(schema) - - if len(import_matches) > 1: - raise KeyError( - f"Found multiple schemas in namespace that define {name}:\ninternal:" - f" {pformat(internal_matches)}\nimported:{pformat(import_matches)}" - ) - elif len(import_matches) == 1: - return import_matches[0] + new_cls.parent.groups[new_cls.parent.groups.index(old_cls)] = new_cls else: - raise KeyError(f"No schema found that define {name}") + # top level class, need to go and find it + schema = self.find_type_source(old_cls) + if isinstance(new_cls, Dataset): + schema.datasets[schema.datasets.index(old_cls)] = new_cls + else: + schema.groups[schema.groups.index(old_cls)] = new_cls + + def find_type_source(self, cls: str | Dataset | Group, fast: bool = False) -> SchemaAdapter: + """ + Given some type (as `neurodata_type_def`), find the schema that it's defined in. + + Rather than returning as soon as a match is found, ensure that duplicates are + not found within the primary schema, then so the same for all imported schemas. + + Args: + cls (str | :class:`.Dataset` | :class:`.Group`): The ``neurodata_type_def`` + to look for the source of. If a Dataset or Group, look for the object itself + (cls in schema.datasets), otherwise look for a class with a matching name. + fast (bool): If ``True``, return as soon as a match is found. + If ``False`, return after checking all schemas for duplicates. + + Returns: + :class:`.SchemaAdapter` + + Raises: + KeyError: if multiple schemas or no schemas are found + """ + matches = [] + for schema in self.all_schemas(): + in_schema = False + if isinstance(cls, str) and cls in [ + c.neurodata_type_def for c in schema.created_classes + ]: + in_schema = True + elif isinstance(cls, Dataset) and cls in schema.datasets: + in_schema = True + elif isinstance(cls, Group) and cls in schema.groups: + in_schema = True + + if in_schema: + if fast: + return schema + else: + matches.append(schema) + + if len(matches) > 1: + raise KeyError(f"Found multiple schemas in namespace that define {cls}:\n{matches}") + elif len(matches) == 1: + return matches[0] + else: + raise KeyError(f"No schema found that define {cls}") def _populate_imports(self) -> "NamespacesAdapter": """ @@ -378,3 +388,99 @@ class NamespacesAdapter(Adapter): for imported in self.imported: for sch in imported.schemas: yield sch + + +def roll_down_nwb_class( + source: Group | Dataset | dict, target: Group | Dataset | dict, complete: bool = False +) -> dict: + """ + Merge an ancestor (via ``neurodata_type_inc`` ) source class with a + child ``target`` class. + + On the first recurive pass, only those values that are set on the target are copied from the + source class - this isn't a true merging, what we are after is to recursively merge all the + values that are modified in the child class with those of the parent class below the top level, + the top-level attributes will be carried through via normal inheritance. + + Rather than re-instantiating the child class, we return the dictionary so that this + function can be used in series to merge a whole ancestry chain within + :class:`.NamespacesAdapter` , but this isn't exposed in the function since + class definitions can be spread out over many schemas, and we need the orchestration + of the adapter to have them in all cases we'd be using this. + + Args: + source (dict): source dictionary + target (dict): target dictionary (values merged over source) + complete (bool): (default ``False``)do a complete merge, merging everything + from source to target without trying to minimize redundancy. + Used to collapse ancestor classes before the terminal class. + + References: + https://github.com/NeurodataWithoutBorders/pynwb/issues/1954 + + """ + if isinstance(source, (Group, Dataset)): + source = source.model_dump(exclude_unset=True, exclude_none=True) + if isinstance(target, (Group, Dataset)): + target = target.model_dump(exclude_unset=True, exclude_none=True) + + exclude = ("neurodata_type_def",) + + # if we are on the first recursion, we exclude top-level items that are not set in the target + if complete: + ret = {k: v for k, v in source.items() if k not in exclude} + else: + ret = {k: v for k, v in source.items() if k not in exclude and k in target} + + for key, value in target.items(): + if key not in ret: + ret[key] = value + elif isinstance(value, dict): + if key in ret: + ret[key] = roll_down_nwb_class(ret[key], value, complete=True) + else: + ret[key] = value + elif isinstance(value, list) and all([isinstance(v, dict) for v in value]): + src_keys = {v["name"]: ret[key].index(v) for v in ret.get(key, {}) if "name" in v} + target_keys = {v["name"]: value.index(v) for v in value if "name" in v} + + new_val = [] + # screwy double iteration to preserve dict order + # all dicts not in target, if in depth > 0 + if complete: + new_val.extend( + [ + ret[key][src_keys[k]] + for k in src_keys + if k in set(src_keys.keys()) - set(target_keys.keys()) + ] + ) + # all dicts not in source + new_val.extend( + [ + value[target_keys[k]] + for k in target_keys + if k in set(target_keys.keys()) - set(src_keys.keys()) + ] + ) + # merge dicts in both + new_val.extend( + [ + roll_down_nwb_class(ret[key][src_keys[k]], value[target_keys[k]], complete=True) + for k in target_keys + if k in set(src_keys.keys()).intersection(set(target_keys.keys())) + ] + ) + new_val = sorted(new_val, key=lambda i: i["name"]) + # add any dicts that don't have the list_key + # they can't be merged since they can't be matched + if complete: + new_val.extend([v for v in ret.get(key, {}) if "name" not in v]) + new_val.extend([v for v in value if "name" not in v]) + + ret[key] = new_val + + else: + ret[key] = value + + return ret diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 1928cf5..927e9c2 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -136,7 +136,7 @@ class NWBPydanticGenerator(PydanticGenerator): """Customize dynamictable behavior""" cls = AfterGenerateClass.inject_dynamictable(cls) cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv) - cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import) + cls = AfterGenerateClass.inject_dynamictable_imports(cls, sv, self._get_element_import) cls = AfterGenerateClass.strip_vector_data_slots(cls, sv) return cls @@ -346,19 +346,22 @@ class AfterGenerateClass: return cls @staticmethod - def inject_elementidentifiers( + def inject_dynamictable_imports( cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import] ) -> ClassResult: """ - Inject ElementIdentifiers into module that define dynamictables - - needed to handle ID columns + Ensure that schema that contain dynamictables have all the imports needed to use them """ if ( cls.source.is_a == "DynamicTable" or "DynamicTable" in sv.class_ancestors(cls.source.name) ) and sv.schema.name != "hdmf-common.table": - imp = import_method("ElementIdentifiers") - cls.imports += [imp] + imp = [ + import_method("ElementIdentifiers"), + import_method("VectorData"), + import_method("VectorIndex"), + ] + cls.imports += imp return cls @staticmethod diff --git a/nwb_linkml/src/nwb_linkml/util.py b/nwb_linkml/src/nwb_linkml/util.py deleted file mode 100644 index ca85357..0000000 --- a/nwb_linkml/src/nwb_linkml/util.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -The much maligned junk drawer -""" - - -def merge_dicts( - source: dict, target: dict, list_key: str | None = None, exclude: list[str] | None = None -) -> dict: - """ - Deeply merge nested dictionaries, replacing already-declared keys rather than - e.g. merging lists as well - - Args: - source (dict): source dictionary - target (dict): target dictionary (values merged over source) - list_key (str | None): Optional: if present, merge lists of dicts using this to - identify matching dicts - exclude: (list[str] | None): Optional: if present, exclude keys from parent. - - References: - https://stackoverflow.com/a/20666342/13113166 - - """ - if exclude is None: - exclude = [] - ret = {k: v for k, v in source.items() if k not in exclude} - for key, value in target.items(): - if key not in ret: - ret[key] = value - elif isinstance(value, dict): - if key in ret: - ret[key] = merge_dicts(ret[key], value, list_key, exclude) - else: - ret[key] = value - elif isinstance(value, list) and list_key and all([isinstance(v, dict) for v in value]): - src_keys = {v[list_key]: ret[key].index(v) for v in ret.get(key, {}) if list_key in v} - target_keys = {v[list_key]: value.index(v) for v in value if list_key in v} - - # all dicts not in target - # screwy double iteration to preserve dict order - new_val = [ - ret[key][src_keys[k]] - for k in src_keys - if k in set(src_keys.keys()) - set(target_keys.keys()) - ] - # all dicts not in source - new_val.extend( - [ - value[target_keys[k]] - for k in target_keys - if k in set(target_keys.keys()) - set(src_keys.keys()) - ] - ) - # merge dicts in both - new_val.extend( - [ - merge_dicts(ret[key][src_keys[k]], value[target_keys[k]], list_key, exclude) - for k in target_keys - if k in set(src_keys.keys()).intersection(set(target_keys.keys())) - ] - ) - new_val = sorted(new_val, key=lambda i: i[list_key]) - # add any dicts that don't have the list_key - # they can't be merged since they can't be matched - new_val.extend([v for v in ret.get(key, {}) if list_key not in v]) - new_val.extend([v for v in value if list_key not in v]) - - ret[key] = new_val - - else: - ret[key] = value - - return ret diff --git a/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py b/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py index ca7e8be..83f084c 100644 --- a/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py +++ b/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py @@ -7,6 +7,7 @@ from decimal import Decimal from enum import Enum from typing import Any, ClassVar, Dict, List, Literal, Optional, Union +from nwb_schema_language.util import pformat from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator @@ -23,7 +24,12 @@ class ConfiguredBaseModel(BaseModel): use_enum_values=True, strict=False, ) - pass + + def __repr__(self): + return pformat(self.model_dump(exclude={"parent": True}), self.__class__.__name__) + + def __str__(self): + return repr(self) class LinkMLMeta(RootModel): @@ -44,9 +50,10 @@ class LinkMLMeta(RootModel): class ParentizeMixin(BaseModel): + """Mixin to populate the parent field for nested datasets and groups""" @model_validator(mode="after") - def parentize(self): + def parentize(self) -> BaseModel: """Set the parent attribute for all our fields they have one""" for field_name in self.model_fields: if field_name == "parent": diff --git a/nwb_schema_language/src/nwb_schema_language/generator.py b/nwb_schema_language/src/nwb_schema_language/generator.py index 38519a4..7b0d289 100644 --- a/nwb_schema_language/src/nwb_schema_language/generator.py +++ b/nwb_schema_language/src/nwb_schema_language/generator.py @@ -31,6 +31,22 @@ class ParentizeMixin(BaseModel): return self +STR_METHOD = """ + def __repr__(self): + return pformat( + self.model_dump( + exclude={"parent": True}, + exclude_unset=True, + exclude_None=True + ), + self.__class__.__name__ + ) + + def __str__(self): + return repr(self) +""" + + @dataclass class NWBSchemaLangGenerator(PydanticGenerator): """ @@ -40,8 +56,10 @@ class NWBSchemaLangGenerator(PydanticGenerator): def __init__(self, *args, **kwargs): kwargs["injected_classes"] = [ParentizeMixin] kwargs["imports"] = [ - Import(module="pydantic", objects=[ObjectImport(name="model_validator")]) + Import(module="pydantic", objects=[ObjectImport(name="model_validator")]), + Import(module="nwb_schema_language.util", objects=[ObjectImport(name="pformat")]), ] + kwargs["injected_fields"] = [STR_METHOD] kwargs["black"] = True super().__init__(*args, **kwargs) diff --git a/nwb_schema_language/src/nwb_schema_language/util.py b/nwb_schema_language/src/nwb_schema_language/util.py new file mode 100644 index 0000000..61bc5ed --- /dev/null +++ b/nwb_schema_language/src/nwb_schema_language/util.py @@ -0,0 +1,39 @@ +from pprint import pformat as _pformat +import textwrap +import re + + +def pformat(fields: dict, cls_name: str, indent: str = " ") -> str: + """ + pretty format the fields of the items of a ``YAMLRoot`` object without the wonky indentation of pformat. + see ``YAMLRoot.__repr__``. + + formatting is similar to black - items at similar levels of nesting have similar levels of indentation, + rather than getting placed at essentially random levels of indentation depending on what came before them. + """ + res = [] + total_len = 0 + for key, val in fields.items(): + if val == [] or val == {} or val is None: + continue + # pformat handles everything else that isn't a YAMLRoot object, but it sure does look ugly + # use it to split lines and as the thing of last resort, but otherwise indent = 0, we'll do that + val_str = _pformat(val, indent=0, compact=True, sort_dicts=False) + # now we indent everything except the first line by indenting and then using regex to remove just the first indent + val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent)) + # now recombine with the key in a format that can be re-eval'd into an object if indent is just whitespace + val_str = f"'{key}': " + val_str + + # count the total length of this string so we know if we need to linebreak or not later + total_len += len(val_str) + res.append(val_str) + + if total_len > 80: + inside = ",\n".join(res) + # we indent twice - once for the inner contents of every inner object, and one to + # offset from the root element. that keeps us from needing to be recursive except for the + # single pformat call + inside = textwrap.indent(inside, indent) + return cls_name + "({\n" + inside + "\n})" + else: + return cls_name + "({" + ", ".join(res) + "})" diff --git a/scripts/generate_core.py b/scripts/generate_core.py index 4aeb21a..55fc94e 100644 --- a/scripts/generate_core.py +++ b/scripts/generate_core.py @@ -67,7 +67,6 @@ def generate_versions( pydantic_path: Path, dry_run: bool = False, repo: GitRepo = NWB_CORE_REPO, - hdmf_only=False, pdb=False, ): """ @@ -253,10 +252,10 @@ def main(): args.yaml.mkdir(exist_ok=True) args.pydantic.mkdir(exist_ok=True) if args.latest: - generate_core_yaml(args.yaml, args.dry_run, args.hdmf) + generate_core_yaml(args.yaml, args.dry_run) generate_core_pydantic(args.yaml, args.pydantic, args.dry_run) else: - generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf, pdb=args.pdb) + generate_versions(args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb) if __name__ == "__main__":