get ting there, working rolldown of extra attributes, but something still funny in patchclampseries children w.r.t. losing attributes in data

This commit is contained in:
sneakers-the-rat 2024-09-13 23:05:34 -07:00
parent 749703e077
commit cad57554fd
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
7 changed files with 264 additions and 165 deletions

View file

@ -8,7 +8,6 @@ for extracting information and generating translated schema
import contextlib import contextlib
from copy import copy from copy import copy
from pathlib import Path from pathlib import Path
from pprint import pformat
from typing import Dict, Generator, List, Optional from typing import Dict, Generator, List, Optional
from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.dumpers import yaml_dumper
@ -19,7 +18,6 @@ from nwb_linkml.adapters.adapter import Adapter, BuildResult
from nwb_linkml.adapters.schema import SchemaAdapter from nwb_linkml.adapters.schema import SchemaAdapter
from nwb_linkml.lang_elements import NwbLangSchema from nwb_linkml.lang_elements import NwbLangSchema
from nwb_linkml.ui import AdapterProgress from nwb_linkml.ui import AdapterProgress
from nwb_linkml.util import merge_dicts
from nwb_schema_language import Dataset, Group, Namespaces from nwb_schema_language import Dataset, Group, Namespaces
@ -188,93 +186,105 @@ class NamespacesAdapter(Adapter):
if not cls.neurodata_type_inc: if not cls.neurodata_type_inc:
continue continue
# get parents parents = self._get_class_ancestors(cls, include_child=True)
# merge and cast
new_cls: dict = {}
for i, parent in enumerate(parents):
# if parent.neurodata_type_def == "PatchClampSeries":
# pdb.set_trace()
complete = True
if i == len(parents) - 1:
complete = False
new_cls = roll_down_nwb_class(new_cls, parent, complete=complete)
new_cls: Group | Dataset = type(cls)(**new_cls)
new_cls.parent = cls.parent
# reinsert
self._overwrite_class(new_cls, cls)
def _get_class_ancestors(
self, cls: Dataset | Group, include_child: bool = True
) -> list[Dataset | Group]:
"""
Get the chain of ancestor classes inherited via ``neurodata_type_inc``
Args:
cls (:class:`.Dataset` | :class:`.Group`): The class to get ancestors of
include_child (bool): If ``True`` (default), include ``cls`` in the output list
"""
parent = self.get(cls.neurodata_type_inc) parent = self.get(cls.neurodata_type_inc)
parents = [parent] parents = [parent]
while parent.neurodata_type_inc: while parent.neurodata_type_inc:
parent = self.get(parent.neurodata_type_inc) parent = self.get(parent.neurodata_type_inc)
parents.insert(0, parent) parents.insert(0, parent)
if include_child:
parents.append(cls) parents.append(cls)
# merge and cast return parents
# note that we don't want to exclude_none in the model dump here,
# if the child class has a field completely unset, we want to inherit it
# from the parent without rolling it down - we are only rolling down
# the things that need to be modified/merged in the child
new_cls: dict = {}
for parent in parents:
new_cls = merge_dicts(
new_cls,
parent.model_dump(exclude_unset=True),
list_key="name",
exclude=["neurodata_type_def"],
)
new_cls: Group | Dataset = type(cls)(**new_cls)
new_cls.parent = cls.parent
# reinsert def _overwrite_class(self, new_cls: Dataset | Group, old_cls: Dataset | Group):
if new_cls.parent: """
if isinstance(cls, Dataset): Overwrite the version of a dataset or group that is stored in our schemas
new_cls.parent.datasets[new_cls.parent.datasets.index(cls)] = new_cls """
if old_cls.parent:
if isinstance(old_cls, Dataset):
new_cls.parent.datasets[new_cls.parent.datasets.index(old_cls)] = new_cls
else: else:
new_cls.parent.groups[new_cls.parent.groups.index(cls)] = new_cls new_cls.parent.groups[new_cls.parent.groups.index(old_cls)] = new_cls
else: else:
# top level class, need to go and find it # top level class, need to go and find it
found = False schema = self.find_type_source(old_cls)
if isinstance(new_cls, Dataset):
schema.datasets[schema.datasets.index(old_cls)] = new_cls
else:
schema.groups[schema.groups.index(old_cls)] = new_cls
def find_type_source(self, cls: str | Dataset | Group, fast: bool = False) -> SchemaAdapter:
"""
Given some type (as `neurodata_type_def`), find the schema that it's defined in.
Rather than returning as soon as a match is found, ensure that duplicates are
not found within the primary schema, then so the same for all imported schemas.
Args:
cls (str | :class:`.Dataset` | :class:`.Group`): The ``neurodata_type_def``
to look for the source of. If a Dataset or Group, look for the object itself
(cls in schema.datasets), otherwise look for a class with a matching name.
fast (bool): If ``True``, return as soon as a match is found.
If ``False`, return after checking all schemas for duplicates.
Returns:
:class:`.SchemaAdapter`
Raises:
KeyError: if multiple schemas or no schemas are found
"""
matches = []
for schema in self.all_schemas(): for schema in self.all_schemas():
if isinstance(cls, Dataset): in_schema = False
if cls in schema.datasets: if isinstance(cls, str) and cls in [
schema.datasets[schema.datasets.index(cls)] = new_cls c.neurodata_type_def for c in schema.created_classes
found = True ]:
break in_schema = True
elif isinstance(cls, Dataset) and cls in schema.datasets:
in_schema = True
elif isinstance(cls, Group) and cls in schema.groups:
in_schema = True
if in_schema:
if fast:
return schema
else: else:
if cls in schema.groups: matches.append(schema)
schema.groups[schema.groups.index(cls)] = new_cls
found = True
break
if not found:
raise KeyError(
f"Unable to find source schema for {cls} when reinserting after rolling"
" down!"
)
def find_type_source(self, name: str) -> SchemaAdapter: if len(matches) > 1:
""" raise KeyError(f"Found multiple schemas in namespace that define {cls}:\n{matches}")
Given some neurodata_type_inc, find the schema that it's defined in. elif len(matches) == 1:
return matches[0]
Rather than returning as soon as a match is found, check all
"""
# First check within the main schema
internal_matches = []
for schema in self.schemas:
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
if name in class_names:
internal_matches.append(schema)
if len(internal_matches) > 1:
raise KeyError(
f"Found multiple schemas in namespace that define {name}:\ninternal:"
f" {pformat(internal_matches)}\nimported:{pformat(internal_matches)}"
)
elif len(internal_matches) == 1:
return internal_matches[0]
import_matches = []
for imported_ns in self.imported:
for schema in imported_ns.schemas:
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
if name in class_names:
import_matches.append(schema)
if len(import_matches) > 1:
raise KeyError(
f"Found multiple schemas in namespace that define {name}:\ninternal:"
f" {pformat(internal_matches)}\nimported:{pformat(import_matches)}"
)
elif len(import_matches) == 1:
return import_matches[0]
else: else:
raise KeyError(f"No schema found that define {name}") raise KeyError(f"No schema found that define {cls}")
def _populate_imports(self) -> "NamespacesAdapter": def _populate_imports(self) -> "NamespacesAdapter":
""" """
@ -378,3 +388,99 @@ class NamespacesAdapter(Adapter):
for imported in self.imported: for imported in self.imported:
for sch in imported.schemas: for sch in imported.schemas:
yield sch yield sch
def roll_down_nwb_class(
source: Group | Dataset | dict, target: Group | Dataset | dict, complete: bool = False
) -> dict:
"""
Merge an ancestor (via ``neurodata_type_inc`` ) source class with a
child ``target`` class.
On the first recurive pass, only those values that are set on the target are copied from the
source class - this isn't a true merging, what we are after is to recursively merge all the
values that are modified in the child class with those of the parent class below the top level,
the top-level attributes will be carried through via normal inheritance.
Rather than re-instantiating the child class, we return the dictionary so that this
function can be used in series to merge a whole ancestry chain within
:class:`.NamespacesAdapter` , but this isn't exposed in the function since
class definitions can be spread out over many schemas, and we need the orchestration
of the adapter to have them in all cases we'd be using this.
Args:
source (dict): source dictionary
target (dict): target dictionary (values merged over source)
complete (bool): (default ``False``)do a complete merge, merging everything
from source to target without trying to minimize redundancy.
Used to collapse ancestor classes before the terminal class.
References:
https://github.com/NeurodataWithoutBorders/pynwb/issues/1954
"""
if isinstance(source, (Group, Dataset)):
source = source.model_dump(exclude_unset=True, exclude_none=True)
if isinstance(target, (Group, Dataset)):
target = target.model_dump(exclude_unset=True, exclude_none=True)
exclude = ("neurodata_type_def",)
# if we are on the first recursion, we exclude top-level items that are not set in the target
if complete:
ret = {k: v for k, v in source.items() if k not in exclude}
else:
ret = {k: v for k, v in source.items() if k not in exclude and k in target}
for key, value in target.items():
if key not in ret:
ret[key] = value
elif isinstance(value, dict):
if key in ret:
ret[key] = roll_down_nwb_class(ret[key], value, complete=True)
else:
ret[key] = value
elif isinstance(value, list) and all([isinstance(v, dict) for v in value]):
src_keys = {v["name"]: ret[key].index(v) for v in ret.get(key, {}) if "name" in v}
target_keys = {v["name"]: value.index(v) for v in value if "name" in v}
new_val = []
# screwy double iteration to preserve dict order
# all dicts not in target, if in depth > 0
if complete:
new_val.extend(
[
ret[key][src_keys[k]]
for k in src_keys
if k in set(src_keys.keys()) - set(target_keys.keys())
]
)
# all dicts not in source
new_val.extend(
[
value[target_keys[k]]
for k in target_keys
if k in set(target_keys.keys()) - set(src_keys.keys())
]
)
# merge dicts in both
new_val.extend(
[
roll_down_nwb_class(ret[key][src_keys[k]], value[target_keys[k]], complete=True)
for k in target_keys
if k in set(src_keys.keys()).intersection(set(target_keys.keys()))
]
)
new_val = sorted(new_val, key=lambda i: i["name"])
# add any dicts that don't have the list_key
# they can't be merged since they can't be matched
if complete:
new_val.extend([v for v in ret.get(key, {}) if "name" not in v])
new_val.extend([v for v in value if "name" not in v])
ret[key] = new_val
else:
ret[key] = value
return ret

