perf enhancements to pydantic generation

This commit is contained in:
sneakers-the-rat 2023-09-06 22:48:47 -07:00
parent 7b97b749ef
commit 2b828082e0
6 changed files with 179 additions and 65 deletions

1
.gitignore vendored
View file

@ -161,3 +161,4 @@ cython_debug/
nwb.schema.json
__tmp__
prof

32
nwb_linkml/poetry.lock generated
View file

@ -492,6 +492,17 @@ files = [
[package.extras]
rewrite = ["tokenize-rt (>=3)"]
[[package]]
name = "gprof2dot"
version = "2022.7.29"
description = "Generate a dot graph from the output of several profilers."
optional = false
python-versions = ">=2.7"
files = [
{file = "gprof2dot-2022.7.29-py2.py3-none-any.whl", hash = "sha256:f165b3851d3c52ee4915eb1bd6cca571e5759823c2cd0f71a79bda93c2dc85d6"},
{file = "gprof2dot-2022.7.29.tar.gz", hash = "sha256:45b4d298bd36608fccf9511c3fd88a773f7a1abc04d6cd39445b11ba43133ec5"},
]
[[package]]
name = "graphviz"
version = "0.20.1"
@ -1554,6 +1565,25 @@ files = [
[package.dependencies]
pytest = ">=4.2.1"
[[package]]
name = "pytest-profiling"
version = "1.7.0"
description = "Profiling plugin for py.test"
optional = false
python-versions = "*"
files = [
{file = "pytest-profiling-1.7.0.tar.gz", hash = "sha256:93938f147662225d2b8bd5af89587b979652426a8a6ffd7e73ec4a23e24b7f29"},
{file = "pytest_profiling-1.7.0-py2.py3-none-any.whl", hash = "sha256:999cc9ac94f2e528e3f5d43465da277429984a1c237ae9818f8cfd0b06acb019"},
]
[package.dependencies]
gprof2dot = "*"
pytest = "*"
six = "*"
[package.extras]
tests = ["pytest-virtualenv"]
[[package]]
name = "python-dateutil"
version = "2.8.2"
@ -2342,4 +2372,4 @@ tests = ["coverage", "coveralls", "pytest", "pytest-cov", "pytest-depends", "pyt
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "5427416a9edebc2ab2c4f7f7c9779b2b9c7e4c1c2da5dcc0968ee633110973fa"
content-hash = "7ae9160a401b3bfa2f4535696ecf15e33815e356f7757ee611893c701485d24f"

View file

@ -29,11 +29,12 @@ pytest-md = {version = "^0.2.0", optional = true}
pytest-emoji = {version="^0.2.0", optional = true}
pytest-cov = {version = "^4.1.0", optional = true}
coveralls = {version = "^3.3.1", optional = true}
pytest-profiling = {version = "^1.7.0", optional = true}
[tool.poetry.extras]
tests = [
"pytest", "pytest-depends", "coverage", "pytest-md",
"pytest-emoji", "pytest-cov", "coveralls"
"pytest", "pytest-depends", "coverage", "pytest-md",
"pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling"
]
plot = ["dash", "dash-cytoscape"]
@ -49,6 +50,7 @@ pytest-md = "^0.2.0"
pytest-emoji = "^0.2.0"
pytest-cov = "^4.1.0"
coveralls = "^3.3.1"
pytest-profiling = "^1.7.0"
[tool.poetry.group.plot]
optional = true
@ -67,7 +69,8 @@ addopts = [
"--cov=nwb_linkml",
"--cov-append",
"--cov-config=.coveragerc",
"--emoji"
"--emoji",
"--profile"
]
testpaths = [
"tests",

View file

@ -19,10 +19,10 @@ The `serialize` method
import pdb
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Set, Tuple, Optional
from typing import List, Dict, Set, Tuple, Optional, TypedDict
import os, sys
from types import ModuleType
from copy import deepcopy
from copy import deepcopy, copy
import warnings
from nwb_linkml.maps import flat_to_npytyping
@ -36,9 +36,16 @@ SlotDefinitionName,
TypeDefinition,
ElementName
)
from linkml.generators.common.type_designators import (
get_accepted_type_designator_values,
get_type_designator_value,
)
from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.schemaview import SchemaView
from linkml_runtime.utils.compile_python import file_text
from linkml.utils.ifabsent_functions import ifabsent_value_declaration
from jinja2 import Template
@ -175,6 +182,7 @@ class {{ c.name }}
return template
@dataclass
class NWBPydanticGenerator(PydanticGenerator):
@ -184,6 +192,8 @@ class NWBPydanticGenerator(PydanticGenerator):
SKIP_CLASSES=('',)
# SKIP_CLASSES=('VectorData','VectorIndex')
split:bool=True
schema_map:Dict[str, SchemaDefinition]=None
def _locate_imports(
self,
@ -234,15 +244,16 @@ class NWBPydanticGenerator(PydanticGenerator):
self,
cls:ClassDefinition,
sv:SchemaView,
all_classes:dict[ClassDefinitionName, ClassDefinition]) -> List[str]:
all_classes:dict[ClassDefinitionName, ClassDefinition],
class_slots:dict[str, List[SlotDefinition]]
) -> List[str]:
"""Get the imports needed for a single class"""
needed_classes = []
needed_classes.append(cls.is_a)
# get needed classes used as ranges in class attributes
for slot_name in sv.class_slots(cls.name):
if slot_name in self.SKIP_SLOTS:
for slot in class_slots[cls.name]:
if slot.name in self.SKIP_SLOTS:
continue
slot = deepcopy(sv.induced_slot(slot_name, cls.name))
if slot.range in all_classes:
needed_classes.append(slot.range)
# handle when a range is a union of classes
@ -253,15 +264,25 @@ class NWBPydanticGenerator(PydanticGenerator):
return needed_classes
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
def _get_imports(self,
sv:SchemaView,
local_classes: List[ClassDefinition],
class_slots: Dict[str, List[SlotDefinition]]) -> Dict[str, List[str]]:
# import from local references, rather than serializing every class in every file
if not self.split:
# we are compiling this whole thing in one big file so we don't import anything
return {}
if 'namespace' in sv.schema.annotations.keys() and sv.schema.annotations['namespace']['value'] == 'True':
return self._get_namespace_imports(sv)
all_classes = sv.all_classes(imports=True)
local_classes = sv.all_classes(imports=False)
# local_classes = sv.all_classes(imports=False)
needed_classes = []
# find needed classes - is_a and slot ranges
for clsname, cls in local_classes.items():
for cls in local_classes:
# get imports for this class
needed_classes.extend(self._get_class_imports(cls, sv, all_classes))
needed_classes.extend(self._get_class_imports(cls, sv, all_classes, class_slots))
# remove duplicates and arraylikes
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
@ -272,27 +293,29 @@ class NWBPydanticGenerator(PydanticGenerator):
return imports
def _get_classes(self, sv:SchemaView, imports: Dict[str, List[str]]) -> List[ClassDefinition]:
module_classes = sv.all_classes(imports=False).values()
imported_classes = []
for classes in imports.values():
imported_classes.extend(classes)
module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
imported_classes = [c for c in imported_classes if sv.get_class(c) and sv.get_class(c).is_a != 'Arraylike']
sorted_classes = self.sort_classes(module_classes, imported_classes)
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
def _get_classes(self, sv:SchemaView) -> List[ClassDefinition]:
if self.split:
classes = sv.all_classes(imports=False).values()
else:
classes = sv.all_classes(imports=True).values()
# Don't want to generate classes when class_uri is linkml:Any, will
# just swap in typing.Any instead down below
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
return sorted_classes
classes = [c for c in list(classes) if c.is_a != 'Arraylike' and c.class_uri != "linkml:Any"]
return classes
def _get_class_slots(self, sv:SchemaView, cls:ClassDefinition) -> List[SlotDefinition]:
slots = []
for slot_name in sv.class_slots(cls.name):
if slot_name in self.SKIP_SLOTS:
continue
slots.append(sv.induced_slot(slot_name, cls.name))
return slots
def _build_class(self, class_original:ClassDefinition) -> ClassDefinition:
class_def: ClassDefinition
class_def = deepcopy(class_original)
class_def = copy(class_original)
class_def.name = camelcase(class_original.name)
if class_def.is_a:
class_def.is_a = camelcase(class_def.is_a)
@ -375,7 +398,7 @@ class NWBPydanticGenerator(PydanticGenerator):
return union
def sort_classes(self, clist: List[ClassDefinition], imports:List[str]) -> List[ClassDefinition]:
def sort_classes(self, clist: List[ClassDefinition], imports:Dict[str, List[str]]) -> List[ClassDefinition]:
"""
sort classes such that if C is a child of P then C appears after P in the list
@ -383,9 +406,14 @@ class NWBPydanticGenerator(PydanticGenerator):
Modified from original to allow for imported classes
"""
# unnest imports
imported_classes = []
for i in imports.values():
imported_classes.extend(i)
clist = list(clist)
clist = [c for c in clist if c.name not in self.SKIP_CLASSES]
slist = [] # sorted
slist = [] # type: List[ClassDefinition]
while len(clist) > 0:
can_add = False
for i in range(len(clist)):
@ -399,7 +427,7 @@ class NWBPydanticGenerator(PydanticGenerator):
can_add = True
else:
if set(candidates) <= set([p.name for p in slist] + imports):
if set(candidates) <= set([p.name for p in slist] + imported_classes):
can_add = True
if can_add:
@ -410,6 +438,9 @@ class NWBPydanticGenerator(PydanticGenerator):
raise ValueError(
f"could not find suitable element in {clist} that does not ref {slist}"
)
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
self.sorted_class_names += [camelcase(c.name) for c in slist]
return slist
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
@ -449,7 +480,47 @@ class NWBPydanticGenerator(PydanticGenerator):
parents[camelcase(class_def.name)] = class_parents
return parents
def get_predefined_slot_value(self, slot: SlotDefinition, class_def: ClassDefinition) -> Optional[str]:
"""
Modified from base pydantic generator to use already grabbed induced_slot from
already-grabbed and modified classes rather than doing a fresh iteration to
save time and respect changes already made elsewhere in the serialization routine
:return: Dictionary of dictionaries with predefined slot values for each class
"""
sv = self.schemaview
slot_value: Optional[str] = None
# for class_def in sv.all_classes().values():
# for slot_name in sv.class_slots(class_def.name):
# slot = sv.induced_slot(slot_name, class_def.name)
if slot.designates_type:
target_value = get_type_designator_value(sv, slot, class_def)
slot_value = f'"{target_value}"'
if slot.multivalued:
slot_value = (
"[" + slot_value + "]"
)
elif slot.ifabsent is not None:
value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot)
slot_value = value
# Multivalued slots that are either not inlined (just an identifier) or are
# inlined as lists should get default_factory list, if they're inlined but
# not as a list, that means a dictionary
elif slot.multivalued:
# this is slow, needs to do additional induced slot calls
#has_identifier_slot = self.range_class_has_identifier_slot(slot)
if slot.inlined and not slot.inlined_as_list: # and has_identifier_slot:
slot_value = "default_factory=dict"
else:
slot_value = "default_factory=list"
return slot_value
def serialize(self) -> str:
predefined_slot_values = {}
"""splitting up parent class :meth:`.get_predefined_slot_values`"""
if self.template_file is not None:
with open(self.template_file) as template_file:
template_obj = Template(template_file.read())
@ -458,34 +529,28 @@ class NWBPydanticGenerator(PydanticGenerator):
sv: SchemaView
sv = self.schemaview
sv.schema_map = self.schema_map
schema = sv.schema
pyschema = SchemaDefinition(
id=schema.id,
name=schema.name,
description=schema.description.replace('"', '\\"') if schema.description else None,
)
# test caching if import closure
enums = self.generate_enums(sv.all_enums())
# filter skipped enums
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
if self.split:
# import from local references, rather than serializing every class in every file
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True':
imports = self._get_namespace_imports(sv)
else:
imports = self._get_imports(sv)
sorted_classes = self._get_classes(sv, imports)
else:
sorted_classes = self.sort_classes(list(sv.all_classes().values()), [])
imports = {}
# Don't want to generate classes when class_uri is linkml:Any, will
# just swap in typing.Any instead down below
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
self.sorted_class_names = [camelcase(c.name) for c in sorted_classes]
classes = self._get_classes(sv)
# just induce slots once because that turns out to be expensive
class_slots = {} # type: Dict[str, List[SlotDefinition]]
for aclass in classes:
class_slots[aclass.name] = self._get_class_slots(sv, aclass)
# figure out what classes we need to imports
imports = self._get_imports(sv, classes, class_slots)
sorted_classes = self.sort_classes(classes, imports)
for class_original in sorted_classes:
# Generate class definition
@ -503,11 +568,12 @@ class NWBPydanticGenerator(PydanticGenerator):
del class_def.attributes[attribute]
class_name = class_original.name
for sn in sv.class_slots(class_name):
if sn in self.SKIP_SLOTS:
continue
# TODO: fix runtime, copy should not be necessary
s = deepcopy(sv.induced_slot(sn, class_name))
predefined_slot_values[camelcase(class_name)] = {}
for s in class_slots[class_name]:
sn = SlotDefinitionName(s.name)
predefined_slot_value = self.get_predefined_slot_value(s, class_def)
if predefined_slot_value is not None:
predefined_slot_values[camelcase(class_name)][s.name] = predefined_slot_value
# logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}')
s.name = underscore(s.name)
if s.description:
@ -557,7 +623,7 @@ class NWBPydanticGenerator(PydanticGenerator):
schema=pyschema,
underscore=underscore,
enums=enums,
predefined_slot_values=self.get_predefined_slot_values(),
predefined_slot_values=predefined_slot_values,
allow_extra=self.allow_extra,
metamodel_version=self.schema.metamodel_version,
version=self.schema.version,

View file

@ -1,11 +1,15 @@
import pdb
from pathlib import Path
from typing import Dict
import pytest
import warnings
from .fixtures import nwb_core_fixture, tmp_output_dir
from linkml_runtime.dumpers import yaml_dumper
from linkml_runtime.linkml_model import SchemaDefinition
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from linkml_runtime.loaders.yaml_loader import YAMLLoader
from nwb_linkml.lang_elements import NwbLangSchema
@ -22,29 +26,38 @@ def test_generate_core(nwb_core_fixture, tmp_output_dir):
output_file = tmp_output_dir / 'schema' / (schema.name + '.yaml')
yaml_dumper.dump(schema, output_file)
def load_schema_files(path: Path) -> Dict[str, SchemaDefinition]:
yaml_loader = YAMLLoader()
sch: SchemaDefinition
preloaded_schema = {}
for schema_path in (path / 'schema').glob('*.yaml'):
sch = yaml_loader.load(str(schema_path), target_class=SchemaDefinition)
preloaded_schema[sch.name] = sch
return preloaded_schema
@pytest.mark.depends(on=['test_generate_core'])
def test_generate_pydantic(tmp_output_dir):
# core_file = tmp_output_dir / 'core.yaml'
# pydantic_file = tmp_output_dir / 'core.py'
(tmp_output_dir / 'models').mkdir(exist_ok=True)
for schema in (tmp_output_dir / 'schema').glob('*.yaml'):
if not schema.exists():
preloaded_schema = load_schema_files(tmp_output_dir)
for schema_path in (tmp_output_dir / 'schema').glob('*.yaml'):
if not schema_path.exists():
continue
# python friendly name
python_name = schema.stem.replace('.', '_').replace('-','_')
python_name = schema_path.stem.replace('.', '_').replace('-','_')
pydantic_file = (schema.parent.parent / 'models' / python_name).with_suffix('.py')
pydantic_file = (schema_path.parent.parent / 'models' / python_name).with_suffix('.py')
generator = NWBPydanticGenerator(
str(schema),
str(schema_path),
pydantic_version='2',
emit_metadata=True,
gen_classvars=True,
gen_slots=True
gen_slots=True,
schema_map=preloaded_schema
)
gen_pydantic = generator.serialize()

View file

@ -5,6 +5,7 @@ from nwb_linkml.io.hdf5 import HDF5IO
def test_hdf_read():
NWBFILE = Path('/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb')
if not NWBFILE.exists():
return
io = HDF5IO(path=NWBFILE)
model = io.read('/general')