diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 9c8e2af..c7611e2 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -16,6 +16,8 @@ dependencies = [ "rich>=13.5.2", #"linkml>=1.7.10", "linkml @ git+https://github.com/sneakers-the-rat/linkml@nwb-linkml", + # until recursive imports gets released + "linkml-runtime @ git+https://github.com/linkml/linkml-runtime@main", "pydantic>=2.3.0", "h5py>=3.9.0", "pydantic-settings>=2.0.3", diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 0cdfd23..5dec150 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -16,6 +16,7 @@ from linkml.generators import PydanticGenerator from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.build import ClassResult, SlotResult from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule +from linkml.generators.pydanticgen.pydanticgen import SplitMode from linkml_runtime.linkml_model.meta import ( ArrayExpression, SchemaDefinition, @@ -60,7 +61,7 @@ class NWBPydanticGenerator(PydanticGenerator): array_representations: List[ArrayRepresentation] = field( default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC] ) - black: bool = True + black: bool = False inlined: bool = True emit_metadata: bool = True gen_classvars: bool = True @@ -94,6 +95,15 @@ class NWBPydanticGenerator(PydanticGenerator): if not base_range_subsumes_any_of: raise ValueError("Slot cannot have both range and any_of defined") + def render(self) -> PydanticModule: + is_namespace = False + ns_annotation = self.schemaview.schema.annotations.get("is_namespace", None) + if ns_annotation: + is_namespace = ns_annotation.value + self.split_mode = SplitMode.FULL if is_namespace else SplitMode.AUTO + + return super().render() + def before_generate_slot(self, slot: SlotDefinition, sv: SchemaView) -> SlotDefinition: """ Force some properties to be optional diff --git a/nwb_linkml/src/nwb_linkml/monkeypatch.py b/nwb_linkml/src/nwb_linkml/monkeypatch.py index 6222089..661ec7a 100644 --- a/nwb_linkml/src/nwb_linkml/monkeypatch.py +++ b/nwb_linkml/src/nwb_linkml/monkeypatch.py @@ -5,67 +5,6 @@ Monkeypatches to external modules # ruff: noqa: ANN001 - not well defined types for this module -def patch_schemaview() -> None: - """ - Patch schemaview to correctly resolve multiple layers of relative imports. - - References: - - Returns: - - """ - from functools import lru_cache - from typing import List - - from linkml_runtime.linkml_model import SchemaDefinitionName - from linkml_runtime.utils.schemaview import SchemaView - - @lru_cache - def imports_closure( - self, imports: bool = True, traverse=True, inject_metadata=True - ) -> List[SchemaDefinitionName]: - """ - Return all imports - - :param traverse: if true, traverse recursively - :return: all schema names in the transitive reflexive imports closure - """ - if not imports: - return [self.schema.name] - if self.schema_map is None: - self.schema_map = {self.schema.name: self.schema} - closure = [] - visited = set() - todo = [self.schema.name] - if not traverse: - return todo - while len(todo) > 0: - sn = todo.pop() - visited.add(sn) - if sn not in self.schema_map: - imported_schema = self.load_import(sn) - self.schema_map[sn] = imported_schema - s = self.schema_map[sn] - if sn not in closure: - closure.append(sn) - for i in s.imports: - if sn.startswith(".") and ":" not in i: - # prepend the relative part - i = "/".join(sn.split("/")[:-1]) + "/" + i - if i not in visited: - todo.append(i) - if inject_metadata: - for s in self.schema_map.values(): - for x in {**s.classes, **s.enums, **s.slots, **s.subsets, **s.types}.values(): - x.from_schema = s.id - for c in s.classes.values(): - for a in c.attributes.values(): - a.from_schema = s.id - return closure - - SchemaView.imports_closure = imports_closure - - def patch_array_expression() -> None: """ Allow SlotDefinitions to use `any_of` with `array` @@ -75,7 +14,7 @@ def patch_array_expression() -> None: from dataclasses import field, make_dataclass from typing import Optional - from linkml_runtime.linkml_model import meta + from linkml_runtime.linkml_model import meta, types new_dataclass = make_dataclass( "AnonymousSlotExpression", @@ -83,84 +22,9 @@ def patch_array_expression() -> None: bases=(meta.AnonymousSlotExpression,), ) meta.AnonymousSlotExpression = new_dataclass - - -def patch_pretty_print() -> None: - """ - Fix the godforsaken linkml dataclass reprs - - See: https://github.com/linkml/linkml-runtime/pull/314 - """ - import re - import textwrap - from dataclasses import field, is_dataclass, make_dataclass - from pprint import pformat - from typing import Any - - from linkml_runtime.linkml_model import meta - from linkml_runtime.utils.formatutils import items - - def _pformat(fields: dict, cls_name: str, indent: str = " ") -> str: - """ - pretty format the fields of the items of a ``YAMLRoot`` object without the wonky - indentation of pformat. - see ``YAMLRoot.__repr__``. - formatting is similar to black - items at similar levels of nesting have similar levels - of indentation, - rather than getting placed at essentially random levels of indentation depending on what - came before them. - """ - res = [] - total_len = 0 - for key, val in fields: - if val == [] or val == {} or val is None: - continue - # pformat handles everything else that isn't a YAMLRoot object, - # but it sure does look ugly - # use it to split lines and as the thing of last resort, but otherwise indent = 0, - # we'll do that - val_str = pformat(val, indent=0, compact=True, sort_dicts=False) - # now we indent everything except the first line by indenting - # and then using regex to remove just the first indent - val_str = re.sub(rf"\A{re.escape(indent)}", "", textwrap.indent(val_str, indent)) - # now recombine with the key in a format that can be re-eval'd - # into an object if indent is just whitespace - val_str = f"'{key}': " + val_str - - # count the total length of this string so we know if we need to linebreak or not later - total_len += len(val_str) - res.append(val_str) - - if total_len > 80: - inside = ",\n".join(res) - # we indent twice - once for the inner contents of every inner object, and one to - # offset from the root element. - # that keeps us from needing to be recursive except for the - # single pformat call - inside = textwrap.indent(inside, indent) - return cls_name + "({\n" + inside + "\n})" - else: - return cls_name + "({" + ", ".join(res) + "})" - - def __repr__(self) -> str: - return _pformat(items(self), self.__class__.__name__) - - for cls_name in dir(meta): - cls = getattr(meta, cls_name) - if is_dataclass(cls): - new_dataclass = make_dataclass( - cls.__name__, - fields=[("__dummy__", Any, field(default=None))], - bases=(cls,), - repr=False, - ) - new_dataclass.__repr__ = __repr__ - new_dataclass.__str__ = __repr__ - setattr(meta, cls.__name__, new_dataclass) + types.AnonymousSlotExpression = new_dataclass def apply_patches() -> None: """Apply all monkeypatches""" - patch_schemaview() patch_array_expression() - patch_pretty_print() diff --git a/nwb_linkml/src/nwb_linkml/providers/pydantic.py b/nwb_linkml/src/nwb_linkml/providers/pydantic.py index 3e379dd..21af04f 100644 --- a/nwb_linkml/src/nwb_linkml/providers/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/providers/pydantic.py @@ -9,16 +9,19 @@ from importlib.abc import MetaPathFinder from importlib.machinery import ModuleSpec from pathlib import Path from types import ModuleType -from typing import List, Optional, Type +from typing import List, Optional, Type, TYPE_CHECKING +import multiprocessing as mp from linkml.generators.pydanticgen.pydanticgen import SplitMode, _ensure_inits, _import_to_path -from linkml_runtime.linkml_model.meta import SchemaDefinition from pydantic import BaseModel from nwb_linkml.generators.pydantic import NWBPydanticGenerator from nwb_linkml.maps.naming import module_case, version_module_case from nwb_linkml.providers import LinkMLProvider, Provider +if TYPE_CHECKING: + from linkml_runtime.linkml_model.meta import SchemaDefinition + class PydanticProvider(Provider): """ @@ -65,6 +68,7 @@ class PydanticProvider(Provider): split: bool = True, dump: bool = True, force: bool = False, + parallel: bool = False, **kwargs: dict, ) -> str | List[str]: """ @@ -88,6 +92,8 @@ class PydanticProvider(Provider): otherwise just return the serialized string of built pydantic model force (bool): If ``False`` (default), don't build the model if it already exists, if ``True`` , delete and rebuild any model + parallel (bool): If ``True``, build imported models using multiprocessing, + if ``False`` (default), don't. **kwargs: Passed to :class:`.NWBPydanticGenerator` Returns: @@ -136,7 +142,9 @@ class PydanticProvider(Provider): return serialized - def _build_split(self, path: Path, dump: bool, force: bool, **kwargs) -> List[str]: + def _build_split( + self, path: Path, dump: bool, force: bool, parallel: bool = False, **kwargs + ) -> List[str]: # FIXME: This is messy as all fuck, we're just getting it to work again # so we can start iterating on the models themselves res = [] @@ -171,51 +179,84 @@ class PydanticProvider(Provider): res.append(serialized) # then each of the other schemas :) - imported_schema: dict[str, SchemaDefinition] = { + imported_schema: dict[str, "SchemaDefinition"] = { gen.generate_module_import(sch): sch for sch in gen.schemaview.schema_map.values() } - for generated_import in [i for i in rendered.python_imports if i.schema]: - import_file = (ns_file.parent / _import_to_path(generated_import.module)).resolve() + generated_imports = [i for i in rendered.python_imports if i.schema] + # each task has an expected output file a corresponding SchemaDefinition + import_paths = [ + (ns_file.parent / _import_to_path(an_import.module)).resolve() + for an_import in generated_imports + ] + import_schemas = [ + Path(path).parent / imported_schema[an_import.module].source_file + for an_import in generated_imports + ] - if not import_file.exists() or force: - import_file.parent.mkdir(exist_ok=True, parents=True) - schema = imported_schema[generated_import.module] - is_namespace = False - ns_annotation = schema.annotations.get("is_namespace", None) - if ns_annotation: - is_namespace = ns_annotation.value + tasks = [ + ( + import_path, + import_schema, + force, + self.SPLIT_PATTERN, + dump, + ) + for import_path, import_schema in zip(import_paths, import_schemas) + ] - # fix schema source to absolute path so schemaview can find imports - schema.source_file = ( - Path(gen.schemaview.schema.source_file).parent / schema.source_file - ).resolve() - - import_gen = NWBPydanticGenerator( - schema, - split=True, - split_pattern=self.SPLIT_PATTERN, - split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO, - ) - serialized = import_gen.serialize() - if dump: - with open(import_file, "w") as ofile: - ofile.write(serialized) - module_paths.append(import_file) - - else: - with open(import_file) as ofile: - serialized = ofile.read() - - res.append(serialized) + if parallel: + with mp.Pool(min(mp.cpu_count(), len(tasks))) as pool: + mp_results = [pool.apply_async(self._generate_single, t) for t in tasks] + for result in mp_results: + res.append(result.get()) + else: + for task in tasks: + res.append(self._generate_single(*task)) # make __init__.py files if we generated any files if len(module_paths) > 0: - _ensure_inits(module_paths) + _ensure_inits(import_paths) # then extra_inits that usually aren't generated bc we're one layer deeper self._make_inits(ns_file) return res + @staticmethod + def _generate_single( + import_file: Path, + # schema: "SchemaDefinition", + schema: Path, + force: bool, + split_pattern: str, + dump: bool, + ) -> str: + """ + Interior generator method for _build_split to be called in parallel + + .. TODO:: + + split up and consolidate this build behavior, very spaghetti. + + """ + + if not import_file.exists() or force: + import_file.parent.mkdir(exist_ok=True, parents=True) + + import_gen = NWBPydanticGenerator( + schema, + split=True, + split_pattern=split_pattern, + ) + serialized = import_gen.serialize() + if dump: + with open(import_file, "w") as ofile: + ofile.write(serialized) + + else: + with open(import_file) as ofile: + serialized = ofile.read() + return serialized + def _make_inits(self, out_file: Path) -> None: """ Make __init__.py files for the directory a model is output to and its immediate parent. diff --git a/scripts/generate_core.py b/scripts/generate_core.py index ad593cf..1e03897 100644 --- a/scripts/generate_core.py +++ b/scripts/generate_core.py @@ -57,7 +57,7 @@ def make_tmp_dir(clear: bool = False) -> Path: if tmp_dir.exists() and clear: for p in tmp_dir.iterdir(): if p.is_dir() and not p.name == "git": - shutil.rmtree(tmp_dir) + shutil.rmtree(p) tmp_dir.mkdir(exist_ok=True) return tmp_dir @@ -139,7 +139,9 @@ def generate_versions( for schema in ns_files: pbar_string = schema.parts[-3] build_progress.update(pydantic_task, action=pbar_string) - pydantic_provider.build(schema, versions=core_ns.versions, split=True) + pydantic_provider.build( + schema, versions=core_ns.versions, split=True, parallel=True + ) build_progress.update(pydantic_task, advance=1) build_progress.update(pydantic_task, action="Built Pydantic")