mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
get ting there, working rolldown of extra attributes, but something still funny in patchclampseries children w.r.t. losing attributes in data
This commit is contained in:
parent
749703e077
commit
cad57554fd
7 changed files with 264 additions and 165 deletions
|
@ -8,7 +8,6 @@ for extracting information and generating translated schema
|
|||
import contextlib
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from linkml_runtime.dumpers import yaml_dumper
|
||||
|
@ -19,7 +18,6 @@ from nwb_linkml.adapters.adapter import Adapter, BuildResult
|
|||
from nwb_linkml.adapters.schema import SchemaAdapter
|
||||
from nwb_linkml.lang_elements import NwbLangSchema
|
||||
from nwb_linkml.ui import AdapterProgress
|
||||
from nwb_linkml.util import merge_dicts
|
||||
from nwb_schema_language import Dataset, Group, Namespaces
|
||||
|
||||
|
||||
|
@ -188,93 +186,105 @@ class NamespacesAdapter(Adapter):
|
|||
if not cls.neurodata_type_inc:
|
||||
continue
|
||||
|
||||
# get parents
|
||||
parent = self.get(cls.neurodata_type_inc)
|
||||
parents = [parent]
|
||||
while parent.neurodata_type_inc:
|
||||
parent = self.get(parent.neurodata_type_inc)
|
||||
parents.insert(0, parent)
|
||||
parents.append(cls)
|
||||
parents = self._get_class_ancestors(cls, include_child=True)
|
||||
|
||||
# merge and cast
|
||||
# note that we don't want to exclude_none in the model dump here,
|
||||
# if the child class has a field completely unset, we want to inherit it
|
||||
# from the parent without rolling it down - we are only rolling down
|
||||
# the things that need to be modified/merged in the child
|
||||
new_cls: dict = {}
|
||||
for parent in parents:
|
||||
new_cls = merge_dicts(
|
||||
new_cls,
|
||||
parent.model_dump(exclude_unset=True),
|
||||
list_key="name",
|
||||
exclude=["neurodata_type_def"],
|
||||
)
|
||||
for i, parent in enumerate(parents):
|
||||
# if parent.neurodata_type_def == "PatchClampSeries":
|
||||
# pdb.set_trace()
|
||||
complete = True
|
||||
if i == len(parents) - 1:
|
||||
complete = False
|
||||
new_cls = roll_down_nwb_class(new_cls, parent, complete=complete)
|
||||
new_cls: Group | Dataset = type(cls)(**new_cls)
|
||||
new_cls.parent = cls.parent
|
||||
|
||||
# reinsert
|
||||
if new_cls.parent:
|
||||
if isinstance(cls, Dataset):
|
||||
new_cls.parent.datasets[new_cls.parent.datasets.index(cls)] = new_cls
|
||||
else:
|
||||
new_cls.parent.groups[new_cls.parent.groups.index(cls)] = new_cls
|
||||
self._overwrite_class(new_cls, cls)
|
||||
|
||||
def _get_class_ancestors(
|
||||
self, cls: Dataset | Group, include_child: bool = True
|
||||
) -> list[Dataset | Group]:
|
||||
"""
|
||||
Get the chain of ancestor classes inherited via ``neurodata_type_inc``
|
||||
|
||||
Args:
|
||||
cls (:class:`.Dataset` | :class:`.Group`): The class to get ancestors of
|
||||
include_child (bool): If ``True`` (default), include ``cls`` in the output list
|
||||
"""
|
||||
parent = self.get(cls.neurodata_type_inc)
|
||||
parents = [parent]
|
||||
while parent.neurodata_type_inc:
|
||||
parent = self.get(parent.neurodata_type_inc)
|
||||
parents.insert(0, parent)
|
||||
|
||||
if include_child:
|
||||
parents.append(cls)
|
||||
|
||||
return parents
|
||||
|
||||
def _overwrite_class(self, new_cls: Dataset | Group, old_cls: Dataset | Group):
|
||||
"""
|
||||
Overwrite the version of a dataset or group that is stored in our schemas
|
||||
"""
|
||||
if old_cls.parent:
|
||||
if isinstance(old_cls, Dataset):
|
||||
new_cls.parent.datasets[new_cls.parent.datasets.index(old_cls)] = new_cls
|
||||
else:
|
||||
# top level class, need to go and find it
|
||||
found = False
|
||||
for schema in self.all_schemas():
|
||||
if isinstance(cls, Dataset):
|
||||
if cls in schema.datasets:
|
||||
schema.datasets[schema.datasets.index(cls)] = new_cls
|
||||
found = True
|
||||
break
|
||||
else:
|
||||
if cls in schema.groups:
|
||||
schema.groups[schema.groups.index(cls)] = new_cls
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
raise KeyError(
|
||||
f"Unable to find source schema for {cls} when reinserting after rolling"
|
||||
" down!"
|
||||
)
|
||||
|
||||
def find_type_source(self, name: str) -> SchemaAdapter:
|
||||
"""
|
||||
Given some neurodata_type_inc, find the schema that it's defined in.
|
||||
|
||||
Rather than returning as soon as a match is found, check all
|
||||
"""
|
||||
# First check within the main schema
|
||||
internal_matches = []
|
||||
for schema in self.schemas:
|
||||
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
|
||||
if name in class_names:
|
||||
internal_matches.append(schema)
|
||||
|
||||
if len(internal_matches) > 1:
|
||||
raise KeyError(
|
||||
f"Found multiple schemas in namespace that define {name}:\ninternal:"
|
||||
f" {pformat(internal_matches)}\nimported:{pformat(internal_matches)}"
|
||||
)
|
||||
elif len(internal_matches) == 1:
|
||||
return internal_matches[0]
|
||||
|
||||
import_matches = []
|
||||
for imported_ns in self.imported:
|
||||
for schema in imported_ns.schemas:
|
||||
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
|
||||
if name in class_names:
|
||||
import_matches.append(schema)
|
||||
|
||||
if len(import_matches) > 1:
|
||||
raise KeyError(
|
||||
f"Found multiple schemas in namespace that define {name}:\ninternal:"
|
||||
f" {pformat(internal_matches)}\nimported:{pformat(import_matches)}"
|
||||
)
|
||||
elif len(import_matches) == 1:
|
||||
return import_matches[0]
|
||||
new_cls.parent.groups[new_cls.parent.groups.index(old_cls)] = new_cls
|
||||
else:
|
||||
raise KeyError(f"No schema found that define {name}")
|
||||
# top level class, need to go and find it
|
||||
schema = self.find_type_source(old_cls)
|
||||
if isinstance(new_cls, Dataset):
|
||||
schema.datasets[schema.datasets.index(old_cls)] = new_cls
|
||||
else:
|
||||
schema.groups[schema.groups.index(old_cls)] = new_cls
|
||||
|
||||
def find_type_source(self, cls: str | Dataset | Group, fast: bool = False) -> SchemaAdapter:
|
||||
"""
|
||||
Given some type (as `neurodata_type_def`), find the schema that it's defined in.
|
||||
|
||||
Rather than returning as soon as a match is found, ensure that duplicates are
|
||||
not found within the primary schema, then so the same for all imported schemas.
|
||||
|
||||
Args:
|
||||
cls (str | :class:`.Dataset` | :class:`.Group`): The ``neurodata_type_def``
|
||||
to look for the source of. If a Dataset or Group, look for the object itself
|
||||
(cls in schema.datasets), otherwise look for a class with a matching name.
|
||||
fast (bool): If ``True``, return as soon as a match is found.
|
||||
If ``False`, return after checking all schemas for duplicates.
|
||||
|
||||
Returns:
|
||||
:class:`.SchemaAdapter`
|
||||
|
||||
Raises:
|
||||
KeyError: if multiple schemas or no schemas are found
|
||||
"""
|
||||
matches = []
|
||||
for schema in self.all_schemas():
|
||||
in_schema = False
|
||||
if isinstance(cls, str) and cls in [
|
||||
c.neurodata_type_def for c in schema.created_classes
|
||||
]:
|
||||
in_schema = True
|
||||
elif isinstance(cls, Dataset) and cls in schema.datasets:
|
||||
in_schema = True
|
||||
elif isinstance(cls, Group) and cls in schema.groups:
|
||||
in_schema = True
|
||||
|
||||
if in_schema:
|
||||
if fast:
|
||||
return schema
|
||||
else:
|
||||
matches.append(schema)
|
||||
|
||||
if len(matches) > 1:
|
||||
raise KeyError(f"Found multiple schemas in namespace that define {cls}:\n{matches}")
|
||||
elif len(matches) == 1:
|
||||
return matches[0]
|
||||
else:
|
||||
raise KeyError(f"No schema found that define {cls}")
|
||||
|
||||
def _populate_imports(self) -> "NamespacesAdapter":
|
||||
"""
|
||||
|
@ -378,3 +388,99 @@ class NamespacesAdapter(Adapter):
|
|||
for imported in self.imported:
|
||||
for sch in imported.schemas:
|
||||
yield sch
|
||||
|
||||
|
||||
def roll_down_nwb_class(
|
||||
source: Group | Dataset | dict, target: Group | Dataset | dict, complete: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Merge an ancestor (via ``neurodata_type_inc`` ) source class with a
|
||||
child ``target`` class.
|
||||
|
||||
On the first recurive pass, only those values that are set on the target are copied from the
|
||||
source class - this isn't a true merging, what we are after is to recursively merge all the
|
||||
values that are modified in the child class with those of the parent class below the top level,
|
||||
the top-level attributes will be carried through via normal inheritance.
|
||||
|
||||
Rather than re-instantiating the child class, we return the dictionary so that this
|
||||
function can be used in series to merge a whole ancestry chain within
|
||||
:class:`.NamespacesAdapter` , but this isn't exposed in the function since
|
||||
class definitions can be spread out over many schemas, and we need the orchestration
|
||||
of the adapter to have them in all cases we'd be using this.
|
||||
|
||||
Args:
|
||||
source (dict): source dictionary
|
||||
target (dict): target dictionary (values merged over source)
|
||||
complete (bool): (default ``False``)do a complete merge, merging everything
|
||||
from source to target without trying to minimize redundancy.
|
||||
Used to collapse ancestor classes before the terminal class.
|
||||
|
||||
References:
|
||||
https://github.com/NeurodataWithoutBorders/pynwb/issues/1954
|
||||
|
||||
"""
|
||||
if isinstance(source, (Group, Dataset)):
|
||||
source = source.model_dump(exclude_unset=True, exclude_none=True)
|
||||
if isinstance(target, (Group, Dataset)):
|
||||
target = target.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
exclude = ("neurodata_type_def",)
|
||||
|
||||
# if we are on the first recursion, we exclude top-level items that are not set in the target
|
||||
if complete:
|
||||
ret = {k: v for k, v in source.items() if k not in exclude}
|
||||
else:
|
||||
ret = {k: v for k, v in source.items() if k not in exclude and k in target}
|
||||
|
||||
for key, value in target.items():
|
||||
if key not in ret:
|
||||
ret[key] = value
|
||||
elif isinstance(value, dict):
|
||||
if key in ret:
|
||||
ret[key] = roll_down_nwb_class(ret[key], value, complete=True)
|
||||
else:
|
||||
ret[key] = value
|
||||
elif isinstance(value, list) and all([isinstance(v, dict) for v in value]):
|
||||
src_keys = {v["name"]: ret[key].index(v) for v in ret.get(key, {}) if "name" in v}
|
||||
target_keys = {v["name"]: value.index(v) for v in value if "name" in v}
|
||||
|
||||
new_val = []
|
||||
# screwy double iteration to preserve dict order
|
||||
# all dicts not in target, if in depth > 0
|
||||
if complete:
|
||||
new_val.extend(
|
||||
[
|
||||
ret[key][src_keys[k]]
|
||||
for k in src_keys
|
||||
if k in set(src_keys.keys()) - set(target_keys.keys())
|
||||
]
|
||||
)
|
||||
# all dicts not in source
|
||||
new_val.extend(
|
||||
[
|
||||
value[target_keys[k]]
|
||||
for k in target_keys
|
||||
if k in set(target_keys.keys()) - set(src_keys.keys())
|
||||
]
|
||||
)
|
||||
# merge dicts in both
|
||||
new_val.extend(
|
||||
[
|
||||
roll_down_nwb_class(ret[key][src_keys[k]], value[target_keys[k]], complete=True)
|
||||
for k in target_keys
|
||||
if k in set(src_keys.keys()).intersection(set(target_keys.keys()))
|
||||
]
|
||||
)
|
||||
new_val = sorted(new_val, key=lambda i: i["name"])
|
||||
# add any dicts that don't have the list_key
|
||||
# they can't be merged since they can't be matched
|
||||
if complete:
|
||||
new_val.extend([v for v in ret.get(key, {}) if "name" not in v])
|
||||
new_val.extend([v for v in value if "name" not in v])
|
||||
|
||||
ret[key] = new_val
|
||||
|
||||
else:
|
||||
ret[key] = value
|
||||
|
||||
return ret
|
||||
|
|
|
@ -136,7 +136,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
"""Customize dynamictable behavior"""
|
||||
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
||||
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
|
||||
cls = AfterGenerateClass.inject_dynamictable_imports(cls, sv, self._get_element_import)
|
||||
cls = AfterGenerateClass.strip_vector_data_slots(cls, sv)
|
||||
return cls
|
||||
|
||||
|
@ -346,19 +346,22 @@ class AfterGenerateClass:
|
|||
return cls
|
||||
|
||||
@staticmethod
|
||||
def inject_elementidentifiers(
|
||||
def inject_dynamictable_imports(
|
||||
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
|
||||
) -> ClassResult:
|
||||
"""
|
||||
Inject ElementIdentifiers into module that define dynamictables -
|
||||
needed to handle ID columns
|
||||
Ensure that schema that contain dynamictables have all the imports needed to use them
|
||||
"""
|
||||
if (
|
||||
cls.source.is_a == "DynamicTable"
|
||||
or "DynamicTable" in sv.class_ancestors(cls.source.name)
|
||||
) and sv.schema.name != "hdmf-common.table":
|
||||
imp = import_method("ElementIdentifiers")
|
||||
cls.imports += [imp]
|
||||
imp = [
|
||||
import_method("ElementIdentifiers"),
|
||||
import_method("VectorData"),
|
||||
import_method("VectorIndex"),
|
||||
]
|
||||
cls.imports += imp
|
||||
return cls
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
"""
|
||||
The much maligned junk drawer
|
||||
"""
|
||||
|
||||
|
||||
def merge_dicts(
|
||||
source: dict, target: dict, list_key: str | None = None, exclude: list[str] | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
Deeply merge nested dictionaries, replacing already-declared keys rather than
|
||||
e.g. merging lists as well
|
||||
|
||||
Args:
|
||||
source (dict): source dictionary
|
||||
target (dict): target dictionary (values merged over source)
|
||||
list_key (str | None): Optional: if present, merge lists of dicts using this to
|
||||
identify matching dicts
|
||||
exclude: (list[str] | None): Optional: if present, exclude keys from parent.
|
||||
|
||||
References:
|
||||
https://stackoverflow.com/a/20666342/13113166
|
||||
|
||||
"""
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
ret = {k: v for k, v in source.items() if k not in exclude}
|
||||
for key, value in target.items():
|
||||
if key not in ret:
|
||||
ret[key] = value
|
||||
elif isinstance(value, dict):
|
||||
if key in ret:
|
||||
ret[key] = merge_dicts(ret[key], value, list_key, exclude)
|
||||
else:
|
||||
ret[key] = value
|
||||
elif isinstance(value, list) and list_key and all([isinstance(v, dict) for v in value]):
|
||||
src_keys = {v[list_key]: ret[key].index(v) for v in ret.get(key, {}) if list_key in v}
|
||||
target_keys = {v[list_key]: value.index(v) for v in value if list_key in v}
|
||||
|
||||
# all dicts not in target
|
||||
# screwy double iteration to preserve dict order
|
||||
new_val = [
|
||||
ret[key][src_keys[k]]
|
||||
for k in src_keys
|
||||
if k in set(src_keys.keys()) - set(target_keys.keys())
|
||||
]
|
||||
# all dicts not in source
|
||||
new_val.extend(
|
||||
[
|
||||
value[target_keys[k]]
|
||||
for k in target_keys
|
||||
if k in set(target_keys.keys()) - set(src_keys.keys())
|
||||
]
|
||||
)
|
||||
# merge dicts in both
|
||||
new_val.extend(
|
||||
[
|
||||
merge_dicts(ret[key][src_keys[k]], value[target_keys[k]], list_key, exclude)
|
||||
for k in target_keys
|
||||
if k in set(src_keys.keys()).intersection(set(target_keys.keys()))
|
||||
]
|
||||
)
|
||||
new_val = sorted(new_val, key=lambda i: i[list_key])
|
||||
# add any dicts that don't have the list_key
|
||||
# they can't be merged since they can't be matched
|
||||
new_val.extend([v for v in ret.get(key, {}) if list_key not in v])
|
||||
new_val.extend([v for v in value if list_key not in v])
|
||||
|
||||
ret[key] = new_val
|
||||
|
||||
else:
|
||||
ret[key] = value
|
||||
|
||||
return ret
|
|
@ -7,6 +7,7 @@ from decimal import Decimal
|
|||
from enum import Enum
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
|
||||
|
||||
from nwb_schema_language.util import pformat
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||
|
||||
|
||||
|
@ -23,7 +24,12 @@ class ConfiguredBaseModel(BaseModel):
|
|||
use_enum_values=True,
|
||||
strict=False,
|
||||
)
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return pformat(self.model_dump(exclude={"parent": True}), self.__class__.__name__)
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
|
||||
|
||||
class LinkMLMeta(RootModel):
|
||||
|
@ -44,9 +50,10 @@ class LinkMLMeta(RootModel):
|
|||
|
||||
|
||||
class ParentizeMixin(BaseModel):
|
||||
"""Mixin to populate the parent field for nested datasets and groups"""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def parentize(self):
|
||||
def parentize(self) -> BaseModel:
|
||||
"""Set the parent attribute for all our fields they have one"""
|
||||
for field_name in self.model_fields:
|
||||
if field_name == "parent":
|
||||
|
|
|
@ -31,6 +31,22 @@ class ParentizeMixin(BaseModel):
|
|||
return self
|
||||
|
||||
|
||||
STR_METHOD = """
|
||||
def __repr__(self):
|
||||
return pformat(
|
||||
self.model_dump(
|
||||
exclude={"parent": True},
|
||||
exclude_unset=True,
|
||||
exclude_None=True
|
||||
),
|
||||
self.__class__.__name__
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return repr(self)
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NWBSchemaLangGenerator(PydanticGenerator):
|
||||
"""
|
||||
|
@ -40,8 +56,10 @@ class NWBSchemaLangGenerator(PydanticGenerator):
|
|||
def __init__(self, *args, **kwargs):
|
||||
kwargs["injected_classes"] = [ParentizeMixin]
|
||||
kwargs["imports"] = [
|
||||
Import(module="pydantic", objects=[ObjectImport(name="model_validator")])
|
||||
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
|
||||
Import(module="nwb_schema_language.util", objects=[ObjectImport(name="pformat")]),
|
||||
]
|
||||
kwargs["injected_fields"] = [STR_METHOD]
|
||||
kwargs["black"] = True
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
39
nwb_schema_language/src/nwb_schema_language/util.py
Normal file
39
nwb_schema_language/src/nwb_schema_language/util.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
from pprint import pformat as _pformat
|
||||
import textwrap
|
||||
import re
|
||||
|
||||
|
||||
def pformat(fields: dict, cls_name: str, indent: str = " ") -> str:
|
||||
"""
|
||||
pretty format the fields of the items of a ``YAMLRoot`` object without the wonky indentation of pformat.
|
||||
see ``YAMLRoot.__repr__``.
|
||||
|
||||
formatting is similar to black - items at similar levels of nesting have similar levels of indentation,
|
||||
rather than getting placed at essentially random levels of indentation depending on what came before them.
|
||||
"""
|
||||
res = []
|
||||
total_len = 0
|
||||
for key, val in fields.items():
|
||||
if val == [] or val == {} or val is None:
|
||||
continue
|
||||
# pformat handles everything else that isn't a YAMLRoot object, but it sure does look ugly
|
||||
# use it to split lines and as the thing of last resort, but otherwise indent = 0, we'll do that
|
||||
val_str = _pformat(val, indent=0, compact=True, sort_dicts=False)
|
||||
# now we indent everything except the first line by indenting and then using regex to remove just the first indent
|
||||
val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent))
|
||||
# now recombine with the key in a format that can be re-eval'd into an object if indent is just whitespace
|
||||
val_str = f"'{key}': " + val_str
|
||||
|
||||
# count the total length of this string so we know if we need to linebreak or not later
|
||||
total_len += len(val_str)
|
||||
res.append(val_str)
|
||||
|
||||
if total_len > 80:
|
||||
inside = ",\n".join(res)
|
||||
# we indent twice - once for the inner contents of every inner object, and one to
|
||||
# offset from the root element. that keeps us from needing to be recursive except for the
|
||||
# single pformat call
|
||||
inside = textwrap.indent(inside, indent)
|
||||
return cls_name + "({\n" + inside + "\n})"
|
||||
else:
|
||||
return cls_name + "({" + ", ".join(res) + "})"
|
|
@ -67,7 +67,6 @@ def generate_versions(
|
|||
pydantic_path: Path,
|
||||
dry_run: bool = False,
|
||||
repo: GitRepo = NWB_CORE_REPO,
|
||||
hdmf_only=False,
|
||||
pdb=False,
|
||||
):
|
||||
"""
|
||||
|
@ -253,10 +252,10 @@ def main():
|
|||
args.yaml.mkdir(exist_ok=True)
|
||||
args.pydantic.mkdir(exist_ok=True)
|
||||
if args.latest:
|
||||
generate_core_yaml(args.yaml, args.dry_run, args.hdmf)
|
||||
generate_core_yaml(args.yaml, args.dry_run)
|
||||
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
|
||||
else:
|
||||
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf, pdb=args.pdb)
|
||||
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue