parallel builds, remove monkeypatches

This commit is contained in:
sneakers-the-rat 2024-08-20 22:45:24 -07:00
parent 98a9975406
commit 582ac5f09c
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 96 additions and 177 deletions

View file

@ -16,6 +16,8 @@ dependencies = [
"rich>=13.5.2",
#"linkml>=1.7.10",
"linkml @ git+https://github.com/sneakers-the-rat/linkml@nwb-linkml",
# until recursive imports gets released
"linkml-runtime @ git+https://github.com/linkml/linkml-runtime@main",
"pydantic>=2.3.0",
"h5py>=3.9.0",
"pydantic-settings>=2.0.3",

View file

@ -16,6 +16,7 @@ from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
from linkml.generators.pydanticgen.build import ClassResult, SlotResult
from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule
from linkml.generators.pydanticgen.pydanticgen import SplitMode
from linkml_runtime.linkml_model.meta import (
ArrayExpression,
SchemaDefinition,
@ -60,7 +61,7 @@ class NWBPydanticGenerator(PydanticGenerator):
array_representations: List[ArrayRepresentation] = field(
default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC]
)
black: bool = True
black: bool = False
inlined: bool = True
emit_metadata: bool = True
gen_classvars: bool = True
@ -94,6 +95,15 @@ class NWBPydanticGenerator(PydanticGenerator):
if not base_range_subsumes_any_of:
raise ValueError("Slot cannot have both range and any_of defined")
def render(self) -> PydanticModule:
is_namespace = False
ns_annotation = self.schemaview.schema.annotations.get("is_namespace", None)
if ns_annotation:
is_namespace = ns_annotation.value
self.split_mode = SplitMode.FULL if is_namespace else SplitMode.AUTO
return super().render()
def before_generate_slot(self, slot: SlotDefinition, sv: SchemaView) -> SlotDefinition:
"""
Force some properties to be optional

View file

@ -5,67 +5,6 @@ Monkeypatches to external modules
# ruff: noqa: ANN001 - not well defined types for this module
def patch_schemaview() -> None:
"""
Patch schemaview to correctly resolve multiple layers of relative imports.
References:
Returns:
"""
from functools import lru_cache
from typing import List
from linkml_runtime.linkml_model import SchemaDefinitionName
from linkml_runtime.utils.schemaview import SchemaView
@lru_cache
def imports_closure(
self, imports: bool = True, traverse=True, inject_metadata=True
) -> List[SchemaDefinitionName]:
"""
Return all imports
:param traverse: if true, traverse recursively
:return: all schema names in the transitive reflexive imports closure
"""
if not imports:
return [self.schema.name]
if self.schema_map is None:
self.schema_map = {self.schema.name: self.schema}
closure = []
visited = set()
todo = [self.schema.name]
if not traverse:
return todo
while len(todo) > 0:
sn = todo.pop()
visited.add(sn)
if sn not in self.schema_map:
imported_schema = self.load_import(sn)
self.schema_map[sn] = imported_schema
s = self.schema_map[sn]
if sn not in closure:
closure.append(sn)
for i in s.imports:
if sn.startswith(".") and ":" not in i:
# prepend the relative part
i = "/".join(sn.split("/")[:-1]) + "/" + i
if i not in visited:
todo.append(i)
if inject_metadata:
for s in self.schema_map.values():
for x in {**s.classes, **s.enums, **s.slots, **s.subsets, **s.types}.values():
x.from_schema = s.id
for c in s.classes.values():
for a in c.attributes.values():
a.from_schema = s.id
return closure
SchemaView.imports_closure = imports_closure
def patch_array_expression() -> None:
"""
Allow SlotDefinitions to use `any_of` with `array`
@ -75,7 +14,7 @@ def patch_array_expression() -> None:
from dataclasses import field, make_dataclass
from typing import Optional
from linkml_runtime.linkml_model import meta
from linkml_runtime.linkml_model import meta, types
new_dataclass = make_dataclass(
"AnonymousSlotExpression",
@ -83,84 +22,9 @@ def patch_array_expression() -> None:
bases=(meta.AnonymousSlotExpression,),
)
meta.AnonymousSlotExpression = new_dataclass
def patch_pretty_print() -> None:
"""
Fix the godforsaken linkml dataclass reprs
See: https://github.com/linkml/linkml-runtime/pull/314
"""
import re
import textwrap
from dataclasses import field, is_dataclass, make_dataclass
from pprint import pformat
from typing import Any
from linkml_runtime.linkml_model import meta
from linkml_runtime.utils.formatutils import items
def _pformat(fields: dict, cls_name: str, indent: str = " ") -> str:
"""
pretty format the fields of the items of a ``YAMLRoot`` object without the wonky
indentation of pformat.
see ``YAMLRoot.__repr__``.
formatting is similar to black - items at similar levels of nesting have similar levels
of indentation,
rather than getting placed at essentially random levels of indentation depending on what
came before them.
"""
res = []
total_len = 0
for key, val in fields:
if val == [] or val == {} or val is None:
continue
# pformat handles everything else that isn't a YAMLRoot object,
# but it sure does look ugly
# use it to split lines and as the thing of last resort, but otherwise indent = 0,
# we'll do that
val_str = pformat(val, indent=0, compact=True, sort_dicts=False)
# now we indent everything except the first line by indenting
# and then using regex to remove just the first indent
val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent))
# now recombine with the key in a format that can be re-eval'd
# into an object if indent is just whitespace
val_str = f"'{key}': " + val_str
# count the total length of this string so we know if we need to linebreak or not later
total_len += len(val_str)
res.append(val_str)
if total_len > 80:
inside = ",\n".join(res)
# we indent twice - once for the inner contents of every inner object, and one to
# offset from the root element.
# that keeps us from needing to be recursive except for the
# single pformat call
inside = textwrap.indent(inside, indent)
return cls_name + "({\n" + inside + "\n})"
else:
return cls_name + "({" + ", ".join(res) + "})"
def __repr__(self) -> str:
return _pformat(items(self), self.__class__.__name__)
for cls_name in dir(meta):
cls = getattr(meta, cls_name)
if is_dataclass(cls):
new_dataclass = make_dataclass(
cls.__name__,
fields=[("__dummy__", Any, field(default=None))],
bases=(cls,),
repr=False,
)
new_dataclass.__repr__ = __repr__
new_dataclass.__str__ = __repr__
setattr(meta, cls.__name__, new_dataclass)
types.AnonymousSlotExpression = new_dataclass
def apply_patches() -> None:
"""Apply all monkeypatches"""
patch_schemaview()
patch_array_expression()
patch_pretty_print()