View file

@ -136,7 +136,7 @@ class NWBPydanticGenerator(PydanticGenerator):
"""Customize dynamictable behavior""" """Customize dynamictable behavior"""
cls = AfterGenerateClass.inject_dynamictable(cls) cls = AfterGenerateClass.inject_dynamictable(cls)
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv) cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import) cls = AfterGenerateClass.inject_dynamictable_imports(cls, sv, self._get_element_import)
cls = AfterGenerateClass.strip_vector_data_slots(cls, sv) cls = AfterGenerateClass.strip_vector_data_slots(cls, sv)
return cls return cls
@ -346,19 +346,22 @@ class AfterGenerateClass:
return cls return cls
@staticmethod @staticmethod
def inject_elementidentifiers( def inject_dynamictable_imports(
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import] cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
) -> ClassResult: ) -> ClassResult:
""" """
Inject ElementIdentifiers into module that define dynamictables - Ensure that schema that contain dynamictables have all the imports needed to use them
needed to handle ID columns
""" """
if ( if (
cls.source.is_a == "DynamicTable" cls.source.is_a == "DynamicTable"
or "DynamicTable" in sv.class_ancestors(cls.source.name) or "DynamicTable" in sv.class_ancestors(cls.source.name)
) and sv.schema.name != "hdmf-common.table": ) and sv.schema.name != "hdmf-common.table":
imp = import_method("ElementIdentifiers") imp = [
cls.imports += [imp] import_method("ElementIdentifiers"),
import_method("VectorData"),
import_method("VectorIndex"),
]
cls.imports += imp
return cls return cls
@staticmethod @staticmethod

