perf enhancements to pydantic generation

This commit is contained in:
sneakers-the-rat 2023-09-06 22:48:47 -07:00
parent 7b97b749ef
commit 2b828082e0
6 changed files with 179 additions and 65 deletions

1
.gitignore vendored
View file

@ -161,3 +161,4 @@ cython_debug/
nwb.schema.json nwb.schema.json
__tmp__ __tmp__
prof

32
nwb_linkml/poetry.lock generated
View file

@ -492,6 +492,17 @@ files = [
[package.extras] [package.extras]
rewrite = ["tokenize-rt (>=3)"] rewrite = ["tokenize-rt (>=3)"]
[[package]]
name = "gprof2dot"
version = "2022.7.29"
description = "Generate a dot graph from the output of several profilers."
optional = false
python-versions = ">=2.7"
files = [
{file = "gprof2dot-2022.7.29-py2.py3-none-any.whl", hash = "sha256:f165b3851d3c52ee4915eb1bd6cca571e5759823c2cd0f71a79bda93c2dc85d6"},
{file = "gprof2dot-2022.7.29.tar.gz", hash = "sha256:45b4d298bd36608fccf9511c3fd88a773f7a1abc04d6cd39445b11ba43133ec5"},
]
[[package]] [[package]]
name = "graphviz" name = "graphviz"
version = "0.20.1" version = "0.20.1"
@ -1554,6 +1565,25 @@ files = [
[package.dependencies] [package.dependencies]
pytest = ">=4.2.1" pytest = ">=4.2.1"
[[package]]
name = "pytest-profiling"
version = "1.7.0"
description = "Profiling plugin for py.test"
optional = false
python-versions = "*"
files = [
{file = "pytest-profiling-1.7.0.tar.gz", hash = "sha256:93938f147662225d2b8bd5af89587b979652426a8a6ffd7e73ec4a23e24b7f29"},
{file = "pytest_profiling-1.7.0-py2.py3-none-any.whl", hash = "sha256:999cc9ac94f2e528e3f5d43465da277429984a1c237ae9818f8cfd0b06acb019"},
]
[package.dependencies]
gprof2dot = "*"
pytest = "*"
six = "*"
[package.extras]
tests = ["pytest-virtualenv"]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.8.2" version = "2.8.2"
@ -2342,4 +2372,4 @@ tests = ["coverage", "coveralls", "pytest", "pytest-cov", "pytest-depends", "pyt
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "5427416a9edebc2ab2c4f7f7c9779b2b9c7e4c1c2da5dcc0968ee633110973fa" content-hash = "7ae9160a401b3bfa2f4535696ecf15e33815e356f7757ee611893c701485d24f"

View file

@ -29,11 +29,12 @@ pytest-md = {version = "^0.2.0", optional = true}
pytest-emoji = {version="^0.2.0", optional = true} pytest-emoji = {version="^0.2.0", optional = true}
pytest-cov = {version = "^4.1.0", optional = true} pytest-cov = {version = "^4.1.0", optional = true}
coveralls = {version = "^3.3.1", optional = true} coveralls = {version = "^3.3.1", optional = true}
pytest-profiling = {version = "^1.7.0", optional = true}
[tool.poetry.extras] [tool.poetry.extras]
tests = [ tests = [
"pytest", "pytest-depends", "coverage", "pytest-md", "pytest", "pytest-depends", "coverage", "pytest-md",
"pytest-emoji", "pytest-cov", "coveralls" "pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling"
] ]
plot = ["dash", "dash-cytoscape"] plot = ["dash", "dash-cytoscape"]
@ -49,6 +50,7 @@ pytest-md = "^0.2.0"
pytest-emoji = "^0.2.0" pytest-emoji = "^0.2.0"
pytest-cov = "^4.1.0" pytest-cov = "^4.1.0"
coveralls = "^3.3.1" coveralls = "^3.3.1"
pytest-profiling = "^1.7.0"
[tool.poetry.group.plot] [tool.poetry.group.plot]
optional = true optional = true
@ -67,7 +69,8 @@ addopts = [
"--cov=nwb_linkml", "--cov=nwb_linkml",
"--cov-append", "--cov-append",
"--cov-config=.coveragerc", "--cov-config=.coveragerc",
"--emoji" "--emoji",
"--profile"
] ]
testpaths = [ testpaths = [
"tests", "tests",

View file

@ -19,10 +19,10 @@ The `serialize` method
import pdb import pdb
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import List, Dict, Set, Tuple, Optional from typing import List, Dict, Set, Tuple, Optional, TypedDict
import os, sys import os, sys
from types import ModuleType from types import ModuleType
from copy import deepcopy from copy import deepcopy, copy
import warnings import warnings
from nwb_linkml.maps import flat_to_npytyping from nwb_linkml.maps import flat_to_npytyping
@ -36,9 +36,16 @@ SlotDefinitionName,
TypeDefinition, TypeDefinition,
ElementName ElementName
) )
from linkml.generators.common.type_designators import (
get_accepted_type_designator_values,
get_type_designator_value,
)
from linkml_runtime.utils.formatutils import camelcase, underscore from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.schemaview import SchemaView from linkml_runtime.utils.schemaview import SchemaView
from linkml_runtime.utils.compile_python import file_text from linkml_runtime.utils.compile_python import file_text
from linkml.utils.ifabsent_functions import ifabsent_value_declaration
from jinja2 import Template from jinja2 import Template
@ -175,6 +182,7 @@ class {{ c.name }}
return template return template
@dataclass @dataclass
class NWBPydanticGenerator(PydanticGenerator): class NWBPydanticGenerator(PydanticGenerator):
@ -184,6 +192,8 @@ class NWBPydanticGenerator(PydanticGenerator):
SKIP_CLASSES=('',) SKIP_CLASSES=('',)
# SKIP_CLASSES=('VectorData','VectorIndex') # SKIP_CLASSES=('VectorData','VectorIndex')
split:bool=True split:bool=True
schema_map:Dict[str, SchemaDefinition]=None
def _locate_imports( def _locate_imports(
self, self,
@ -234,15 +244,16 @@ class NWBPydanticGenerator(PydanticGenerator):
self, self,
cls:ClassDefinition, cls:ClassDefinition,
sv:SchemaView, sv:SchemaView,
all_classes:dict[ClassDefinitionName, ClassDefinition]) -> List[str]: all_classes:dict[ClassDefinitionName, ClassDefinition],
class_slots:dict[str, List[SlotDefinition]]
) -> List[str]:
"""Get the imports needed for a single class""" """Get the imports needed for a single class"""
needed_classes = [] needed_classes = []
needed_classes.append(cls.is_a) needed_classes.append(cls.is_a)
# get needed classes used as ranges in class attributes # get needed classes used as ranges in class attributes
for slot_name in sv.class_slots(cls.name): for slot in class_slots[cls.name]:
if slot_name in self.SKIP_SLOTS: if slot.name in self.SKIP_SLOTS:
continue continue
slot = deepcopy(sv.induced_slot(slot_name, cls.name))
if slot.range in all_classes: if slot.range in all_classes:
needed_classes.append(slot.range) needed_classes.append(slot.range)
# handle when a range is a union of classes # handle when a range is a union of classes
@ -253,15 +264,25 @@ class NWBPydanticGenerator(PydanticGenerator):
return needed_classes return needed_classes
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]: def _get_imports(self,
sv:SchemaView,
local_classes: List[ClassDefinition],
class_slots: Dict[str, List[SlotDefinition]]) -> Dict[str, List[str]]:
# import from local references, rather than serializing every class in every file
if not self.split:
# we are compiling this whole thing in one big file so we don't import anything
return {}
if 'namespace' in sv.schema.annotations.keys() and sv.schema.annotations['namespace']['value'] == 'True':
return self._get_namespace_imports(sv)
all_classes = sv.all_classes(imports=True) all_classes = sv.all_classes(imports=True)
local_classes = sv.all_classes(imports=False) # local_classes = sv.all_classes(imports=False)
needed_classes = [] needed_classes = []
# find needed classes - is_a and slot ranges # find needed classes - is_a and slot ranges
for clsname, cls in local_classes.items(): for cls in local_classes:
# get imports for this class # get imports for this class
needed_classes.extend(self._get_class_imports(cls, sv, all_classes)) needed_classes.extend(self._get_class_imports(cls, sv, all_classes, class_slots))
# remove duplicates and arraylikes # remove duplicates and arraylikes
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike'] needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
@ -272,27 +293,29 @@ class NWBPydanticGenerator(PydanticGenerator):
return imports return imports
def _get_classes(self, sv:SchemaView, imports: Dict[str, List[str]]) -> List[ClassDefinition]: def _get_classes(self, sv:SchemaView) -> List[ClassDefinition]:
module_classes = sv.all_classes(imports=False).values() if self.split:
imported_classes = [] classes = sv.all_classes(imports=False).values()
for classes in imports.values(): else:
imported_classes.extend(classes) classes = sv.all_classes(imports=True).values()
module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
imported_classes = [c for c in imported_classes if sv.get_class(c) and sv.get_class(c).is_a != 'Arraylike']
sorted_classes = self.sort_classes(module_classes, imported_classes)
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
# Don't want to generate classes when class_uri is linkml:Any, will # Don't want to generate classes when class_uri is linkml:Any, will
# just swap in typing.Any instead down below # just swap in typing.Any instead down below
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"] classes = [c for c in list(classes) if c.is_a != 'Arraylike' and c.class_uri != "linkml:Any"]
return sorted_classes
return classes
def _get_class_slots(self, sv:SchemaView, cls:ClassDefinition) -> List[SlotDefinition]:
slots = []
for slot_name in sv.class_slots(cls.name):
if slot_name in self.SKIP_SLOTS:
continue
slots.append(sv.induced_slot(slot_name, cls.name))
return slots
def _build_class(self, class_original:ClassDefinition) -> ClassDefinition: def _build_class(self, class_original:ClassDefinition) -> ClassDefinition:
class_def: ClassDefinition class_def: ClassDefinition
class_def = deepcopy(class_original) class_def = copy(class_original)
class_def.name = camelcase(class_original.name) class_def.name = camelcase(class_original.name)
if class_def.is_a: if class_def.is_a:
class_def.is_a = camelcase(class_def.is_a) class_def.is_a = camelcase(class_def.is_a)
@ -375,7 +398,7 @@ class NWBPydanticGenerator(PydanticGenerator):
return union return union
def sort_classes(self, clist: List[ClassDefinition], imports:List[str]) -> List[ClassDefinition]: def sort_classes(self, clist: List[ClassDefinition], imports:Dict[str, List[str]]) -> List[ClassDefinition]:
""" """
sort classes such that if C is a child of P then C appears after P in the list sort classes such that if C is a child of P then C appears after P in the list
@ -383,9 +406,14 @@ class NWBPydanticGenerator(PydanticGenerator):
Modified from original to allow for imported classes Modified from original to allow for imported classes
""" """
# unnest imports
imported_classes = []
for i in imports.values():
imported_classes.extend(i)
clist = list(clist) clist = list(clist)
clist = [c for c in clist if c.name not in self.SKIP_CLASSES] clist = [c for c in clist if c.name not in self.SKIP_CLASSES]
slist = [] # sorted slist = [] # type: List[ClassDefinition]
while len(clist) > 0: while len(clist) > 0:
can_add = False can_add = False
for i in range(len(clist)): for i in range(len(clist)):
@ -399,7 +427,7 @@ class NWBPydanticGenerator(PydanticGenerator):
can_add = True can_add = True
else: else:
if set(candidates) <= set([p.name for p in slist] + imports): if set(candidates) <= set([p.name for p in slist] + imported_classes):
can_add = True can_add = True
if can_add: if can_add:
@ -410,6 +438,9 @@ class NWBPydanticGenerator(PydanticGenerator):
raise ValueError( raise ValueError(
f"could not find suitable element in {clist} that does not ref {slist}" f"could not find suitable element in {clist} that does not ref {slist}"
) )
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
self.sorted_class_names += [camelcase(c.name) for c in slist]
return slist return slist
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str: def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
@ -449,7 +480,47 @@ class NWBPydanticGenerator(PydanticGenerator):
parents[camelcase(class_def.name)] = class_parents parents[camelcase(class_def.name)] = class_parents
return parents return parents
def get_predefined_slot_value(self, slot: SlotDefinition, class_def: ClassDefinition) -> Optional[str]:
"""
Modified from base pydantic generator to use already grabbed induced_slot from
already-grabbed and modified classes rather than doing a fresh iteration to
save time and respect changes already made elsewhere in the serialization routine
:return: Dictionary of dictionaries with predefined slot values for each class
"""
sv = self.schemaview
slot_value: Optional[str] = None
# for class_def in sv.all_classes().values():
# for slot_name in sv.class_slots(class_def.name):
# slot = sv.induced_slot(slot_name, class_def.name)
if slot.designates_type:
target_value = get_type_designator_value(sv, slot, class_def)
slot_value = f'"{target_value}"'
if slot.multivalued:
slot_value = (
"[" + slot_value + "]"
)
elif slot.ifabsent is not None:
value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot)
slot_value = value
# Multivalued slots that are either not inlined (just an identifier) or are
# inlined as lists should get default_factory list, if they're inlined but
# not as a list, that means a dictionary
elif slot.multivalued:
# this is slow, needs to do additional induced slot calls
#has_identifier_slot = self.range_class_has_identifier_slot(slot)
if slot.inlined and not slot.inlined_as_list: # and has_identifier_slot:
slot_value = "default_factory=dict"
else:
slot_value = "default_factory=list"
return slot_value
def serialize(self) -> str: def serialize(self) -> str:
predefined_slot_values = {}
"""splitting up parent class :meth:`.get_predefined_slot_values`"""
if self.template_file is not None: if self.template_file is not None:
with open(self.template_file) as template_file: with open(self.template_file) as template_file:
template_obj = Template(template_file.read()) template_obj = Template(template_file.read())
@ -458,34 +529,28 @@ class NWBPydanticGenerator(PydanticGenerator):
sv: SchemaView sv: SchemaView
sv = self.schemaview sv = self.schemaview
sv.schema_map = self.schema_map
schema = sv.schema schema = sv.schema
pyschema = SchemaDefinition( pyschema = SchemaDefinition(
id=schema.id, id=schema.id,
name=schema.name, name=schema.name,
description=schema.description.replace('"', '\\"') if schema.description else None, description=schema.description.replace('"', '\\"') if schema.description else None,
) )
# test caching if import closure
enums = self.generate_enums(sv.all_enums()) enums = self.generate_enums(sv.all_enums())
# filter skipped enums # filter skipped enums
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM} enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
if self.split: classes = self._get_classes(sv)
# import from local references, rather than serializing every class in every file # just induce slots once because that turns out to be expensive
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True': class_slots = {} # type: Dict[str, List[SlotDefinition]]
imports = self._get_namespace_imports(sv) for aclass in classes:
else: class_slots[aclass.name] = self._get_class_slots(sv, aclass)
imports = self._get_imports(sv)
sorted_classes = self._get_classes(sv, imports)
else:
sorted_classes = self.sort_classes(list(sv.all_classes().values()), [])
imports = {}
# Don't want to generate classes when class_uri is linkml:Any, will
# just swap in typing.Any instead down below
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
self.sorted_class_names = [camelcase(c.name) for c in sorted_classes]
# figure out what classes we need to imports
imports = self._get_imports(sv, classes, class_slots)
sorted_classes = self.sort_classes(classes, imports)
for class_original in sorted_classes: for class_original in sorted_classes:
# Generate class definition # Generate class definition
@ -503,11 +568,12 @@ class NWBPydanticGenerator(PydanticGenerator):
del class_def.attributes[attribute] del class_def.attributes[attribute]
class_name = class_original.name class_name = class_original.name
for sn in sv.class_slots(class_name): predefined_slot_values[camelcase(class_name)] = {}
if sn in self.SKIP_SLOTS: for s in class_slots[class_name]:
continue sn = SlotDefinitionName(s.name)
# TODO: fix runtime, copy should not be necessary predefined_slot_value = self.get_predefined_slot_value(s, class_def)
s = deepcopy(sv.induced_slot(sn, class_name)) if predefined_slot_value is not None:
predefined_slot_values[camelcase(class_name)][s.name] = predefined_slot_value
# logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}') # logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}')
s.name = underscore(s.name) s.name = underscore(s.name)
if s.description: if s.description:
@ -557,7 +623,7 @@ class NWBPydanticGenerator(PydanticGenerator):
schema=pyschema, schema=pyschema,
underscore=underscore, underscore=underscore,
enums=enums, enums=enums,
predefined_slot_values=self.get_predefined_slot_values(), predefined_slot_values=predefined_slot_values,
allow_extra=self.allow_extra, allow_extra=self.allow_extra,
metamodel_version=self.schema.metamodel_version, metamodel_version=self.schema.metamodel_version,
version=self.schema.version, version=self.schema.version,

