mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
perf enhancements to pydantic generation
This commit is contained in:
parent
7b97b749ef
commit
2b828082e0
6 changed files with 179 additions and 65 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -161,3 +161,4 @@ cython_debug/
|
|||
|
||||
nwb.schema.json
|
||||
__tmp__
|
||||
prof
|
32
nwb_linkml/poetry.lock
generated
32
nwb_linkml/poetry.lock
generated
|
@ -492,6 +492,17 @@ files = [
|
|||
[package.extras]
|
||||
rewrite = ["tokenize-rt (>=3)"]
|
||||
|
||||
[[package]]
|
||||
name = "gprof2dot"
|
||||
version = "2022.7.29"
|
||||
description = "Generate a dot graph from the output of several profilers."
|
||||
optional = false
|
||||
python-versions = ">=2.7"
|
||||
files = [
|
||||
{file = "gprof2dot-2022.7.29-py2.py3-none-any.whl", hash = "sha256:f165b3851d3c52ee4915eb1bd6cca571e5759823c2cd0f71a79bda93c2dc85d6"},
|
||||
{file = "gprof2dot-2022.7.29.tar.gz", hash = "sha256:45b4d298bd36608fccf9511c3fd88a773f7a1abc04d6cd39445b11ba43133ec5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphviz"
|
||||
version = "0.20.1"
|
||||
|
@ -1554,6 +1565,25 @@ files = [
|
|||
[package.dependencies]
|
||||
pytest = ">=4.2.1"
|
||||
|
||||
[[package]]
|
||||
name = "pytest-profiling"
|
||||
version = "1.7.0"
|
||||
description = "Profiling plugin for py.test"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytest-profiling-1.7.0.tar.gz", hash = "sha256:93938f147662225d2b8bd5af89587b979652426a8a6ffd7e73ec4a23e24b7f29"},
|
||||
{file = "pytest_profiling-1.7.0-py2.py3-none-any.whl", hash = "sha256:999cc9ac94f2e528e3f5d43465da277429984a1c237ae9818f8cfd0b06acb019"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
gprof2dot = "*"
|
||||
pytest = "*"
|
||||
six = "*"
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest-virtualenv"]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.8.2"
|
||||
|
@ -2342,4 +2372,4 @@ tests = ["coverage", "coveralls", "pytest", "pytest-cov", "pytest-depends", "pyt
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "5427416a9edebc2ab2c4f7f7c9779b2b9c7e4c1c2da5dcc0968ee633110973fa"
|
||||
content-hash = "7ae9160a401b3bfa2f4535696ecf15e33815e356f7757ee611893c701485d24f"
|
||||
|
|
|
@ -29,11 +29,12 @@ pytest-md = {version = "^0.2.0", optional = true}
|
|||
pytest-emoji = {version="^0.2.0", optional = true}
|
||||
pytest-cov = {version = "^4.1.0", optional = true}
|
||||
coveralls = {version = "^3.3.1", optional = true}
|
||||
pytest-profiling = {version = "^1.7.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
tests = [
|
||||
"pytest", "pytest-depends", "coverage", "pytest-md",
|
||||
"pytest-emoji", "pytest-cov", "coveralls"
|
||||
"pytest", "pytest-depends", "coverage", "pytest-md",
|
||||
"pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling"
|
||||
]
|
||||
plot = ["dash", "dash-cytoscape"]
|
||||
|
||||
|
@ -49,6 +50,7 @@ pytest-md = "^0.2.0"
|
|||
pytest-emoji = "^0.2.0"
|
||||
pytest-cov = "^4.1.0"
|
||||
coveralls = "^3.3.1"
|
||||
pytest-profiling = "^1.7.0"
|
||||
|
||||
[tool.poetry.group.plot]
|
||||
optional = true
|
||||
|
@ -67,7 +69,8 @@ addopts = [
|
|||
"--cov=nwb_linkml",
|
||||
"--cov-append",
|
||||
"--cov-config=.coveragerc",
|
||||
"--emoji"
|
||||
"--emoji",
|
||||
"--profile"
|
||||
]
|
||||
testpaths = [
|
||||
"tests",
|
||||
|
|
|
@ -19,10 +19,10 @@ The `serialize` method
|
|||
import pdb
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Set, Tuple, Optional
|
||||
from typing import List, Dict, Set, Tuple, Optional, TypedDict
|
||||
import os, sys
|
||||
from types import ModuleType
|
||||
from copy import deepcopy
|
||||
from copy import deepcopy, copy
|
||||
import warnings
|
||||
|
||||
from nwb_linkml.maps import flat_to_npytyping
|
||||
|
@ -36,9 +36,16 @@ SlotDefinitionName,
|
|||
TypeDefinition,
|
||||
ElementName
|
||||
)
|
||||
from linkml.generators.common.type_designators import (
|
||||
get_accepted_type_designator_values,
|
||||
get_type_designator_value,
|
||||
)
|
||||
from linkml_runtime.utils.formatutils import camelcase, underscore
|
||||
from linkml_runtime.utils.schemaview import SchemaView
|
||||
from linkml_runtime.utils.compile_python import file_text
|
||||
from linkml.utils.ifabsent_functions import ifabsent_value_declaration
|
||||
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
|
||||
|
@ -175,6 +182,7 @@ class {{ c.name }}
|
|||
return template
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class NWBPydanticGenerator(PydanticGenerator):
|
||||
|
||||
|
@ -184,6 +192,8 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
SKIP_CLASSES=('',)
|
||||
# SKIP_CLASSES=('VectorData','VectorIndex')
|
||||
split:bool=True
|
||||
schema_map:Dict[str, SchemaDefinition]=None
|
||||
|
||||
|
||||
def _locate_imports(
|
||||
self,
|
||||
|
@ -234,15 +244,16 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
self,
|
||||
cls:ClassDefinition,
|
||||
sv:SchemaView,
|
||||
all_classes:dict[ClassDefinitionName, ClassDefinition]) -> List[str]:
|
||||
all_classes:dict[ClassDefinitionName, ClassDefinition],
|
||||
class_slots:dict[str, List[SlotDefinition]]
|
||||
) -> List[str]:
|
||||
"""Get the imports needed for a single class"""
|
||||
needed_classes = []
|
||||
needed_classes.append(cls.is_a)
|
||||
# get needed classes used as ranges in class attributes
|
||||
for slot_name in sv.class_slots(cls.name):
|
||||
if slot_name in self.SKIP_SLOTS:
|
||||
for slot in class_slots[cls.name]:
|
||||
if slot.name in self.SKIP_SLOTS:
|
||||
continue
|
||||
slot = deepcopy(sv.induced_slot(slot_name, cls.name))
|
||||
if slot.range in all_classes:
|
||||
needed_classes.append(slot.range)
|
||||
# handle when a range is a union of classes
|
||||
|
@ -253,15 +264,25 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
|
||||
return needed_classes
|
||||
|
||||
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
|
||||
def _get_imports(self,
|
||||
sv:SchemaView,
|
||||
local_classes: List[ClassDefinition],
|
||||
class_slots: Dict[str, List[SlotDefinition]]) -> Dict[str, List[str]]:
|
||||
# import from local references, rather than serializing every class in every file
|
||||
if not self.split:
|
||||
# we are compiling this whole thing in one big file so we don't import anything
|
||||
return {}
|
||||
if 'namespace' in sv.schema.annotations.keys() and sv.schema.annotations['namespace']['value'] == 'True':
|
||||
return self._get_namespace_imports(sv)
|
||||
|
||||
all_classes = sv.all_classes(imports=True)
|
||||
local_classes = sv.all_classes(imports=False)
|
||||
# local_classes = sv.all_classes(imports=False)
|
||||
needed_classes = []
|
||||
# find needed classes - is_a and slot ranges
|
||||
|
||||
for clsname, cls in local_classes.items():
|
||||
for cls in local_classes:
|
||||
# get imports for this class
|
||||
needed_classes.extend(self._get_class_imports(cls, sv, all_classes))
|
||||
needed_classes.extend(self._get_class_imports(cls, sv, all_classes, class_slots))
|
||||
|
||||
# remove duplicates and arraylikes
|
||||
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
|
||||
|
@ -272,27 +293,29 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
return imports
|
||||
|
||||
|
||||
def _get_classes(self, sv:SchemaView, imports: Dict[str, List[str]]) -> List[ClassDefinition]:
|
||||
module_classes = sv.all_classes(imports=False).values()
|
||||
imported_classes = []
|
||||
for classes in imports.values():
|
||||
imported_classes.extend(classes)
|
||||
|
||||
module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
|
||||
imported_classes = [c for c in imported_classes if sv.get_class(c) and sv.get_class(c).is_a != 'Arraylike']
|
||||
|
||||
sorted_classes = self.sort_classes(module_classes, imported_classes)
|
||||
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
|
||||
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
|
||||
def _get_classes(self, sv:SchemaView) -> List[ClassDefinition]:
|
||||
if self.split:
|
||||
classes = sv.all_classes(imports=False).values()
|
||||
else:
|
||||
classes = sv.all_classes(imports=True).values()
|
||||
|
||||
# Don't want to generate classes when class_uri is linkml:Any, will
|
||||
# just swap in typing.Any instead down below
|
||||
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
|
||||
return sorted_classes
|
||||
classes = [c for c in list(classes) if c.is_a != 'Arraylike' and c.class_uri != "linkml:Any"]
|
||||
|
||||
return classes
|
||||
|
||||
def _get_class_slots(self, sv:SchemaView, cls:ClassDefinition) -> List[SlotDefinition]:
|
||||
slots = []
|
||||
for slot_name in sv.class_slots(cls.name):
|
||||
if slot_name in self.SKIP_SLOTS:
|
||||
continue
|
||||
slots.append(sv.induced_slot(slot_name, cls.name))
|
||||
return slots
|
||||
|
||||
def _build_class(self, class_original:ClassDefinition) -> ClassDefinition:
|
||||
class_def: ClassDefinition
|
||||
class_def = deepcopy(class_original)
|
||||
class_def = copy(class_original)
|
||||
class_def.name = camelcase(class_original.name)
|
||||
if class_def.is_a:
|
||||
class_def.is_a = camelcase(class_def.is_a)
|
||||
|
@ -375,7 +398,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
return union
|
||||
|
||||
|
||||
def sort_classes(self, clist: List[ClassDefinition], imports:List[str]) -> List[ClassDefinition]:
|
||||
def sort_classes(self, clist: List[ClassDefinition], imports:Dict[str, List[str]]) -> List[ClassDefinition]:
|
||||
"""
|
||||
sort classes such that if C is a child of P then C appears after P in the list
|
||||
|
||||
|
@ -383,9 +406,14 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
|
||||
Modified from original to allow for imported classes
|
||||
"""
|
||||
# unnest imports
|
||||
imported_classes = []
|
||||
for i in imports.values():
|
||||
imported_classes.extend(i)
|
||||
|
||||
clist = list(clist)
|
||||
clist = [c for c in clist if c.name not in self.SKIP_CLASSES]
|
||||
slist = [] # sorted
|
||||
slist = [] # type: List[ClassDefinition]
|
||||
while len(clist) > 0:
|
||||
can_add = False
|
||||
for i in range(len(clist)):
|
||||
|
@ -399,7 +427,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
can_add = True
|
||||
|
||||
else:
|
||||
if set(candidates) <= set([p.name for p in slist] + imports):
|
||||
if set(candidates) <= set([p.name for p in slist] + imported_classes):
|
||||
can_add = True
|
||||
|
||||
if can_add:
|
||||
|
@ -410,6 +438,9 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
raise ValueError(
|
||||
f"could not find suitable element in {clist} that does not ref {slist}"
|
||||
)
|
||||
|
||||
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
|
||||
self.sorted_class_names += [camelcase(c.name) for c in slist]
|
||||
return slist
|
||||
|
||||
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
|
||||
|
@ -449,7 +480,47 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
parents[camelcase(class_def.name)] = class_parents
|
||||
return parents
|
||||
|
||||
def get_predefined_slot_value(self, slot: SlotDefinition, class_def: ClassDefinition) -> Optional[str]:
|
||||
"""
|
||||
Modified from base pydantic generator to use already grabbed induced_slot from
|
||||
already-grabbed and modified classes rather than doing a fresh iteration to
|
||||
save time and respect changes already made elsewhere in the serialization routine
|
||||
|
||||
:return: Dictionary of dictionaries with predefined slot values for each class
|
||||
"""
|
||||
sv = self.schemaview
|
||||
slot_value: Optional[str] = None
|
||||
# for class_def in sv.all_classes().values():
|
||||
# for slot_name in sv.class_slots(class_def.name):
|
||||
# slot = sv.induced_slot(slot_name, class_def.name)
|
||||
if slot.designates_type:
|
||||
target_value = get_type_designator_value(sv, slot, class_def)
|
||||
slot_value = f'"{target_value}"'
|
||||
if slot.multivalued:
|
||||
slot_value = (
|
||||
"[" + slot_value + "]"
|
||||
)
|
||||
elif slot.ifabsent is not None:
|
||||
value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot)
|
||||
slot_value = value
|
||||
# Multivalued slots that are either not inlined (just an identifier) or are
|
||||
# inlined as lists should get default_factory list, if they're inlined but
|
||||
# not as a list, that means a dictionary
|
||||
elif slot.multivalued:
|
||||
# this is slow, needs to do additional induced slot calls
|
||||
#has_identifier_slot = self.range_class_has_identifier_slot(slot)
|
||||
|
||||
if slot.inlined and not slot.inlined_as_list: # and has_identifier_slot:
|
||||
slot_value = "default_factory=dict"
|
||||
else:
|
||||
slot_value = "default_factory=list"
|
||||
|
||||
return slot_value
|
||||
|
||||
def serialize(self) -> str:
|
||||
predefined_slot_values = {}
|
||||
"""splitting up parent class :meth:`.get_predefined_slot_values`"""
|
||||
|
||||
if self.template_file is not None:
|
||||
with open(self.template_file) as template_file:
|
||||
template_obj = Template(template_file.read())
|
||||
|
@ -458,34 +529,28 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
|
||||
sv: SchemaView
|
||||
sv = self.schemaview
|
||||
sv.schema_map = self.schema_map
|
||||
schema = sv.schema
|
||||
pyschema = SchemaDefinition(
|
||||
id=schema.id,
|
||||
name=schema.name,
|
||||
description=schema.description.replace('"', '\\"') if schema.description else None,
|
||||
)
|
||||
# test caching if import closure
|
||||
enums = self.generate_enums(sv.all_enums())
|
||||
# filter skipped enums
|
||||
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
|
||||
|
||||
if self.split:
|
||||
# import from local references, rather than serializing every class in every file
|
||||
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True':
|
||||
imports = self._get_namespace_imports(sv)
|
||||
else:
|
||||
imports = self._get_imports(sv)
|
||||
|
||||
sorted_classes = self._get_classes(sv, imports)
|
||||
else:
|
||||
sorted_classes = self.sort_classes(list(sv.all_classes().values()), [])
|
||||
imports = {}
|
||||
|
||||
# Don't want to generate classes when class_uri is linkml:Any, will
|
||||
# just swap in typing.Any instead down below
|
||||
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
|
||||
self.sorted_class_names = [camelcase(c.name) for c in sorted_classes]
|
||||
classes = self._get_classes(sv)
|
||||
# just induce slots once because that turns out to be expensive
|
||||
class_slots = {} # type: Dict[str, List[SlotDefinition]]
|
||||
for aclass in classes:
|
||||
class_slots[aclass.name] = self._get_class_slots(sv, aclass)
|
||||
|
||||
# figure out what classes we need to imports
|
||||
imports = self._get_imports(sv, classes, class_slots)
|
||||
|
||||
sorted_classes = self.sort_classes(classes, imports)
|
||||
|
||||
for class_original in sorted_classes:
|
||||
# Generate class definition
|
||||
|
@ -503,11 +568,12 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
del class_def.attributes[attribute]
|
||||
|
||||
class_name = class_original.name
|
||||
for sn in sv.class_slots(class_name):
|
||||
if sn in self.SKIP_SLOTS:
|
||||
continue
|
||||
# TODO: fix runtime, copy should not be necessary
|
||||
s = deepcopy(sv.induced_slot(sn, class_name))
|
||||
predefined_slot_values[camelcase(class_name)] = {}
|
||||
for s in class_slots[class_name]:
|
||||
sn = SlotDefinitionName(s.name)
|
||||
predefined_slot_value = self.get_predefined_slot_value(s, class_def)
|
||||
if predefined_slot_value is not None:
|
||||
predefined_slot_values[camelcase(class_name)][s.name] = predefined_slot_value
|
||||
# logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}')
|
||||
s.name = underscore(s.name)
|
||||
if s.description:
|
||||
|
@ -557,7 +623,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
schema=pyschema,
|
||||
underscore=underscore,
|
||||
enums=enums,
|
||||
predefined_slot_values=self.get_predefined_slot_values(),
|
||||
predefined_slot_values=predefined_slot_values,
|
||||
allow_extra=self.allow_extra,
|
||||
metamodel_version=self.schema.metamodel_version,
|
||||
version=self.schema.version,
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
import pdb
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
import warnings
|
||||
|
||||
from .fixtures import nwb_core_fixture, tmp_output_dir
|
||||
from linkml_runtime.dumpers import yaml_dumper
|
||||
from linkml_runtime.linkml_model import SchemaDefinition
|
||||
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
|
||||
from linkml_runtime.loaders.yaml_loader import YAMLLoader
|
||||
|
||||
from nwb_linkml.lang_elements import NwbLangSchema
|
||||
|
||||
|
@ -22,29 +26,38 @@ def test_generate_core(nwb_core_fixture, tmp_output_dir):
|
|||
output_file = tmp_output_dir / 'schema' / (schema.name + '.yaml')
|
||||
yaml_dumper.dump(schema, output_file)
|
||||
|
||||
def load_schema_files(path: Path) -> Dict[str, SchemaDefinition]:
|
||||
yaml_loader = YAMLLoader()
|
||||
sch: SchemaDefinition
|
||||
preloaded_schema = {}
|
||||
for schema_path in (path / 'schema').glob('*.yaml'):
|
||||
sch = yaml_loader.load(str(schema_path), target_class=SchemaDefinition)
|
||||
preloaded_schema[sch.name] = sch
|
||||
return preloaded_schema
|
||||
|
||||
|
||||
@pytest.mark.depends(on=['test_generate_core'])
|
||||
def test_generate_pydantic(tmp_output_dir):
|
||||
|
||||
|
||||
# core_file = tmp_output_dir / 'core.yaml'
|
||||
# pydantic_file = tmp_output_dir / 'core.py'
|
||||
|
||||
(tmp_output_dir / 'models').mkdir(exist_ok=True)
|
||||
|
||||
for schema in (tmp_output_dir / 'schema').glob('*.yaml'):
|
||||
if not schema.exists():
|
||||
preloaded_schema = load_schema_files(tmp_output_dir)
|
||||
|
||||
for schema_path in (tmp_output_dir / 'schema').glob('*.yaml'):
|
||||
if not schema_path.exists():
|
||||
continue
|
||||
# python friendly name
|
||||
python_name = schema.stem.replace('.', '_').replace('-','_')
|
||||
python_name = schema_path.stem.replace('.', '_').replace('-','_')
|
||||
|
||||
pydantic_file = (schema.parent.parent / 'models' / python_name).with_suffix('.py')
|
||||
pydantic_file = (schema_path.parent.parent / 'models' / python_name).with_suffix('.py')
|
||||
|
||||
generator = NWBPydanticGenerator(
|
||||
str(schema),
|
||||
str(schema_path),
|
||||
pydantic_version='2',
|
||||
emit_metadata=True,
|
||||
gen_classvars=True,
|
||||
gen_slots=True
|
||||
gen_slots=True,
|
||||
schema_map=preloaded_schema
|
||||
)
|
||||
gen_pydantic = generator.serialize()
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ from nwb_linkml.io.hdf5 import HDF5IO
|
|||
|
||||
def test_hdf_read():
|
||||
NWBFILE = Path('/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb')
|
||||
|
||||
if not NWBFILE.exists():
|
||||
return
|
||||
io = HDF5IO(path=NWBFILE)
|
||||
model = io.read('/general')
|
||||
|
|
Loading…
Reference in a new issue