parallel builds, remove monkeypatches

This commit is contained in:
sneakers-the-rat 2024-08-20 22:45:24 -07:00
parent 98a9975406
commit 582ac5f09c
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 96 additions and 177 deletions

View file

@ -16,6 +16,8 @@ dependencies = [
"rich>=13.5.2", "rich>=13.5.2",
#"linkml>=1.7.10", #"linkml>=1.7.10",
"linkml @ git+https://github.com/sneakers-the-rat/linkml@nwb-linkml", "linkml @ git+https://github.com/sneakers-the-rat/linkml@nwb-linkml",
# until recursive imports gets released
"linkml-runtime @ git+https://github.com/linkml/linkml-runtime@main",
"pydantic>=2.3.0", "pydantic>=2.3.0",
"h5py>=3.9.0", "h5py>=3.9.0",
"pydantic-settings>=2.0.3", "pydantic-settings>=2.0.3",

View file

@ -16,6 +16,7 @@ from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
from linkml.generators.pydanticgen.build import ClassResult, SlotResult from linkml.generators.pydanticgen.build import ClassResult, SlotResult
from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule
from linkml.generators.pydanticgen.pydanticgen import SplitMode
from linkml_runtime.linkml_model.meta import ( from linkml_runtime.linkml_model.meta import (
ArrayExpression, ArrayExpression,
SchemaDefinition, SchemaDefinition,
@ -60,7 +61,7 @@ class NWBPydanticGenerator(PydanticGenerator):
array_representations: List[ArrayRepresentation] = field( array_representations: List[ArrayRepresentation] = field(
default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC] default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC]
) )
black: bool = True black: bool = False
inlined: bool = True inlined: bool = True
emit_metadata: bool = True emit_metadata: bool = True
gen_classvars: bool = True gen_classvars: bool = True
@ -94,6 +95,15 @@ class NWBPydanticGenerator(PydanticGenerator):
if not base_range_subsumes_any_of: if not base_range_subsumes_any_of:
raise ValueError("Slot cannot have both range and any_of defined") raise ValueError("Slot cannot have both range and any_of defined")
def render(self) -> PydanticModule:
is_namespace = False
ns_annotation = self.schemaview.schema.annotations.get("is_namespace", None)
if ns_annotation:
is_namespace = ns_annotation.value
self.split_mode = SplitMode.FULL if is_namespace else SplitMode.AUTO
return super().render()
def before_generate_slot(self, slot: SlotDefinition, sv: SchemaView) -> SlotDefinition: def before_generate_slot(self, slot: SlotDefinition, sv: SchemaView) -> SlotDefinition:
""" """
Force some properties to be optional Force some properties to be optional

View file

@ -5,67 +5,6 @@ Monkeypatches to external modules
# ruff: noqa: ANN001 - not well defined types for this module # ruff: noqa: ANN001 - not well defined types for this module
def patch_schemaview() -> None:
"""
Patch schemaview to correctly resolve multiple layers of relative imports.
References:
Returns:
"""
from functools import lru_cache
from typing import List
from linkml_runtime.linkml_model import SchemaDefinitionName
from linkml_runtime.utils.schemaview import SchemaView
@lru_cache
def imports_closure(
self, imports: bool = True, traverse=True, inject_metadata=True
) -> List[SchemaDefinitionName]:
"""
Return all imports
:param traverse: if true, traverse recursively
:return: all schema names in the transitive reflexive imports closure
"""
if not imports:
return [self.schema.name]
if self.schema_map is None:
self.schema_map = {self.schema.name: self.schema}
closure = []
visited = set()
todo = [self.schema.name]
if not traverse:
return todo
while len(todo) > 0:
sn = todo.pop()
visited.add(sn)
if sn not in self.schema_map:
imported_schema = self.load_import(sn)
self.schema_map[sn] = imported_schema
s = self.schema_map[sn]
if sn not in closure:
closure.append(sn)
for i in s.imports:
if sn.startswith(".") and ":" not in i:
# prepend the relative part
i = "/".join(sn.split("/")[:-1]) + "/" + i
if i not in visited:
todo.append(i)
if inject_metadata:
for s in self.schema_map.values():
for x in {**s.classes, **s.enums, **s.slots, **s.subsets, **s.types}.values():
x.from_schema = s.id
for c in s.classes.values():
for a in c.attributes.values():
a.from_schema = s.id
return closure
SchemaView.imports_closure = imports_closure
def patch_array_expression() -> None: def patch_array_expression() -> None:
""" """
Allow SlotDefinitions to use `any_of` with `array` Allow SlotDefinitions to use `any_of` with `array`
@ -75,7 +14,7 @@ def patch_array_expression() -> None:
from dataclasses import field, make_dataclass from dataclasses import field, make_dataclass
from typing import Optional from typing import Optional
from linkml_runtime.linkml_model import meta from linkml_runtime.linkml_model import meta, types
new_dataclass = make_dataclass( new_dataclass = make_dataclass(
"AnonymousSlotExpression", "AnonymousSlotExpression",
@ -83,84 +22,9 @@ def patch_array_expression() -> None:
bases=(meta.AnonymousSlotExpression,), bases=(meta.AnonymousSlotExpression,),
) )
meta.AnonymousSlotExpression = new_dataclass meta.AnonymousSlotExpression = new_dataclass
types.AnonymousSlotExpression = new_dataclass
def patch_pretty_print() -> None:
"""
Fix the godforsaken linkml dataclass reprs
See: https://github.com/linkml/linkml-runtime/pull/314
"""
import re
import textwrap
from dataclasses import field, is_dataclass, make_dataclass
from pprint import pformat
from typing import Any
from linkml_runtime.linkml_model import meta
from linkml_runtime.utils.formatutils import items
def _pformat(fields: dict, cls_name: str, indent: str = " ") -> str:
"""
pretty format the fields of the items of a ``YAMLRoot`` object without the wonky
indentation of pformat.
see ``YAMLRoot.__repr__``.
formatting is similar to black - items at similar levels of nesting have similar levels
of indentation,
rather than getting placed at essentially random levels of indentation depending on what
came before them.
"""
res = []
total_len = 0
for key, val in fields:
if val == [] or val == {} or val is None:
continue
# pformat handles everything else that isn't a YAMLRoot object,
# but it sure does look ugly
# use it to split lines and as the thing of last resort, but otherwise indent = 0,
# we'll do that
val_str = pformat(val, indent=0, compact=True, sort_dicts=False)
# now we indent everything except the first line by indenting
# and then using regex to remove just the first indent
val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent))
# now recombine with the key in a format that can be re-eval'd
# into an object if indent is just whitespace
val_str = f"'{key}': " + val_str
# count the total length of this string so we know if we need to linebreak or not later
total_len += len(val_str)
res.append(val_str)
if total_len > 80:
inside = ",\n".join(res)
# we indent twice - once for the inner contents of every inner object, and one to
# offset from the root element.
# that keeps us from needing to be recursive except for the
# single pformat call
inside = textwrap.indent(inside, indent)
return cls_name + "({\n" + inside + "\n})"
else:
return cls_name + "({" + ", ".join(res) + "})"
def __repr__(self) -> str:
return _pformat(items(self), self.__class__.__name__)
for cls_name in dir(meta):
cls = getattr(meta, cls_name)
if is_dataclass(cls):
new_dataclass = make_dataclass(
cls.__name__,
fields=[("__dummy__", Any, field(default=None))],
bases=(cls,),
repr=False,
)
new_dataclass.__repr__ = __repr__
new_dataclass.__str__ = __repr__
setattr(meta, cls.__name__, new_dataclass)
def apply_patches() -> None: def apply_patches() -> None:
"""Apply all monkeypatches""" """Apply all monkeypatches"""
patch_schemaview()
patch_array_expression() patch_array_expression()
patch_pretty_print()