View file

@ -1,73 +0,0 @@
"""
The much maligned junk drawer
"""
def merge_dicts(
source: dict, target: dict, list_key: str | None = None, exclude: list[str] | None = None
) -> dict:
"""
Deeply merge nested dictionaries, replacing already-declared keys rather than
e.g. merging lists as well
Args:
source (dict): source dictionary
target (dict): target dictionary (values merged over source)
list_key (str | None): Optional: if present, merge lists of dicts using this to
identify matching dicts
exclude: (list[str] | None): Optional: if present, exclude keys from parent.
References:
https://stackoverflow.com/a/20666342/13113166
"""
if exclude is None:
exclude = []
ret = {k: v for k, v in source.items() if k not in exclude}
for key, value in target.items():
if key not in ret:
ret[key] = value
elif isinstance(value, dict):
if key in ret:
ret[key] = merge_dicts(ret[key], value, list_key, exclude)
else:
ret[key] = value
elif isinstance(value, list) and list_key and all([isinstance(v, dict) for v in value]):
src_keys = {v[list_key]: ret[key].index(v) for v in ret.get(key, {}) if list_key in v}
target_keys = {v[list_key]: value.index(v) for v in value if list_key in v}
# all dicts not in target
# screwy double iteration to preserve dict order
new_val = [
ret[key][src_keys[k]]
for k in src_keys
if k in set(src_keys.keys()) - set(target_keys.keys())
]
# all dicts not in source
new_val.extend(
[
value[target_keys[k]]
for k in target_keys
if k in set(target_keys.keys()) - set(src_keys.keys())
]
)
# merge dicts in both
new_val.extend(
[
merge_dicts(ret[key][src_keys[k]], value[target_keys[k]], list_key, exclude)
for k in target_keys
if k in set(src_keys.keys()).intersection(set(target_keys.keys()))
]
)
new_val = sorted(new_val, key=lambda i: i[list_key])
# add any dicts that don't have the list_key
# they can't be merged since they can't be matched
new_val.extend([v for v in ret.get(key, {}) if list_key not in v])
new_val.extend([v for v in value if list_key not in v])
ret[key] = new_val
else:
ret[key] = value
return ret