View file

@ -9,16 +9,19 @@ from importlib.abc import MetaPathFinder
from importlib.machinery import ModuleSpec
from pathlib import Path
from types import ModuleType
from typing import List, Optional, Type
from typing import List, Optional, Type, TYPE_CHECKING
import multiprocessing as mp
from linkml.generators.pydanticgen.pydanticgen import SplitMode, _ensure_inits, _import_to_path
from linkml_runtime.linkml_model.meta import SchemaDefinition
from pydantic import BaseModel
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from nwb_linkml.maps.naming import module_case, version_module_case
from nwb_linkml.providers import LinkMLProvider, Provider
if TYPE_CHECKING:
from linkml_runtime.linkml_model.meta import SchemaDefinition
class PydanticProvider(Provider):
"""
@ -65,6 +68,7 @@ class PydanticProvider(Provider):
split: bool = True,
dump: bool = True,
force: bool = False,
parallel: bool = False,
**kwargs: dict,
) -> str | List[str]:
"""
@ -88,6 +92,8 @@ class PydanticProvider(Provider):
otherwise just return the serialized string of built pydantic model
force (bool): If ``False`` (default), don't build the model if it already exists,
if ``True`` , delete and rebuild any model
parallel (bool): If ``True``, build imported models using multiprocessing,
if ``False`` (default), don't.
**kwargs: Passed to :class:`.NWBPydanticGenerator`
Returns:
@ -136,7 +142,9 @@ class PydanticProvider(Provider):
return serialized
def _build_split(self, path: Path, dump: bool, force: bool, **kwargs) -> List[str]:
def _build_split(
self, path: Path, dump: bool, force: bool, parallel: bool = False, **kwargs
) -> List[str]:
# FIXME: This is messy as all fuck, we're just getting it to work again
# so we can start iterating on the models themselves
res = []
@ -171,50 +179,83 @@ class PydanticProvider(Provider):
res.append(serialized)
# then each of the other schemas :)
imported_schema: dict[str, SchemaDefinition] = {
imported_schema: dict[str, "SchemaDefinition"] = {
gen.generate_module_import(sch): sch for sch in gen.schemaview.schema_map.values()
}
for generated_import in [i for i in rendered.python_imports if i.schema]:
import_file = (ns_file.parent / _import_to_path(generated_import.module)).resolve()
generated_imports = [i for i in rendered.python_imports if i.schema]
# each task has an expected output file a corresponding SchemaDefinition
import_paths = [
(ns_file.parent / _import_to_path(an_import.module)).resolve()
for an_import in generated_imports
]
import_schemas = [
Path(path).parent / imported_schema[an_import.module].source_file
for an_import in generated_imports
]
tasks = [
(
import_path,
import_schema,
force,
self.SPLIT_PATTERN,
dump,
)
for import_path, import_schema in zip(import_paths, import_schemas)
]
if parallel:
with mp.Pool(min(mp.cpu_count(), len(tasks))) as pool:
mp_results = [pool.apply_async(self._generate_single, t) for t in tasks]
for result in mp_results:
res.append(result.get())
else:
for task in tasks:
res.append(self._generate_single(*task))
# make __init__.py files if we generated any files
if len(module_paths) > 0:
_ensure_inits(import_paths)
# then extra_inits that usually aren't generated bc we're one layer deeper
self._make_inits(ns_file)
return res
@staticmethod
def _generate_single(
import_file: Path,
# schema: "SchemaDefinition",
schema: Path,
force: bool,
split_pattern: str,
dump: bool,
) -> str:
"""
Interior generator method for _build_split to be called in parallel
.. TODO::
split up and consolidate this build behavior, very spaghetti.
"""
if not import_file.exists() or force:
import_file.parent.mkdir(exist_ok=True, parents=True)
schema = imported_schema[generated_import.module]
is_namespace = False
ns_annotation = schema.annotations.get("is_namespace", None)
if ns_annotation:
is_namespace = ns_annotation.value
# fix schema source to absolute path so schemaview can find imports
schema.source_file = (
Path(gen.schemaview.schema.source_file).parent / schema.source_file
).resolve()
import_gen = NWBPydanticGenerator(
schema,
split=True,
split_pattern=self.SPLIT_PATTERN,
split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO,
split_pattern=split_pattern,
)
serialized = import_gen.serialize()
if dump:
with open(import_file, "w") as ofile:
ofile.write(serialized)
module_paths.append(import_file)
else:
with open(import_file) as ofile:
serialized = ofile.read()
res.append(serialized)
# make __init__.py files if we generated any files
if len(module_paths) > 0:
_ensure_inits(module_paths)
# then extra_inits that usually aren't generated bc we're one layer deeper
self._make_inits(ns_file)
return res
return serialized
def _make_inits(self, out_file: Path) -> None:
"""

View file

@ -57,7 +57,7 @@ def make_tmp_dir(clear: bool = False) -> Path:
if tmp_dir.exists() and clear:
for p in tmp_dir.iterdir():
if p.is_dir() and not p.name == "git":
shutil.rmtree(tmp_dir)
shutil.rmtree(p)
tmp_dir.mkdir(exist_ok=True)
return tmp_dir
@ -139,7 +139,9 @@ def generate_versions(
for schema in ns_files:
pbar_string = schema.parts[-3]
build_progress.update(pydantic_task, action=pbar_string)
pydantic_provider.build(schema, versions=core_ns.versions, split=True)
pydantic_provider.build(
schema, versions=core_ns.versions, split=True, parallel=True
)
build_progress.update(pydantic_task, advance=1)
build_progress.update(pydantic_task, action="Built Pydantic")