View file

@ -9,16 +9,19 @@ from importlib.abc import MetaPathFinder
from importlib.machinery import ModuleSpec from importlib.machinery import ModuleSpec
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import List, Optional, Type from typing import List, Optional, Type, TYPE_CHECKING
import multiprocessing as mp
from linkml.generators.pydanticgen.pydanticgen import SplitMode, _ensure_inits, _import_to_path from linkml.generators.pydanticgen.pydanticgen import SplitMode, _ensure_inits, _import_to_path
from linkml_runtime.linkml_model.meta import SchemaDefinition
from pydantic import BaseModel from pydantic import BaseModel
from nwb_linkml.generators.pydantic import NWBPydanticGenerator from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from nwb_linkml.maps.naming import module_case, version_module_case from nwb_linkml.maps.naming import module_case, version_module_case
from nwb_linkml.providers import LinkMLProvider, Provider from nwb_linkml.providers import LinkMLProvider, Provider
if TYPE_CHECKING:
from linkml_runtime.linkml_model.meta import SchemaDefinition
class PydanticProvider(Provider): class PydanticProvider(Provider):
""" """
@ -65,6 +68,7 @@ class PydanticProvider(Provider):
split: bool = True, split: bool = True,
dump: bool = True, dump: bool = True,
force: bool = False, force: bool = False,
parallel: bool = False,
**kwargs: dict, **kwargs: dict,
) -> str | List[str]: ) -> str | List[str]:
""" """
@ -88,6 +92,8 @@ class PydanticProvider(Provider):
otherwise just return the serialized string of built pydantic model otherwise just return the serialized string of built pydantic model
force (bool): If ``False`` (default), don't build the model if it already exists, force (bool): If ``False`` (default), don't build the model if it already exists,
if ``True`` , delete and rebuild any model if ``True`` , delete and rebuild any model
parallel (bool): If ``True``, build imported models using multiprocessing,
if ``False`` (default), don't.
**kwargs: Passed to :class:`.NWBPydanticGenerator` **kwargs: Passed to :class:`.NWBPydanticGenerator`
Returns: Returns:
@ -136,7 +142,9 @@ class PydanticProvider(Provider):
return serialized return serialized
def _build_split(self, path: Path, dump: bool, force: bool, **kwargs) -> List[str]: def _build_split(
self, path: Path, dump: bool, force: bool, parallel: bool = False, **kwargs
) -> List[str]:
# FIXME: This is messy as all fuck, we're just getting it to work again # FIXME: This is messy as all fuck, we're just getting it to work again
# so we can start iterating on the models themselves # so we can start iterating on the models themselves
res = [] res = []
@ -171,50 +179,83 @@ class PydanticProvider(Provider):
res.append(serialized) res.append(serialized)
# then each of the other schemas :) # then each of the other schemas :)
imported_schema: dict[str, SchemaDefinition] = { imported_schema: dict[str, "SchemaDefinition"] = {
gen.generate_module_import(sch): sch for sch in gen.schemaview.schema_map.values() gen.generate_module_import(sch): sch for sch in gen.schemaview.schema_map.values()
} }
for generated_import in [i for i in rendered.python_imports if i.schema]: generated_imports = [i for i in rendered.python_imports if i.schema]
import_file = (ns_file.parent / _import_to_path(generated_import.module)).resolve() # each task has an expected output file a corresponding SchemaDefinition
import_paths = [
(ns_file.parent / _import_to_path(an_import.module)).resolve()
for an_import in generated_imports
]
import_schemas = [
Path(path).parent / imported_schema[an_import.module].source_file
for an_import in generated_imports
]
tasks = [
(
import_path,
import_schema,
force,
self.SPLIT_PATTERN,
dump,
)
for import_path, import_schema in zip(import_paths, import_schemas)
]
if parallel:
with mp.Pool(min(mp.cpu_count(), len(tasks))) as pool:
mp_results = [pool.apply_async(self._generate_single, t) for t in tasks]
for result in mp_results:
res.append(result.get())
else:
for task in tasks:
res.append(self._generate_single(*task))
# make __init__.py files if we generated any files
if len(module_paths) > 0:
_ensure_inits(import_paths)
# then extra_inits that usually aren't generated bc we're one layer deeper
self._make_inits(ns_file)
return res
@staticmethod
def _generate_single(
import_file: Path,
# schema: "SchemaDefinition",
schema: Path,
force: bool,
split_pattern: str,
dump: bool,
) -> str:
"""
Interior generator method for _build_split to be called in parallel
.. TODO::
split up and consolidate this build behavior, very spaghetti.
"""
if not import_file.exists() or force: if not import_file.exists() or force:
import_file.parent.mkdir(exist_ok=True, parents=True) import_file.parent.mkdir(exist_ok=True, parents=True)
schema = imported_schema[generated_import.module]
is_namespace = False
ns_annotation = schema.annotations.get("is_namespace", None)
if ns_annotation:
is_namespace = ns_annotation.value
# fix schema source to absolute path so schemaview can find imports
schema.source_file = (
Path(gen.schemaview.schema.source_file).parent / schema.source_file
).resolve()
import_gen = NWBPydanticGenerator( import_gen = NWBPydanticGenerator(
schema, schema,
split=True, split=True,
split_pattern=self.SPLIT_PATTERN, split_pattern=split_pattern,
split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO,
) )
serialized = import_gen.serialize() serialized = import_gen.serialize()
if dump: if dump:
with open(import_file, "w") as ofile: with open(import_file, "w") as ofile:
ofile.write(serialized) ofile.write(serialized)
module_paths.append(import_file)
else: else:
with open(import_file) as ofile: with open(import_file) as ofile:
serialized = ofile.read() serialized = ofile.read()
return serialized
res.append(serialized)
# make __init__.py files if we generated any files
if len(module_paths) > 0:
_ensure_inits(module_paths)
# then extra_inits that usually aren't generated bc we're one layer deeper
self._make_inits(ns_file)
return res
def _make_inits(self, out_file: Path) -> None: def _make_inits(self, out_file: Path) -> None:
""" """

View file

@ -57,7 +57,7 @@ def make_tmp_dir(clear: bool = False) -> Path:
if tmp_dir.exists() and clear: if tmp_dir.exists() and clear:
for p in tmp_dir.iterdir(): for p in tmp_dir.iterdir():
if p.is_dir() and not p.name == "git": if p.is_dir() and not p.name == "git":
shutil.rmtree(tmp_dir) shutil.rmtree(p)
tmp_dir.mkdir(exist_ok=True) tmp_dir.mkdir(exist_ok=True)
return tmp_dir return tmp_dir
@ -139,7 +139,9 @@ def generate_versions(
for schema in ns_files: for schema in ns_files:
pbar_string = schema.parts[-3] pbar_string = schema.parts[-3]
build_progress.update(pydantic_task, action=pbar_string) build_progress.update(pydantic_task, action=pbar_string)
pydantic_provider.build(schema, versions=core_ns.versions, split=True) pydantic_provider.build(
schema, versions=core_ns.versions, split=True, parallel=True
)
build_progress.update(pydantic_task, advance=1) build_progress.update(pydantic_task, advance=1)
build_progress.update(pydantic_task, action="Built Pydantic") build_progress.update(pydantic_task, action="Built Pydantic")