View file

@ -7,6 +7,7 @@ from decimal import Decimal
from enum import Enum from enum import Enum
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
from nwb_schema_language.util import pformat
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
@ -23,7 +24,12 @@ class ConfiguredBaseModel(BaseModel):
use_enum_values=True, use_enum_values=True,
strict=False, strict=False,
) )
pass
def __repr__(self):
return pformat(self.model_dump(exclude={"parent": True}), self.__class__.__name__)
def __str__(self):
return repr(self)
class LinkMLMeta(RootModel): class LinkMLMeta(RootModel):
@ -44,9 +50,10 @@ class LinkMLMeta(RootModel):
class ParentizeMixin(BaseModel): class ParentizeMixin(BaseModel):
"""Mixin to populate the parent field for nested datasets and groups"""
@model_validator(mode="after") @model_validator(mode="after")
def parentize(self): def parentize(self) -> BaseModel:
"""Set the parent attribute for all our fields they have one""" """Set the parent attribute for all our fields they have one"""
for field_name in self.model_fields: for field_name in self.model_fields:
if field_name == "parent": if field_name == "parent":

View file

@ -31,6 +31,22 @@ class ParentizeMixin(BaseModel):
return self return self
STR_METHOD = """
def __repr__(self):
return pformat(
self.model_dump(
exclude={"parent": True},
exclude_unset=True,
exclude_None=True
),
self.__class__.__name__
)
def __str__(self):
return repr(self)
"""
@dataclass @dataclass
class NWBSchemaLangGenerator(PydanticGenerator): class NWBSchemaLangGenerator(PydanticGenerator):
""" """
@ -40,8 +56,10 @@ class NWBSchemaLangGenerator(PydanticGenerator):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs["injected_classes"] = [ParentizeMixin] kwargs["injected_classes"] = [ParentizeMixin]
kwargs["imports"] = [ kwargs["imports"] = [
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]) Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
Import(module="nwb_schema_language.util", objects=[ObjectImport(name="pformat")]),
] ]
kwargs["injected_fields"] = [STR_METHOD]
kwargs["black"] = True kwargs["black"] = True
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View file

@ -0,0 +1,39 @@
from pprint import pformat as _pformat
import textwrap
import re
def pformat(fields: dict, cls_name: str, indent: str = " ") -> str:
"""
pretty format the fields of the items of a ``YAMLRoot`` object without the wonky indentation of pformat.
see ``YAMLRoot.__repr__``.
formatting is similar to black - items at similar levels of nesting have similar levels of indentation,
rather than getting placed at essentially random levels of indentation depending on what came before them.
"""
res = []
total_len = 0
for key, val in fields.items():
if val == [] or val == {} or val is None:
continue
# pformat handles everything else that isn't a YAMLRoot object, but it sure does look ugly
# use it to split lines and as the thing of last resort, but otherwise indent = 0, we'll do that
val_str = _pformat(val, indent=0, compact=True, sort_dicts=False)
# now we indent everything except the first line by indenting and then using regex to remove just the first indent
val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent))
# now recombine with the key in a format that can be re-eval'd into an object if indent is just whitespace
val_str = f"'{key}': " + val_str
# count the total length of this string so we know if we need to linebreak or not later
total_len += len(val_str)
res.append(val_str)
if total_len > 80:
inside = ",\n".join(res)
# we indent twice - once for the inner contents of every inner object, and one to
# offset from the root element. that keeps us from needing to be recursive except for the
# single pformat call
inside = textwrap.indent(inside, indent)
return cls_name + "({\n" + inside + "\n})"
else:
return cls_name + "({" + ", ".join(res) + "})"

View file

@ -67,7 +67,6 @@ def generate_versions(
pydantic_path: Path, pydantic_path: Path,
dry_run: bool = False, dry_run: bool = False,
repo: GitRepo = NWB_CORE_REPO, repo: GitRepo = NWB_CORE_REPO,
hdmf_only=False,
pdb=False, pdb=False,
): ):
""" """
@ -253,10 +252,10 @@ def main():
args.yaml.mkdir(exist_ok=True) args.yaml.mkdir(exist_ok=True)
args.pydantic.mkdir(exist_ok=True) args.pydantic.mkdir(exist_ok=True)
if args.latest: if args.latest:
generate_core_yaml(args.yaml, args.dry_run, args.hdmf) generate_core_yaml(args.yaml, args.dry_run)
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run) generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
else: else:
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf, pdb=args.pdb) generate_versions(args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb)
if __name__ == "__main__": if __name__ == "__main__":