diff --git a/.gitignore b/.gitignore index 644d9b0..d3e4baf 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,4 @@ cython_debug/ nwb.schema.json __tmp__ +prof \ No newline at end of file diff --git a/nwb_linkml/poetry.lock b/nwb_linkml/poetry.lock index 0cc852c..520bad7 100644 --- a/nwb_linkml/poetry.lock +++ b/nwb_linkml/poetry.lock @@ -492,6 +492,17 @@ files = [ [package.extras] rewrite = ["tokenize-rt (>=3)"] +[[package]] +name = "gprof2dot" +version = "2022.7.29" +description = "Generate a dot graph from the output of several profilers." +optional = false +python-versions = ">=2.7" +files = [ + {file = "gprof2dot-2022.7.29-py2.py3-none-any.whl", hash = "sha256:f165b3851d3c52ee4915eb1bd6cca571e5759823c2cd0f71a79bda93c2dc85d6"}, + {file = "gprof2dot-2022.7.29.tar.gz", hash = "sha256:45b4d298bd36608fccf9511c3fd88a773f7a1abc04d6cd39445b11ba43133ec5"}, +] + [[package]] name = "graphviz" version = "0.20.1" @@ -1554,6 +1565,25 @@ files = [ [package.dependencies] pytest = ">=4.2.1" +[[package]] +name = "pytest-profiling" +version = "1.7.0" +description = "Profiling plugin for py.test" +optional = false +python-versions = "*" +files = [ + {file = "pytest-profiling-1.7.0.tar.gz", hash = "sha256:93938f147662225d2b8bd5af89587b979652426a8a6ffd7e73ec4a23e24b7f29"}, + {file = "pytest_profiling-1.7.0-py2.py3-none-any.whl", hash = "sha256:999cc9ac94f2e528e3f5d43465da277429984a1c237ae9818f8cfd0b06acb019"}, +] + +[package.dependencies] +gprof2dot = "*" +pytest = "*" +six = "*" + +[package.extras] +tests = ["pytest-virtualenv"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -2342,4 +2372,4 @@ tests = ["coverage", "coveralls", "pytest", "pytest-cov", "pytest-depends", "pyt [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "5427416a9edebc2ab2c4f7f7c9779b2b9c7e4c1c2da5dcc0968ee633110973fa" +content-hash = "7ae9160a401b3bfa2f4535696ecf15e33815e356f7757ee611893c701485d24f" diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 518fffa..ccf43f1 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -29,11 +29,12 @@ pytest-md = {version = "^0.2.0", optional = true} pytest-emoji = {version="^0.2.0", optional = true} pytest-cov = {version = "^4.1.0", optional = true} coveralls = {version = "^3.3.1", optional = true} +pytest-profiling = {version = "^1.7.0", optional = true} [tool.poetry.extras] tests = [ - "pytest", "pytest-depends", "coverage", "pytest-md", - "pytest-emoji", "pytest-cov", "coveralls" + "pytest", "pytest-depends", "coverage", "pytest-md", + "pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling" ] plot = ["dash", "dash-cytoscape"] @@ -49,6 +50,7 @@ pytest-md = "^0.2.0" pytest-emoji = "^0.2.0" pytest-cov = "^4.1.0" coveralls = "^3.3.1" +pytest-profiling = "^1.7.0" [tool.poetry.group.plot] optional = true @@ -67,7 +69,8 @@ addopts = [ "--cov=nwb_linkml", "--cov-append", "--cov-config=.coveragerc", - "--emoji" + "--emoji", + "--profile" ] testpaths = [ "tests", diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index e571397..2c64bc1 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -19,10 +19,10 @@ The `serialize` method import pdb from dataclasses import dataclass from pathlib import Path -from typing import List, Dict, Set, Tuple, Optional +from typing import List, Dict, Set, Tuple, Optional, TypedDict import os, sys from types import ModuleType -from copy import deepcopy +from copy import deepcopy, copy import warnings from nwb_linkml.maps import flat_to_npytyping @@ -36,9 +36,16 @@ SlotDefinitionName, TypeDefinition, ElementName ) +from linkml.generators.common.type_designators import ( + get_accepted_type_designator_values, + get_type_designator_value, +) from linkml_runtime.utils.formatutils import camelcase, underscore from linkml_runtime.utils.schemaview import SchemaView from linkml_runtime.utils.compile_python import file_text +from linkml.utils.ifabsent_functions import ifabsent_value_declaration + + from jinja2 import Template @@ -175,6 +182,7 @@ class {{ c.name }} return template + @dataclass class NWBPydanticGenerator(PydanticGenerator): @@ -184,6 +192,8 @@ class NWBPydanticGenerator(PydanticGenerator): SKIP_CLASSES=('',) # SKIP_CLASSES=('VectorData','VectorIndex') split:bool=True + schema_map:Dict[str, SchemaDefinition]=None + def _locate_imports( self, @@ -234,15 +244,16 @@ class NWBPydanticGenerator(PydanticGenerator): self, cls:ClassDefinition, sv:SchemaView, - all_classes:dict[ClassDefinitionName, ClassDefinition]) -> List[str]: + all_classes:dict[ClassDefinitionName, ClassDefinition], + class_slots:dict[str, List[SlotDefinition]] + ) -> List[str]: """Get the imports needed for a single class""" needed_classes = [] needed_classes.append(cls.is_a) # get needed classes used as ranges in class attributes - for slot_name in sv.class_slots(cls.name): - if slot_name in self.SKIP_SLOTS: + for slot in class_slots[cls.name]: + if slot.name in self.SKIP_SLOTS: continue - slot = deepcopy(sv.induced_slot(slot_name, cls.name)) if slot.range in all_classes: needed_classes.append(slot.range) # handle when a range is a union of classes @@ -253,15 +264,25 @@ class NWBPydanticGenerator(PydanticGenerator): return needed_classes - def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]: + def _get_imports(self, + sv:SchemaView, + local_classes: List[ClassDefinition], + class_slots: Dict[str, List[SlotDefinition]]) -> Dict[str, List[str]]: + # import from local references, rather than serializing every class in every file + if not self.split: + # we are compiling this whole thing in one big file so we don't import anything + return {} + if 'namespace' in sv.schema.annotations.keys() and sv.schema.annotations['namespace']['value'] == 'True': + return self._get_namespace_imports(sv) + all_classes = sv.all_classes(imports=True) - local_classes = sv.all_classes(imports=False) + # local_classes = sv.all_classes(imports=False) needed_classes = [] # find needed classes - is_a and slot ranges - for clsname, cls in local_classes.items(): + for cls in local_classes: # get imports for this class - needed_classes.extend(self._get_class_imports(cls, sv, all_classes)) + needed_classes.extend(self._get_class_imports(cls, sv, all_classes, class_slots)) # remove duplicates and arraylikes needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike'] @@ -272,27 +293,29 @@ class NWBPydanticGenerator(PydanticGenerator): return imports - def _get_classes(self, sv:SchemaView, imports: Dict[str, List[str]]) -> List[ClassDefinition]: - module_classes = sv.all_classes(imports=False).values() - imported_classes = [] - for classes in imports.values(): - imported_classes.extend(classes) - - module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike'] - imported_classes = [c for c in imported_classes if sv.get_class(c) and sv.get_class(c).is_a != 'Arraylike'] - - sorted_classes = self.sort_classes(module_classes, imported_classes) - self.sorted_class_names = [camelcase(cname) for cname in imported_classes] - self.sorted_class_names += [camelcase(c.name) for c in sorted_classes] + def _get_classes(self, sv:SchemaView) -> List[ClassDefinition]: + if self.split: + classes = sv.all_classes(imports=False).values() + else: + classes = sv.all_classes(imports=True).values() # Don't want to generate classes when class_uri is linkml:Any, will # just swap in typing.Any instead down below - sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"] - return sorted_classes + classes = [c for c in list(classes) if c.is_a != 'Arraylike' and c.class_uri != "linkml:Any"] + + return classes + + def _get_class_slots(self, sv:SchemaView, cls:ClassDefinition) -> List[SlotDefinition]: + slots = [] + for slot_name in sv.class_slots(cls.name): + if slot_name in self.SKIP_SLOTS: + continue + slots.append(sv.induced_slot(slot_name, cls.name)) + return slots def _build_class(self, class_original:ClassDefinition) -> ClassDefinition: class_def: ClassDefinition - class_def = deepcopy(class_original) + class_def = copy(class_original) class_def.name = camelcase(class_original.name) if class_def.is_a: class_def.is_a = camelcase(class_def.is_a) @@ -375,7 +398,7 @@ class NWBPydanticGenerator(PydanticGenerator): return union - def sort_classes(self, clist: List[ClassDefinition], imports:List[str]) -> List[ClassDefinition]: + def sort_classes(self, clist: List[ClassDefinition], imports:Dict[str, List[str]]) -> List[ClassDefinition]: """ sort classes such that if C is a child of P then C appears after P in the list @@ -383,9 +406,14 @@ class NWBPydanticGenerator(PydanticGenerator): Modified from original to allow for imported classes """ + # unnest imports + imported_classes = [] + for i in imports.values(): + imported_classes.extend(i) + clist = list(clist) clist = [c for c in clist if c.name not in self.SKIP_CLASSES] - slist = [] # sorted + slist = [] # type: List[ClassDefinition] while len(clist) > 0: can_add = False for i in range(len(clist)): @@ -399,7 +427,7 @@ class NWBPydanticGenerator(PydanticGenerator): can_add = True else: - if set(candidates) <= set([p.name for p in slist] + imports): + if set(candidates) <= set([p.name for p in slist] + imported_classes): can_add = True if can_add: @@ -410,6 +438,9 @@ class NWBPydanticGenerator(PydanticGenerator): raise ValueError( f"could not find suitable element in {clist} that does not ref {slist}" ) + + self.sorted_class_names = [camelcase(cname) for cname in imported_classes] + self.sorted_class_names += [camelcase(c.name) for c in slist] return slist def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str: @@ -449,7 +480,47 @@ class NWBPydanticGenerator(PydanticGenerator): parents[camelcase(class_def.name)] = class_parents return parents + def get_predefined_slot_value(self, slot: SlotDefinition, class_def: ClassDefinition) -> Optional[str]: + """ + Modified from base pydantic generator to use already grabbed induced_slot from + already-grabbed and modified classes rather than doing a fresh iteration to + save time and respect changes already made elsewhere in the serialization routine + + :return: Dictionary of dictionaries with predefined slot values for each class + """ + sv = self.schemaview + slot_value: Optional[str] = None + # for class_def in sv.all_classes().values(): + # for slot_name in sv.class_slots(class_def.name): + # slot = sv.induced_slot(slot_name, class_def.name) + if slot.designates_type: + target_value = get_type_designator_value(sv, slot, class_def) + slot_value = f'"{target_value}"' + if slot.multivalued: + slot_value = ( + "[" + slot_value + "]" + ) + elif slot.ifabsent is not None: + value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot) + slot_value = value + # Multivalued slots that are either not inlined (just an identifier) or are + # inlined as lists should get default_factory list, if they're inlined but + # not as a list, that means a dictionary + elif slot.multivalued: + # this is slow, needs to do additional induced slot calls + #has_identifier_slot = self.range_class_has_identifier_slot(slot) + + if slot.inlined and not slot.inlined_as_list: # and has_identifier_slot: + slot_value = "default_factory=dict" + else: + slot_value = "default_factory=list" + + return slot_value + def serialize(self) -> str: + predefined_slot_values = {} + """splitting up parent class :meth:`.get_predefined_slot_values`""" + if self.template_file is not None: with open(self.template_file) as template_file: template_obj = Template(template_file.read()) @@ -458,34 +529,28 @@ class NWBPydanticGenerator(PydanticGenerator): sv: SchemaView sv = self.schemaview + sv.schema_map = self.schema_map schema = sv.schema pyschema = SchemaDefinition( id=schema.id, name=schema.name, description=schema.description.replace('"', '\\"') if schema.description else None, ) + # test caching if import closure enums = self.generate_enums(sv.all_enums()) # filter skipped enums enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM} - if self.split: - # import from local references, rather than serializing every class in every file - if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True': - imports = self._get_namespace_imports(sv) - else: - imports = self._get_imports(sv) - - sorted_classes = self._get_classes(sv, imports) - else: - sorted_classes = self.sort_classes(list(sv.all_classes().values()), []) - imports = {} - - # Don't want to generate classes when class_uri is linkml:Any, will - # just swap in typing.Any instead down below - sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"] - self.sorted_class_names = [camelcase(c.name) for c in sorted_classes] + classes = self._get_classes(sv) + # just induce slots once because that turns out to be expensive + class_slots = {} # type: Dict[str, List[SlotDefinition]] + for aclass in classes: + class_slots[aclass.name] = self._get_class_slots(sv, aclass) + # figure out what classes we need to imports + imports = self._get_imports(sv, classes, class_slots) + sorted_classes = self.sort_classes(classes, imports) for class_original in sorted_classes: # Generate class definition @@ -503,11 +568,12 @@ class NWBPydanticGenerator(PydanticGenerator): del class_def.attributes[attribute] class_name = class_original.name - for sn in sv.class_slots(class_name): - if sn in self.SKIP_SLOTS: - continue - # TODO: fix runtime, copy should not be necessary - s = deepcopy(sv.induced_slot(sn, class_name)) + predefined_slot_values[camelcase(class_name)] = {} + for s in class_slots[class_name]: + sn = SlotDefinitionName(s.name) + predefined_slot_value = self.get_predefined_slot_value(s, class_def) + if predefined_slot_value is not None: + predefined_slot_values[camelcase(class_name)][s.name] = predefined_slot_value # logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}') s.name = underscore(s.name) if s.description: @@ -557,7 +623,7 @@ class NWBPydanticGenerator(PydanticGenerator): schema=pyschema, underscore=underscore, enums=enums, - predefined_slot_values=self.get_predefined_slot_values(), + predefined_slot_values=predefined_slot_values, allow_extra=self.allow_extra, metamodel_version=self.schema.metamodel_version, version=self.schema.version, diff --git a/nwb_linkml/tests/test_generate.py b/nwb_linkml/tests/test_generate.py index cf8a2d6..d950291 100644 --- a/nwb_linkml/tests/test_generate.py +++ b/nwb_linkml/tests/test_generate.py @@ -1,11 +1,15 @@ import pdb +from pathlib import Path +from typing import Dict import pytest import warnings from .fixtures import nwb_core_fixture, tmp_output_dir from linkml_runtime.dumpers import yaml_dumper +from linkml_runtime.linkml_model import SchemaDefinition from nwb_linkml.generators.pydantic import NWBPydanticGenerator +from linkml_runtime.loaders.yaml_loader import YAMLLoader from nwb_linkml.lang_elements import NwbLangSchema @@ -22,29 +26,38 @@ def test_generate_core(nwb_core_fixture, tmp_output_dir): output_file = tmp_output_dir / 'schema' / (schema.name + '.yaml') yaml_dumper.dump(schema, output_file) +def load_schema_files(path: Path) -> Dict[str, SchemaDefinition]: + yaml_loader = YAMLLoader() + sch: SchemaDefinition + preloaded_schema = {} + for schema_path in (path / 'schema').glob('*.yaml'): + sch = yaml_loader.load(str(schema_path), target_class=SchemaDefinition) + preloaded_schema[sch.name] = sch + return preloaded_schema + + @pytest.mark.depends(on=['test_generate_core']) def test_generate_pydantic(tmp_output_dir): - - # core_file = tmp_output_dir / 'core.yaml' - # pydantic_file = tmp_output_dir / 'core.py' - (tmp_output_dir / 'models').mkdir(exist_ok=True) - for schema in (tmp_output_dir / 'schema').glob('*.yaml'): - if not schema.exists(): + preloaded_schema = load_schema_files(tmp_output_dir) + + for schema_path in (tmp_output_dir / 'schema').glob('*.yaml'): + if not schema_path.exists(): continue # python friendly name - python_name = schema.stem.replace('.', '_').replace('-','_') + python_name = schema_path.stem.replace('.', '_').replace('-','_') - pydantic_file = (schema.parent.parent / 'models' / python_name).with_suffix('.py') + pydantic_file = (schema_path.parent.parent / 'models' / python_name).with_suffix('.py') generator = NWBPydanticGenerator( - str(schema), + str(schema_path), pydantic_version='2', emit_metadata=True, gen_classvars=True, - gen_slots=True + gen_slots=True, + schema_map=preloaded_schema ) gen_pydantic = generator.serialize() diff --git a/nwb_linkml/tests/test_io/test_io_hdf5.py b/nwb_linkml/tests/test_io/test_io_hdf5.py index 65d1950..bae1578 100644 --- a/nwb_linkml/tests/test_io/test_io_hdf5.py +++ b/nwb_linkml/tests/test_io/test_io_hdf5.py @@ -5,6 +5,7 @@ from nwb_linkml.io.hdf5 import HDF5IO def test_hdf_read(): NWBFILE = Path('/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb') - + if not NWBFILE.exists(): + return io = HDF5IO(path=NWBFILE) model = io.read('/general')