View file

@ -1,11 +1,15 @@
import pdb import pdb
from pathlib import Path
from typing import Dict
import pytest import pytest
import warnings import warnings
from .fixtures import nwb_core_fixture, tmp_output_dir from .fixtures import nwb_core_fixture, tmp_output_dir
from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.dumpers import yaml_dumper
from linkml_runtime.linkml_model import SchemaDefinition
from nwb_linkml.generators.pydantic import NWBPydanticGenerator from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from linkml_runtime.loaders.yaml_loader import YAMLLoader
from nwb_linkml.lang_elements import NwbLangSchema from nwb_linkml.lang_elements import NwbLangSchema
@ -22,29 +26,38 @@ def test_generate_core(nwb_core_fixture, tmp_output_dir):
output_file = tmp_output_dir / 'schema' / (schema.name + '.yaml') output_file = tmp_output_dir / 'schema' / (schema.name + '.yaml')
yaml_dumper.dump(schema, output_file) yaml_dumper.dump(schema, output_file)
def load_schema_files(path: Path) -> Dict[str, SchemaDefinition]:
yaml_loader = YAMLLoader()
sch: SchemaDefinition
preloaded_schema = {}
for schema_path in (path / 'schema').glob('*.yaml'):
sch = yaml_loader.load(str(schema_path), target_class=SchemaDefinition)
preloaded_schema[sch.name] = sch
return preloaded_schema
@pytest.mark.depends(on=['test_generate_core']) @pytest.mark.depends(on=['test_generate_core'])
def test_generate_pydantic(tmp_output_dir): def test_generate_pydantic(tmp_output_dir):
# core_file = tmp_output_dir / 'core.yaml'
# pydantic_file = tmp_output_dir / 'core.py'
(tmp_output_dir / 'models').mkdir(exist_ok=True) (tmp_output_dir / 'models').mkdir(exist_ok=True)
for schema in (tmp_output_dir / 'schema').glob('*.yaml'): preloaded_schema = load_schema_files(tmp_output_dir)
if not schema.exists():
for schema_path in (tmp_output_dir / 'schema').glob('*.yaml'):
if not schema_path.exists():
continue continue
# python friendly name # python friendly name
python_name = schema.stem.replace('.', '_').replace('-','_') python_name = schema_path.stem.replace('.', '_').replace('-','_')
pydantic_file = (schema.parent.parent / 'models' / python_name).with_suffix('.py') pydantic_file = (schema_path.parent.parent / 'models' / python_name).with_suffix('.py')
generator = NWBPydanticGenerator( generator = NWBPydanticGenerator(
str(schema), str(schema_path),
pydantic_version='2', pydantic_version='2',
emit_metadata=True, emit_metadata=True,
gen_classvars=True, gen_classvars=True,
gen_slots=True gen_slots=True,
schema_map=preloaded_schema
) )
gen_pydantic = generator.serialize() gen_pydantic = generator.serialize()

View file

@ -5,6 +5,7 @@ from nwb_linkml.io.hdf5 import HDF5IO
def test_hdf_read(): def test_hdf_read():
NWBFILE = Path('/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb') NWBFILE = Path('/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb')
if not NWBFILE.exists():
return
io = HDF5IO(path=NWBFILE) io = HDF5IO(path=NWBFILE)
model = io.read('/general') model = io.read('/general')