mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 06:04:28 +00:00
parallel builds, remove monkeypatches
This commit is contained in:
parent
98a9975406
commit
582ac5f09c
5 changed files with 96 additions and 177 deletions
|
@ -16,6 +16,8 @@ dependencies = [
|
|||
"rich>=13.5.2",
|
||||
#"linkml>=1.7.10",
|
||||
"linkml @ git+https://github.com/sneakers-the-rat/linkml@nwb-linkml",
|
||||
# until recursive imports gets released
|
||||
"linkml-runtime @ git+https://github.com/linkml/linkml-runtime@main",
|
||||
"pydantic>=2.3.0",
|
||||
"h5py>=3.9.0",
|
||||
"pydantic-settings>=2.0.3",
|
||||
|
|
|
@ -16,6 +16,7 @@ from linkml.generators import PydanticGenerator
|
|||
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
|
||||
from linkml.generators.pydanticgen.build import ClassResult, SlotResult
|
||||
from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule
|
||||
from linkml.generators.pydanticgen.pydanticgen import SplitMode
|
||||
from linkml_runtime.linkml_model.meta import (
|
||||
ArrayExpression,
|
||||
SchemaDefinition,
|
||||
|
@ -60,7 +61,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
array_representations: List[ArrayRepresentation] = field(
|
||||
default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC]
|
||||
)
|
||||
black: bool = True
|
||||
black: bool = False
|
||||
inlined: bool = True
|
||||
emit_metadata: bool = True
|
||||
gen_classvars: bool = True
|
||||
|
@ -94,6 +95,15 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
if not base_range_subsumes_any_of:
|
||||
raise ValueError("Slot cannot have both range and any_of defined")
|
||||
|
||||
def render(self) -> PydanticModule:
|
||||
is_namespace = False
|
||||
ns_annotation = self.schemaview.schema.annotations.get("is_namespace", None)
|
||||
if ns_annotation:
|
||||
is_namespace = ns_annotation.value
|
||||
self.split_mode = SplitMode.FULL if is_namespace else SplitMode.AUTO
|
||||
|
||||
return super().render()
|
||||
|
||||
def before_generate_slot(self, slot: SlotDefinition, sv: SchemaView) -> SlotDefinition:
|
||||
"""
|
||||
Force some properties to be optional
|
||||
|
|
|
@ -5,67 +5,6 @@ Monkeypatches to external modules
|
|||
# ruff: noqa: ANN001 - not well defined types for this module
|
||||
|
||||
|
||||
def patch_schemaview() -> None:
|
||||
"""
|
||||
Patch schemaview to correctly resolve multiple layers of relative imports.
|
||||
|
||||
References:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from linkml_runtime.linkml_model import SchemaDefinitionName
|
||||
from linkml_runtime.utils.schemaview import SchemaView
|
||||
|
||||
@lru_cache
|
||||
def imports_closure(
|
||||
self, imports: bool = True, traverse=True, inject_metadata=True
|
||||
) -> List[SchemaDefinitionName]:
|
||||
"""
|
||||
Return all imports
|
||||
|
||||
:param traverse: if true, traverse recursively
|
||||
:return: all schema names in the transitive reflexive imports closure
|
||||
"""
|
||||
if not imports:
|
||||
return [self.schema.name]
|
||||
if self.schema_map is None:
|
||||
self.schema_map = {self.schema.name: self.schema}
|
||||
closure = []
|
||||
visited = set()
|
||||
todo = [self.schema.name]
|
||||
if not traverse:
|
||||
return todo
|
||||
while len(todo) > 0:
|
||||
sn = todo.pop()
|
||||
visited.add(sn)
|
||||
if sn not in self.schema_map:
|
||||
imported_schema = self.load_import(sn)
|
||||
self.schema_map[sn] = imported_schema
|
||||
s = self.schema_map[sn]
|
||||
if sn not in closure:
|
||||
closure.append(sn)
|
||||
for i in s.imports:
|
||||
if sn.startswith(".") and ":" not in i:
|
||||
# prepend the relative part
|
||||
i = "/".join(sn.split("/")[:-1]) + "/" + i
|
||||
if i not in visited:
|
||||
todo.append(i)
|
||||
if inject_metadata:
|
||||
for s in self.schema_map.values():
|
||||
for x in {**s.classes, **s.enums, **s.slots, **s.subsets, **s.types}.values():
|
||||
x.from_schema = s.id
|
||||
for c in s.classes.values():
|
||||
for a in c.attributes.values():
|
||||
a.from_schema = s.id
|
||||
return closure
|
||||
|
||||
SchemaView.imports_closure = imports_closure
|
||||
|
||||
|
||||
def patch_array_expression() -> None:
|
||||
"""
|
||||
Allow SlotDefinitions to use `any_of` with `array`
|
||||
|
@ -75,7 +14,7 @@ def patch_array_expression() -> None:
|
|||
from dataclasses import field, make_dataclass
|
||||
from typing import Optional
|
||||
|
||||
from linkml_runtime.linkml_model import meta
|
||||
from linkml_runtime.linkml_model import meta, types
|
||||
|
||||
new_dataclass = make_dataclass(
|
||||
"AnonymousSlotExpression",
|
||||
|
@ -83,84 +22,9 @@ def patch_array_expression() -> None:
|
|||
bases=(meta.AnonymousSlotExpression,),
|
||||
)
|
||||
meta.AnonymousSlotExpression = new_dataclass
|
||||
|
||||
|
||||
def patch_pretty_print() -> None:
|
||||
"""
|
||||
Fix the godforsaken linkml dataclass reprs
|
||||
|
||||
See: https://github.com/linkml/linkml-runtime/pull/314
|
||||
"""
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import field, is_dataclass, make_dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
from linkml_runtime.linkml_model import meta
|
||||
from linkml_runtime.utils.formatutils import items
|
||||
|
||||
def _pformat(fields: dict, cls_name: str, indent: str = " ") -> str:
|
||||
"""
|
||||
pretty format the fields of the items of a ``YAMLRoot`` object without the wonky
|
||||
indentation of pformat.
|
||||
see ``YAMLRoot.__repr__``.
|
||||
formatting is similar to black - items at similar levels of nesting have similar levels
|
||||
of indentation,
|
||||
rather than getting placed at essentially random levels of indentation depending on what
|
||||
came before them.
|
||||
"""
|
||||
res = []
|
||||
total_len = 0
|
||||
for key, val in fields:
|
||||
if val == [] or val == {} or val is None:
|
||||
continue
|
||||
# pformat handles everything else that isn't a YAMLRoot object,
|
||||
# but it sure does look ugly
|
||||
# use it to split lines and as the thing of last resort, but otherwise indent = 0,
|
||||
# we'll do that
|
||||
val_str = pformat(val, indent=0, compact=True, sort_dicts=False)
|
||||
# now we indent everything except the first line by indenting
|
||||
# and then using regex to remove just the first indent
|
||||
val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent))
|
||||
# now recombine with the key in a format that can be re-eval'd
|
||||
# into an object if indent is just whitespace
|
||||
val_str = f"'{key}': " + val_str
|
||||
|
||||
# count the total length of this string so we know if we need to linebreak or not later
|
||||
total_len += len(val_str)
|
||||
res.append(val_str)
|
||||
|
||||
if total_len > 80:
|
||||
inside = ",\n".join(res)
|
||||
# we indent twice - once for the inner contents of every inner object, and one to
|
||||
# offset from the root element.
|
||||
# that keeps us from needing to be recursive except for the
|
||||
# single pformat call
|
||||
inside = textwrap.indent(inside, indent)
|
||||
return cls_name + "({\n" + inside + "\n})"
|
||||
else:
|
||||
return cls_name + "({" + ", ".join(res) + "})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return _pformat(items(self), self.__class__.__name__)
|
||||
|
||||
for cls_name in dir(meta):
|
||||
cls = getattr(meta, cls_name)
|
||||
if is_dataclass(cls):
|
||||
new_dataclass = make_dataclass(
|
||||
cls.__name__,
|
||||
fields=[("__dummy__", Any, field(default=None))],
|
||||
bases=(cls,),
|
||||
repr=False,
|
||||
)
|
||||
new_dataclass.__repr__ = __repr__
|
||||
new_dataclass.__str__ = __repr__
|
||||
setattr(meta, cls.__name__, new_dataclass)
|
||||
types.AnonymousSlotExpression = new_dataclass
|
||||
|
||||
|
||||
def apply_patches() -> None:
|
||||
"""Apply all monkeypatches"""
|
||||
patch_schemaview()
|
||||
patch_array_expression()
|
||||
patch_pretty_print()
|
||||
|
|
|
@ -9,16 +9,19 @@ from importlib.abc import MetaPathFinder
|
|||
from importlib.machinery import ModuleSpec
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, TYPE_CHECKING
|
||||
import multiprocessing as mp
|
||||
|
||||
from linkml.generators.pydanticgen.pydanticgen import SplitMode, _ensure_inits, _import_to_path
|
||||
from linkml_runtime.linkml_model.meta import SchemaDefinition
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
|
||||
from nwb_linkml.maps.naming import module_case, version_module_case
|
||||
from nwb_linkml.providers import LinkMLProvider, Provider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from linkml_runtime.linkml_model.meta import SchemaDefinition
|
||||
|
||||
|
||||
class PydanticProvider(Provider):
|
||||
"""
|
||||
|
@ -65,6 +68,7 @@ class PydanticProvider(Provider):
|
|||
split: bool = True,
|
||||
dump: bool = True,
|
||||
force: bool = False,
|
||||
parallel: bool = False,
|
||||
**kwargs: dict,
|
||||
) -> str | List[str]:
|
||||
"""
|
||||
|
@ -88,6 +92,8 @@ class PydanticProvider(Provider):
|
|||
otherwise just return the serialized string of built pydantic model
|
||||
force (bool): If ``False`` (default), don't build the model if it already exists,
|
||||
if ``True`` , delete and rebuild any model
|
||||
parallel (bool): If ``True``, build imported models using multiprocessing,
|
||||
if ``False`` (default), don't.
|
||||
**kwargs: Passed to :class:`.NWBPydanticGenerator`
|
||||
|
||||
Returns:
|
||||
|
@ -136,7 +142,9 @@ class PydanticProvider(Provider):
|
|||
|
||||
return serialized
|
||||
|
||||
def _build_split(self, path: Path, dump: bool, force: bool, **kwargs) -> List[str]:
|
||||
def _build_split(
|
||||
self, path: Path, dump: bool, force: bool, parallel: bool = False, **kwargs
|
||||
) -> List[str]:
|
||||
# FIXME: This is messy as all fuck, we're just getting it to work again
|
||||
# so we can start iterating on the models themselves
|
||||
res = []
|
||||
|
@ -171,50 +179,83 @@ class PydanticProvider(Provider):
|
|||
res.append(serialized)
|
||||
|
||||
# then each of the other schemas :)
|
||||
imported_schema: dict[str, SchemaDefinition] = {
|
||||
imported_schema: dict[str, "SchemaDefinition"] = {
|
||||
gen.generate_module_import(sch): sch for sch in gen.schemaview.schema_map.values()
|
||||
}
|
||||
for generated_import in [i for i in rendered.python_imports if i.schema]:
|
||||
import_file = (ns_file.parent / _import_to_path(generated_import.module)).resolve()
|
||||
generated_imports = [i for i in rendered.python_imports if i.schema]
|
||||
# each task has an expected output file a corresponding SchemaDefinition
|
||||
import_paths = [
|
||||
(ns_file.parent / _import_to_path(an_import.module)).resolve()
|
||||
for an_import in generated_imports
|
||||
]
|
||||
import_schemas = [
|
||||
Path(path).parent / imported_schema[an_import.module].source_file
|
||||
for an_import in generated_imports
|
||||
]
|
||||
|
||||
tasks = [
|
||||
(
|
||||
import_path,
|
||||
import_schema,
|
||||
force,
|
||||
self.SPLIT_PATTERN,
|
||||
dump,
|
||||
)
|
||||
for import_path, import_schema in zip(import_paths, import_schemas)
|
||||
]
|
||||
|
||||
if parallel:
|
||||
with mp.Pool(min(mp.cpu_count(), len(tasks))) as pool:
|
||||
mp_results = [pool.apply_async(self._generate_single, t) for t in tasks]
|
||||
for result in mp_results:
|
||||
res.append(result.get())
|
||||
else:
|
||||
for task in tasks:
|
||||
res.append(self._generate_single(*task))
|
||||
|
||||
# make __init__.py files if we generated any files
|
||||
if len(module_paths) > 0:
|
||||
_ensure_inits(import_paths)
|
||||
# then extra_inits that usually aren't generated bc we're one layer deeper
|
||||
self._make_inits(ns_file)
|
||||
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _generate_single(
|
||||
import_file: Path,
|
||||
# schema: "SchemaDefinition",
|
||||
schema: Path,
|
||||
force: bool,
|
||||
split_pattern: str,
|
||||
dump: bool,
|
||||
) -> str:
|
||||
"""
|
||||
Interior generator method for _build_split to be called in parallel
|
||||
|
||||
.. TODO::
|
||||
|
||||
split up and consolidate this build behavior, very spaghetti.
|
||||
|
||||
"""
|
||||
|
||||
if not import_file.exists() or force:
|
||||
import_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
schema = imported_schema[generated_import.module]
|
||||
is_namespace = False
|
||||
ns_annotation = schema.annotations.get("is_namespace", None)
|
||||
if ns_annotation:
|
||||
is_namespace = ns_annotation.value
|
||||
|
||||
# fix schema source to absolute path so schemaview can find imports
|
||||
schema.source_file = (
|
||||
Path(gen.schemaview.schema.source_file).parent / schema.source_file
|
||||
).resolve()
|
||||
|
||||
import_gen = NWBPydanticGenerator(
|
||||
schema,
|
||||
split=True,
|
||||
split_pattern=self.SPLIT_PATTERN,
|
||||
split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO,
|
||||
split_pattern=split_pattern,
|
||||
)
|
||||
serialized = import_gen.serialize()
|
||||
if dump:
|
||||
with open(import_file, "w") as ofile:
|
||||
ofile.write(serialized)
|
||||
module_paths.append(import_file)
|
||||
|
||||
else:
|
||||
with open(import_file) as ofile:
|
||||
serialized = ofile.read()
|
||||
|
||||
res.append(serialized)
|
||||
|
||||
# make __init__.py files if we generated any files
|
||||
if len(module_paths) > 0:
|
||||
_ensure_inits(module_paths)
|
||||
# then extra_inits that usually aren't generated bc we're one layer deeper
|
||||
self._make_inits(ns_file)
|
||||
|
||||
return res
|
||||
return serialized
|
||||
|
||||
def _make_inits(self, out_file: Path) -> None:
|
||||
"""
|
||||
|
|
|
@ -57,7 +57,7 @@ def make_tmp_dir(clear: bool = False) -> Path:
|
|||
if tmp_dir.exists() and clear:
|
||||
for p in tmp_dir.iterdir():
|
||||
if p.is_dir() and not p.name == "git":
|
||||
shutil.rmtree(tmp_dir)
|
||||
shutil.rmtree(p)
|
||||
tmp_dir.mkdir(exist_ok=True)
|
||||
return tmp_dir
|
||||
|
||||
|
@ -139,7 +139,9 @@ def generate_versions(
|
|||
for schema in ns_files:
|
||||
pbar_string = schema.parts[-3]
|
||||
build_progress.update(pydantic_task, action=pbar_string)
|
||||
pydantic_provider.build(schema, versions=core_ns.versions, split=True)
|
||||
pydantic_provider.build(
|
||||
schema, versions=core_ns.versions, split=True, parallel=True
|
||||
)
|
||||
build_progress.update(pydantic_task, advance=1)
|
||||
build_progress.update(pydantic_task, action="Built Pydantic")
|
||||
|
||||
|
|
Loading…
Reference in a new issue