mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 06:04:28 +00:00
get ting there, working rolldown of extra attributes, but something still funny in patchclampseries children w.r.t. losing attributes in data
This commit is contained in:
parent
749703e077
commit
cad57554fd
7 changed files with 264 additions and 165 deletions
|
@ -8,7 +8,6 @@ for extracting information and generating translated schema
|
||||||
import contextlib
|
import contextlib
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
|
||||||
from typing import Dict, Generator, List, Optional
|
from typing import Dict, Generator, List, Optional
|
||||||
|
|
||||||
from linkml_runtime.dumpers import yaml_dumper
|
from linkml_runtime.dumpers import yaml_dumper
|
||||||
|
@ -19,7 +18,6 @@ from nwb_linkml.adapters.adapter import Adapter, BuildResult
|
||||||
from nwb_linkml.adapters.schema import SchemaAdapter
|
from nwb_linkml.adapters.schema import SchemaAdapter
|
||||||
from nwb_linkml.lang_elements import NwbLangSchema
|
from nwb_linkml.lang_elements import NwbLangSchema
|
||||||
from nwb_linkml.ui import AdapterProgress
|
from nwb_linkml.ui import AdapterProgress
|
||||||
from nwb_linkml.util import merge_dicts
|
|
||||||
from nwb_schema_language import Dataset, Group, Namespaces
|
from nwb_schema_language import Dataset, Group, Namespaces
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,93 +186,105 @@ class NamespacesAdapter(Adapter):
|
||||||
if not cls.neurodata_type_inc:
|
if not cls.neurodata_type_inc:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# get parents
|
parents = self._get_class_ancestors(cls, include_child=True)
|
||||||
parent = self.get(cls.neurodata_type_inc)
|
|
||||||
parents = [parent]
|
|
||||||
while parent.neurodata_type_inc:
|
|
||||||
parent = self.get(parent.neurodata_type_inc)
|
|
||||||
parents.insert(0, parent)
|
|
||||||
parents.append(cls)
|
|
||||||
|
|
||||||
# merge and cast
|
# merge and cast
|
||||||
# note that we don't want to exclude_none in the model dump here,
|
|
||||||
# if the child class has a field completely unset, we want to inherit it
|
|
||||||
# from the parent without rolling it down - we are only rolling down
|
|
||||||
# the things that need to be modified/merged in the child
|
|
||||||
new_cls: dict = {}
|
new_cls: dict = {}
|
||||||
for parent in parents:
|
for i, parent in enumerate(parents):
|
||||||
new_cls = merge_dicts(
|
# if parent.neurodata_type_def == "PatchClampSeries":
|
||||||
new_cls,
|
# pdb.set_trace()
|
||||||
parent.model_dump(exclude_unset=True),
|
complete = True
|
||||||
list_key="name",
|
if i == len(parents) - 1:
|
||||||
exclude=["neurodata_type_def"],
|
complete = False
|
||||||
)
|
new_cls = roll_down_nwb_class(new_cls, parent, complete=complete)
|
||||||
new_cls: Group | Dataset = type(cls)(**new_cls)
|
new_cls: Group | Dataset = type(cls)(**new_cls)
|
||||||
new_cls.parent = cls.parent
|
new_cls.parent = cls.parent
|
||||||
|
|
||||||
# reinsert
|
# reinsert
|
||||||
if new_cls.parent:
|
self._overwrite_class(new_cls, cls)
|
||||||
if isinstance(cls, Dataset):
|
|
||||||
new_cls.parent.datasets[new_cls.parent.datasets.index(cls)] = new_cls
|
def _get_class_ancestors(
|
||||||
else:
|
self, cls: Dataset | Group, include_child: bool = True
|
||||||
new_cls.parent.groups[new_cls.parent.groups.index(cls)] = new_cls
|
) -> list[Dataset | Group]:
|
||||||
|
"""
|
||||||
|
Get the chain of ancestor classes inherited via ``neurodata_type_inc``
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (:class:`.Dataset` | :class:`.Group`): The class to get ancestors of
|
||||||
|
include_child (bool): If ``True`` (default), include ``cls`` in the output list
|
||||||
|
"""
|
||||||
|
parent = self.get(cls.neurodata_type_inc)
|
||||||
|
parents = [parent]
|
||||||
|
while parent.neurodata_type_inc:
|
||||||
|
parent = self.get(parent.neurodata_type_inc)
|
||||||
|
parents.insert(0, parent)
|
||||||
|
|
||||||
|
if include_child:
|
||||||
|
parents.append(cls)
|
||||||
|
|
||||||
|
return parents
|
||||||
|
|
||||||
|
def _overwrite_class(self, new_cls: Dataset | Group, old_cls: Dataset | Group):
|
||||||
|
"""
|
||||||
|
Overwrite the version of a dataset or group that is stored in our schemas
|
||||||
|
"""
|
||||||
|
if old_cls.parent:
|
||||||
|
if isinstance(old_cls, Dataset):
|
||||||
|
new_cls.parent.datasets[new_cls.parent.datasets.index(old_cls)] = new_cls
|
||||||
else:
|
else:
|
||||||
# top level class, need to go and find it
|
new_cls.parent.groups[new_cls.parent.groups.index(old_cls)] = new_cls
|
||||||
found = False
|
|
||||||
for schema in self.all_schemas():
|
|
||||||
if isinstance(cls, Dataset):
|
|
||||||
if cls in schema.datasets:
|
|
||||||
schema.datasets[schema.datasets.index(cls)] = new_cls
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if cls in schema.groups:
|
|
||||||
schema.groups[schema.groups.index(cls)] = new_cls
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
if not found:
|
|
||||||
raise KeyError(
|
|
||||||
f"Unable to find source schema for {cls} when reinserting after rolling"
|
|
||||||
" down!"
|
|
||||||
)
|
|
||||||
|
|
||||||
def find_type_source(self, name: str) -> SchemaAdapter:
|
|
||||||
"""
|
|
||||||
Given some neurodata_type_inc, find the schema that it's defined in.
|
|
||||||
|
|
||||||
Rather than returning as soon as a match is found, check all
|
|
||||||
"""
|
|
||||||
# First check within the main schema
|
|
||||||
internal_matches = []
|
|
||||||
for schema in self.schemas:
|
|
||||||
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
|
|
||||||
if name in class_names:
|
|
||||||
internal_matches.append(schema)
|
|
||||||
|
|
||||||
if len(internal_matches) > 1:
|
|
||||||
raise KeyError(
|
|
||||||
f"Found multiple schemas in namespace that define {name}:\ninternal:"
|
|
||||||
f" {pformat(internal_matches)}\nimported:{pformat(internal_matches)}"
|
|
||||||
)
|
|
||||||
elif len(internal_matches) == 1:
|
|
||||||
return internal_matches[0]
|
|
||||||
|
|
||||||
import_matches = []
|
|
||||||
for imported_ns in self.imported:
|
|
||||||
for schema in imported_ns.schemas:
|
|
||||||
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
|
|
||||||
if name in class_names:
|
|
||||||
import_matches.append(schema)
|
|
||||||
|
|
||||||
if len(import_matches) > 1:
|
|
||||||
raise KeyError(
|
|
||||||
f"Found multiple schemas in namespace that define {name}:\ninternal:"
|
|
||||||
f" {pformat(internal_matches)}\nimported:{pformat(import_matches)}"
|
|
||||||
)
|
|
||||||
elif len(import_matches) == 1:
|
|
||||||
return import_matches[0]
|
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"No schema found that define {name}")
|
# top level class, need to go and find it
|
||||||
|
schema = self.find_type_source(old_cls)
|
||||||
|
if isinstance(new_cls, Dataset):
|
||||||
|
schema.datasets[schema.datasets.index(old_cls)] = new_cls
|
||||||
|
else:
|
||||||
|
schema.groups[schema.groups.index(old_cls)] = new_cls
|
||||||
|
|
||||||
|
def find_type_source(self, cls: str | Dataset | Group, fast: bool = False) -> SchemaAdapter:
|
||||||
|
"""
|
||||||
|
Given some type (as `neurodata_type_def`), find the schema that it's defined in.
|
||||||
|
|
||||||
|
Rather than returning as soon as a match is found, ensure that duplicates are
|
||||||
|
not found within the primary schema, then so the same for all imported schemas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (str | :class:`.Dataset` | :class:`.Group`): The ``neurodata_type_def``
|
||||||
|
to look for the source of. If a Dataset or Group, look for the object itself
|
||||||
|
(cls in schema.datasets), otherwise look for a class with a matching name.
|
||||||
|
fast (bool): If ``True``, return as soon as a match is found.
|
||||||
|
If ``False`, return after checking all schemas for duplicates.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`.SchemaAdapter`
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: if multiple schemas or no schemas are found
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
for schema in self.all_schemas():
|
||||||
|
in_schema = False
|
||||||
|
if isinstance(cls, str) and cls in [
|
||||||
|
c.neurodata_type_def for c in schema.created_classes
|
||||||
|
]:
|
||||||
|
in_schema = True
|
||||||
|
elif isinstance(cls, Dataset) and cls in schema.datasets:
|
||||||
|
in_schema = True
|
||||||
|
elif isinstance(cls, Group) and cls in schema.groups:
|
||||||
|
in_schema = True
|
||||||
|
|
||||||
|
if in_schema:
|
||||||
|
if fast:
|
||||||
|
return schema
|
||||||
|
else:
|
||||||
|
matches.append(schema)
|
||||||
|
|
||||||
|
if len(matches) > 1:
|
||||||
|
raise KeyError(f"Found multiple schemas in namespace that define {cls}:\n{matches}")
|
||||||
|
elif len(matches) == 1:
|
||||||
|
return matches[0]
|
||||||
|
else:
|
||||||
|
raise KeyError(f"No schema found that define {cls}")
|
||||||
|
|
||||||
def _populate_imports(self) -> "NamespacesAdapter":
|
def _populate_imports(self) -> "NamespacesAdapter":
|
||||||
"""
|
"""
|
||||||
|
@ -378,3 +388,99 @@ class NamespacesAdapter(Adapter):
|
||||||
for imported in self.imported:
|
for imported in self.imported:
|
||||||
for sch in imported.schemas:
|
for sch in imported.schemas:
|
||||||
yield sch
|
yield sch
|
||||||
|
|
||||||
|
|
||||||
|
def roll_down_nwb_class(
|
||||||
|
source: Group | Dataset | dict, target: Group | Dataset | dict, complete: bool = False
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Merge an ancestor (via ``neurodata_type_inc`` ) source class with a
|
||||||
|
child ``target`` class.
|
||||||
|
|
||||||
|
On the first recurive pass, only those values that are set on the target are copied from the
|
||||||
|
source class - this isn't a true merging, what we are after is to recursively merge all the
|
||||||
|
values that are modified in the child class with those of the parent class below the top level,
|
||||||
|
the top-level attributes will be carried through via normal inheritance.
|
||||||
|
|
||||||
|
Rather than re-instantiating the child class, we return the dictionary so that this
|
||||||
|
function can be used in series to merge a whole ancestry chain within
|
||||||
|
:class:`.NamespacesAdapter` , but this isn't exposed in the function since
|
||||||
|
class definitions can be spread out over many schemas, and we need the orchestration
|
||||||
|
of the adapter to have them in all cases we'd be using this.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source (dict): source dictionary
|
||||||
|
target (dict): target dictionary (values merged over source)
|
||||||
|
complete (bool): (default ``False``)do a complete merge, merging everything
|
||||||
|
from source to target without trying to minimize redundancy.
|
||||||
|
Used to collapse ancestor classes before the terminal class.
|
||||||
|
|
||||||
|
References:
|
||||||
|
https://github.com/NeurodataWithoutBorders/pynwb/issues/1954
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(source, (Group, Dataset)):
|
||||||
|
source = source.model_dump(exclude_unset=True, exclude_none=True)
|
||||||
|
if isinstance(target, (Group, Dataset)):
|
||||||
|
target = target.model_dump(exclude_unset=True, exclude_none=True)
|
||||||
|
|
||||||
|
exclude = ("neurodata_type_def",)
|
||||||
|
|
||||||
|
# if we are on the first recursion, we exclude top-level items that are not set in the target
|
||||||
|
if complete:
|
||||||
|
ret = {k: v for k, v in source.items() if k not in exclude}
|
||||||
|
else:
|
||||||
|
ret = {k: v for k, v in source.items() if k not in exclude and k in target}
|
||||||
|
|
||||||
|
for key, value in target.items():
|
||||||
|
if key not in ret:
|
||||||
|
ret[key] = value
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
if key in ret:
|
||||||
|
ret[key] = roll_down_nwb_class(ret[key], value, complete=True)
|
||||||
|
else:
|
||||||
|
ret[key] = value
|
||||||
|
elif isinstance(value, list) and all([isinstance(v, dict) for v in value]):
|
||||||
|
src_keys = {v["name"]: ret[key].index(v) for v in ret.get(key, {}) if "name" in v}
|
||||||
|
target_keys = {v["name"]: value.index(v) for v in value if "name" in v}
|
||||||
|
|
||||||
|
new_val = []
|
||||||
|
# screwy double iteration to preserve dict order
|
||||||
|
# all dicts not in target, if in depth > 0
|
||||||
|
if complete:
|
||||||
|
new_val.extend(
|
||||||
|
[
|
||||||
|
ret[key][src_keys[k]]
|
||||||
|
for k in src_keys
|
||||||
|
if k in set(src_keys.keys()) - set(target_keys.keys())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# all dicts not in source
|
||||||
|
new_val.extend(
|
||||||
|
[
|
||||||
|
value[target_keys[k]]
|
||||||
|
for k in target_keys
|
||||||
|
if k in set(target_keys.keys()) - set(src_keys.keys())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# merge dicts in both
|
||||||
|
new_val.extend(
|
||||||
|
[
|
||||||
|
roll_down_nwb_class(ret[key][src_keys[k]], value[target_keys[k]], complete=True)
|
||||||
|
for k in target_keys
|
||||||
|
if k in set(src_keys.keys()).intersection(set(target_keys.keys()))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
new_val = sorted(new_val, key=lambda i: i["name"])
|
||||||
|
# add any dicts that don't have the list_key
|
||||||
|
# they can't be merged since they can't be matched
|
||||||
|
if complete:
|
||||||
|
new_val.extend([v for v in ret.get(key, {}) if "name" not in v])
|
||||||
|
new_val.extend([v for v in value if "name" not in v])
|
||||||
|
|
||||||
|
ret[key] = new_val
|
||||||
|
|
||||||
|
else:
|
||||||
|
ret[key] = value
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
|
@ -136,7 +136,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
"""Customize dynamictable behavior"""
|
"""Customize dynamictable behavior"""
|
||||||
cls = AfterGenerateClass.inject_dynamictable(cls)
|
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||||
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
||||||
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
|
cls = AfterGenerateClass.inject_dynamictable_imports(cls, sv, self._get_element_import)
|
||||||
cls = AfterGenerateClass.strip_vector_data_slots(cls, sv)
|
cls = AfterGenerateClass.strip_vector_data_slots(cls, sv)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@ -346,19 +346,22 @@ class AfterGenerateClass:
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def inject_elementidentifiers(
|
def inject_dynamictable_imports(
|
||||||
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
|
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
|
||||||
) -> ClassResult:
|
) -> ClassResult:
|
||||||
"""
|
"""
|
||||||
Inject ElementIdentifiers into module that define dynamictables -
|
Ensure that schema that contain dynamictables have all the imports needed to use them
|
||||||
needed to handle ID columns
|
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
cls.source.is_a == "DynamicTable"
|
cls.source.is_a == "DynamicTable"
|
||||||
or "DynamicTable" in sv.class_ancestors(cls.source.name)
|
or "DynamicTable" in sv.class_ancestors(cls.source.name)
|
||||||
) and sv.schema.name != "hdmf-common.table":
|
) and sv.schema.name != "hdmf-common.table":
|
||||||
imp = import_method("ElementIdentifiers")
|
imp = [
|
||||||
cls.imports += [imp]
|
import_method("ElementIdentifiers"),
|
||||||
|
import_method("VectorData"),
|
||||||
|
import_method("VectorIndex"),
|
||||||
|
]
|
||||||
|
cls.imports += imp
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1,73 +0,0 @@
|
||||||
"""
|
|
||||||
The much maligned junk drawer
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def merge_dicts(
|
|
||||||
source: dict, target: dict, list_key: str | None = None, exclude: list[str] | None = None
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Deeply merge nested dictionaries, replacing already-declared keys rather than
|
|
||||||
e.g. merging lists as well
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source (dict): source dictionary
|
|
||||||
target (dict): target dictionary (values merged over source)
|
|
||||||
list_key (str | None): Optional: if present, merge lists of dicts using this to
|
|
||||||
identify matching dicts
|
|
||||||
exclude: (list[str] | None): Optional: if present, exclude keys from parent.
|
|
||||||
|
|
||||||
References:
|
|
||||||
https://stackoverflow.com/a/20666342/13113166
|
|
||||||
|
|
||||||
"""
|
|
||||||
if exclude is None:
|
|
||||||
exclude = []
|
|
||||||
ret = {k: v for k, v in source.items() if k not in exclude}
|
|
||||||
for key, value in target.items():
|
|
||||||
if key not in ret:
|
|
||||||
ret[key] = value
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
if key in ret:
|
|
||||||
ret[key] = merge_dicts(ret[key], value, list_key, exclude)
|
|
||||||
else:
|
|
||||||
ret[key] = value
|
|
||||||
elif isinstance(value, list) and list_key and all([isinstance(v, dict) for v in value]):
|
|
||||||
src_keys = {v[list_key]: ret[key].index(v) for v in ret.get(key, {}) if list_key in v}
|
|
||||||
target_keys = {v[list_key]: value.index(v) for v in value if list_key in v}
|
|
||||||
|
|
||||||
# all dicts not in target
|
|
||||||
# screwy double iteration to preserve dict order
|
|
||||||
new_val = [
|
|
||||||
ret[key][src_keys[k]]
|
|
||||||
for k in src_keys
|
|
||||||
if k in set(src_keys.keys()) - set(target_keys.keys())
|
|
||||||
]
|
|
||||||
# all dicts not in source
|
|
||||||
new_val.extend(
|
|
||||||
[
|
|
||||||
value[target_keys[k]]
|
|
||||||
for k in target_keys
|
|
||||||
if k in set(target_keys.keys()) - set(src_keys.keys())
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# merge dicts in both
|
|
||||||
new_val.extend(
|
|
||||||
[
|
|
||||||
merge_dicts(ret[key][src_keys[k]], value[target_keys[k]], list_key, exclude)
|
|
||||||
for k in target_keys
|
|
||||||
if k in set(src_keys.keys()).intersection(set(target_keys.keys()))
|
|
||||||
]
|
|
||||||
)
|
|
||||||
new_val = sorted(new_val, key=lambda i: i[list_key])
|
|
||||||
# add any dicts that don't have the list_key
|
|
||||||
# they can't be merged since they can't be matched
|
|
||||||
new_val.extend([v for v in ret.get(key, {}) if list_key not in v])
|
|
||||||
new_val.extend([v for v in value if list_key not in v])
|
|
||||||
|
|
||||||
ret[key] = new_val
|
|
||||||
|
|
||||||
else:
|
|
||||||
ret[key] = value
|
|
||||||
|
|
||||||
return ret
|
|
|
@ -7,6 +7,7 @@ from decimal import Decimal
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
|
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from nwb_schema_language.util import pformat
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,7 +24,12 @@ class ConfiguredBaseModel(BaseModel):
|
||||||
use_enum_values=True,
|
use_enum_values=True,
|
||||||
strict=False,
|
strict=False,
|
||||||
)
|
)
|
||||||
pass
|
|
||||||
|
def __repr__(self):
|
||||||
|
return pformat(self.model_dump(exclude={"parent": True}), self.__class__.__name__)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return repr(self)
|
||||||
|
|
||||||
|
|
||||||
class LinkMLMeta(RootModel):
|
class LinkMLMeta(RootModel):
|
||||||
|
@ -44,9 +50,10 @@ class LinkMLMeta(RootModel):
|
||||||
|
|
||||||
|
|
||||||
class ParentizeMixin(BaseModel):
|
class ParentizeMixin(BaseModel):
|
||||||
|
"""Mixin to populate the parent field for nested datasets and groups"""
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def parentize(self):
|
def parentize(self) -> BaseModel:
|
||||||
"""Set the parent attribute for all our fields they have one"""
|
"""Set the parent attribute for all our fields they have one"""
|
||||||
for field_name in self.model_fields:
|
for field_name in self.model_fields:
|
||||||
if field_name == "parent":
|
if field_name == "parent":
|
||||||
|
|
|
@ -31,6 +31,22 @@ class ParentizeMixin(BaseModel):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
STR_METHOD = """
|
||||||
|
def __repr__(self):
|
||||||
|
return pformat(
|
||||||
|
self.model_dump(
|
||||||
|
exclude={"parent": True},
|
||||||
|
exclude_unset=True,
|
||||||
|
exclude_None=True
|
||||||
|
),
|
||||||
|
self.__class__.__name__
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return repr(self)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NWBSchemaLangGenerator(PydanticGenerator):
|
class NWBSchemaLangGenerator(PydanticGenerator):
|
||||||
"""
|
"""
|
||||||
|
@ -40,8 +56,10 @@ class NWBSchemaLangGenerator(PydanticGenerator):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
kwargs["injected_classes"] = [ParentizeMixin]
|
kwargs["injected_classes"] = [ParentizeMixin]
|
||||||
kwargs["imports"] = [
|
kwargs["imports"] = [
|
||||||
Import(module="pydantic", objects=[ObjectImport(name="model_validator")])
|
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
|
||||||
|
Import(module="nwb_schema_language.util", objects=[ObjectImport(name="pformat")]),
|
||||||
]
|
]
|
||||||
|
kwargs["injected_fields"] = [STR_METHOD]
|
||||||
kwargs["black"] = True
|
kwargs["black"] = True
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
39
nwb_schema_language/src/nwb_schema_language/util.py
Normal file
39
nwb_schema_language/src/nwb_schema_language/util.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
from pprint import pformat as _pformat
|
||||||
|
import textwrap
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def pformat(fields: dict, cls_name: str, indent: str = " ") -> str:
|
||||||
|
"""
|
||||||
|
pretty format the fields of the items of a ``YAMLRoot`` object without the wonky indentation of pformat.
|
||||||
|
see ``YAMLRoot.__repr__``.
|
||||||
|
|
||||||
|
formatting is similar to black - items at similar levels of nesting have similar levels of indentation,
|
||||||
|
rather than getting placed at essentially random levels of indentation depending on what came before them.
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
total_len = 0
|
||||||
|
for key, val in fields.items():
|
||||||
|
if val == [] or val == {} or val is None:
|
||||||
|
continue
|
||||||
|
# pformat handles everything else that isn't a YAMLRoot object, but it sure does look ugly
|
||||||
|
# use it to split lines and as the thing of last resort, but otherwise indent = 0, we'll do that
|
||||||
|
val_str = _pformat(val, indent=0, compact=True, sort_dicts=False)
|
||||||
|
# now we indent everything except the first line by indenting and then using regex to remove just the first indent
|
||||||
|
val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent))
|
||||||
|
# now recombine with the key in a format that can be re-eval'd into an object if indent is just whitespace
|
||||||
|
val_str = f"'{key}': " + val_str
|
||||||
|
|
||||||
|
# count the total length of this string so we know if we need to linebreak or not later
|
||||||
|
total_len += len(val_str)
|
||||||
|
res.append(val_str)
|
||||||
|
|
||||||
|
if total_len > 80:
|
||||||
|
inside = ",\n".join(res)
|
||||||
|
# we indent twice - once for the inner contents of every inner object, and one to
|
||||||
|
# offset from the root element. that keeps us from needing to be recursive except for the
|
||||||
|
# single pformat call
|
||||||
|
inside = textwrap.indent(inside, indent)
|
||||||
|
return cls_name + "({\n" + inside + "\n})"
|
||||||
|
else:
|
||||||
|
return cls_name + "({" + ", ".join(res) + "})"
|
|
@ -67,7 +67,6 @@ def generate_versions(
|
||||||
pydantic_path: Path,
|
pydantic_path: Path,
|
||||||
dry_run: bool = False,
|
dry_run: bool = False,
|
||||||
repo: GitRepo = NWB_CORE_REPO,
|
repo: GitRepo = NWB_CORE_REPO,
|
||||||
hdmf_only=False,
|
|
||||||
pdb=False,
|
pdb=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -253,10 +252,10 @@ def main():
|
||||||
args.yaml.mkdir(exist_ok=True)
|
args.yaml.mkdir(exist_ok=True)
|
||||||
args.pydantic.mkdir(exist_ok=True)
|
args.pydantic.mkdir(exist_ok=True)
|
||||||
if args.latest:
|
if args.latest:
|
||||||
generate_core_yaml(args.yaml, args.dry_run, args.hdmf)
|
generate_core_yaml(args.yaml, args.dry_run)
|
||||||
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
|
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
|
||||||
else:
|
else:
|
||||||
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf, pdb=args.pdb)
|
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in a new issue