mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-09 21:54:27 +00:00
perf enhancements to pydantic generation
This commit is contained in:
parent
7b97b749ef
commit
2b828082e0
6 changed files with 179 additions and 65 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -161,3 +161,4 @@ cython_debug/
|
||||||
|
|
||||||
nwb.schema.json
|
nwb.schema.json
|
||||||
__tmp__
|
__tmp__
|
||||||
|
prof
|
32
nwb_linkml/poetry.lock
generated
32
nwb_linkml/poetry.lock
generated
|
@ -492,6 +492,17 @@ files = [
|
||||||
[package.extras]
|
[package.extras]
|
||||||
rewrite = ["tokenize-rt (>=3)"]
|
rewrite = ["tokenize-rt (>=3)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gprof2dot"
|
||||||
|
version = "2022.7.29"
|
||||||
|
description = "Generate a dot graph from the output of several profilers."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=2.7"
|
||||||
|
files = [
|
||||||
|
{file = "gprof2dot-2022.7.29-py2.py3-none-any.whl", hash = "sha256:f165b3851d3c52ee4915eb1bd6cca571e5759823c2cd0f71a79bda93c2dc85d6"},
|
||||||
|
{file = "gprof2dot-2022.7.29.tar.gz", hash = "sha256:45b4d298bd36608fccf9511c3fd88a773f7a1abc04d6cd39445b11ba43133ec5"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "graphviz"
|
name = "graphviz"
|
||||||
version = "0.20.1"
|
version = "0.20.1"
|
||||||
|
@ -1554,6 +1565,25 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
pytest = ">=4.2.1"
|
pytest = ">=4.2.1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-profiling"
|
||||||
|
version = "1.7.0"
|
||||||
|
description = "Profiling plugin for py.test"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "pytest-profiling-1.7.0.tar.gz", hash = "sha256:93938f147662225d2b8bd5af89587b979652426a8a6ffd7e73ec4a23e24b7f29"},
|
||||||
|
{file = "pytest_profiling-1.7.0-py2.py3-none-any.whl", hash = "sha256:999cc9ac94f2e528e3f5d43465da277429984a1c237ae9818f8cfd0b06acb019"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
gprof2dot = "*"
|
||||||
|
pytest = "*"
|
||||||
|
six = "*"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
tests = ["pytest-virtualenv"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.8.2"
|
version = "2.8.2"
|
||||||
|
@ -2342,4 +2372,4 @@ tests = ["coverage", "coveralls", "pytest", "pytest-cov", "pytest-depends", "pyt
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "5427416a9edebc2ab2c4f7f7c9779b2b9c7e4c1c2da5dcc0968ee633110973fa"
|
content-hash = "7ae9160a401b3bfa2f4535696ecf15e33815e356f7757ee611893c701485d24f"
|
||||||
|
|
|
@ -29,11 +29,12 @@ pytest-md = {version = "^0.2.0", optional = true}
|
||||||
pytest-emoji = {version="^0.2.0", optional = true}
|
pytest-emoji = {version="^0.2.0", optional = true}
|
||||||
pytest-cov = {version = "^4.1.0", optional = true}
|
pytest-cov = {version = "^4.1.0", optional = true}
|
||||||
coveralls = {version = "^3.3.1", optional = true}
|
coveralls = {version = "^3.3.1", optional = true}
|
||||||
|
pytest-profiling = {version = "^1.7.0", optional = true}
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
tests = [
|
tests = [
|
||||||
"pytest", "pytest-depends", "coverage", "pytest-md",
|
"pytest", "pytest-depends", "coverage", "pytest-md",
|
||||||
"pytest-emoji", "pytest-cov", "coveralls"
|
"pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling"
|
||||||
]
|
]
|
||||||
plot = ["dash", "dash-cytoscape"]
|
plot = ["dash", "dash-cytoscape"]
|
||||||
|
|
||||||
|
@ -49,6 +50,7 @@ pytest-md = "^0.2.0"
|
||||||
pytest-emoji = "^0.2.0"
|
pytest-emoji = "^0.2.0"
|
||||||
pytest-cov = "^4.1.0"
|
pytest-cov = "^4.1.0"
|
||||||
coveralls = "^3.3.1"
|
coveralls = "^3.3.1"
|
||||||
|
pytest-profiling = "^1.7.0"
|
||||||
|
|
||||||
[tool.poetry.group.plot]
|
[tool.poetry.group.plot]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -67,7 +69,8 @@ addopts = [
|
||||||
"--cov=nwb_linkml",
|
"--cov=nwb_linkml",
|
||||||
"--cov-append",
|
"--cov-append",
|
||||||
"--cov-config=.coveragerc",
|
"--cov-config=.coveragerc",
|
||||||
"--emoji"
|
"--emoji",
|
||||||
|
"--profile"
|
||||||
]
|
]
|
||||||
testpaths = [
|
testpaths = [
|
||||||
"tests",
|
"tests",
|
||||||
|
|
|
@ -19,10 +19,10 @@ The `serialize` method
|
||||||
import pdb
|
import pdb
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Set, Tuple, Optional
|
from typing import List, Dict, Set, Tuple, Optional, TypedDict
|
||||||
import os, sys
|
import os, sys
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from copy import deepcopy
|
from copy import deepcopy, copy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from nwb_linkml.maps import flat_to_npytyping
|
from nwb_linkml.maps import flat_to_npytyping
|
||||||
|
@ -36,9 +36,16 @@ SlotDefinitionName,
|
||||||
TypeDefinition,
|
TypeDefinition,
|
||||||
ElementName
|
ElementName
|
||||||
)
|
)
|
||||||
|
from linkml.generators.common.type_designators import (
|
||||||
|
get_accepted_type_designator_values,
|
||||||
|
get_type_designator_value,
|
||||||
|
)
|
||||||
from linkml_runtime.utils.formatutils import camelcase, underscore
|
from linkml_runtime.utils.formatutils import camelcase, underscore
|
||||||
from linkml_runtime.utils.schemaview import SchemaView
|
from linkml_runtime.utils.schemaview import SchemaView
|
||||||
from linkml_runtime.utils.compile_python import file_text
|
from linkml_runtime.utils.compile_python import file_text
|
||||||
|
from linkml.utils.ifabsent_functions import ifabsent_value_declaration
|
||||||
|
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,6 +182,7 @@ class {{ c.name }}
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NWBPydanticGenerator(PydanticGenerator):
|
class NWBPydanticGenerator(PydanticGenerator):
|
||||||
|
|
||||||
|
@ -184,6 +192,8 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
SKIP_CLASSES=('',)
|
SKIP_CLASSES=('',)
|
||||||
# SKIP_CLASSES=('VectorData','VectorIndex')
|
# SKIP_CLASSES=('VectorData','VectorIndex')
|
||||||
split:bool=True
|
split:bool=True
|
||||||
|
schema_map:Dict[str, SchemaDefinition]=None
|
||||||
|
|
||||||
|
|
||||||
def _locate_imports(
|
def _locate_imports(
|
||||||
self,
|
self,
|
||||||
|
@ -234,15 +244,16 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
self,
|
self,
|
||||||
cls:ClassDefinition,
|
cls:ClassDefinition,
|
||||||
sv:SchemaView,
|
sv:SchemaView,
|
||||||
all_classes:dict[ClassDefinitionName, ClassDefinition]) -> List[str]:
|
all_classes:dict[ClassDefinitionName, ClassDefinition],
|
||||||
|
class_slots:dict[str, List[SlotDefinition]]
|
||||||
|
) -> List[str]:
|
||||||
"""Get the imports needed for a single class"""
|
"""Get the imports needed for a single class"""
|
||||||
needed_classes = []
|
needed_classes = []
|
||||||
needed_classes.append(cls.is_a)
|
needed_classes.append(cls.is_a)
|
||||||
# get needed classes used as ranges in class attributes
|
# get needed classes used as ranges in class attributes
|
||||||
for slot_name in sv.class_slots(cls.name):
|
for slot in class_slots[cls.name]:
|
||||||
if slot_name in self.SKIP_SLOTS:
|
if slot.name in self.SKIP_SLOTS:
|
||||||
continue
|
continue
|
||||||
slot = deepcopy(sv.induced_slot(slot_name, cls.name))
|
|
||||||
if slot.range in all_classes:
|
if slot.range in all_classes:
|
||||||
needed_classes.append(slot.range)
|
needed_classes.append(slot.range)
|
||||||
# handle when a range is a union of classes
|
# handle when a range is a union of classes
|
||||||
|
@ -253,15 +264,25 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
|
|
||||||
return needed_classes
|
return needed_classes
|
||||||
|
|
||||||
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
|
def _get_imports(self,
|
||||||
|
sv:SchemaView,
|
||||||
|
local_classes: List[ClassDefinition],
|
||||||
|
class_slots: Dict[str, List[SlotDefinition]]) -> Dict[str, List[str]]:
|
||||||
|
# import from local references, rather than serializing every class in every file
|
||||||
|
if not self.split:
|
||||||
|
# we are compiling this whole thing in one big file so we don't import anything
|
||||||
|
return {}
|
||||||
|
if 'namespace' in sv.schema.annotations.keys() and sv.schema.annotations['namespace']['value'] == 'True':
|
||||||
|
return self._get_namespace_imports(sv)
|
||||||
|
|
||||||
all_classes = sv.all_classes(imports=True)
|
all_classes = sv.all_classes(imports=True)
|
||||||
local_classes = sv.all_classes(imports=False)
|
# local_classes = sv.all_classes(imports=False)
|
||||||
needed_classes = []
|
needed_classes = []
|
||||||
# find needed classes - is_a and slot ranges
|
# find needed classes - is_a and slot ranges
|
||||||
|
|
||||||
for clsname, cls in local_classes.items():
|
for cls in local_classes:
|
||||||
# get imports for this class
|
# get imports for this class
|
||||||
needed_classes.extend(self._get_class_imports(cls, sv, all_classes))
|
needed_classes.extend(self._get_class_imports(cls, sv, all_classes, class_slots))
|
||||||
|
|
||||||
# remove duplicates and arraylikes
|
# remove duplicates and arraylikes
|
||||||
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
|
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
|
||||||
|
@ -272,27 +293,29 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
|
||||||
def _get_classes(self, sv:SchemaView, imports: Dict[str, List[str]]) -> List[ClassDefinition]:
|
def _get_classes(self, sv:SchemaView) -> List[ClassDefinition]:
|
||||||
module_classes = sv.all_classes(imports=False).values()
|
if self.split:
|
||||||
imported_classes = []
|
classes = sv.all_classes(imports=False).values()
|
||||||
for classes in imports.values():
|
else:
|
||||||
imported_classes.extend(classes)
|
classes = sv.all_classes(imports=True).values()
|
||||||
|
|
||||||
module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
|
|
||||||
imported_classes = [c for c in imported_classes if sv.get_class(c) and sv.get_class(c).is_a != 'Arraylike']
|
|
||||||
|
|
||||||
sorted_classes = self.sort_classes(module_classes, imported_classes)
|
|
||||||
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
|
|
||||||
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
|
|
||||||
|
|
||||||
# Don't want to generate classes when class_uri is linkml:Any, will
|
# Don't want to generate classes when class_uri is linkml:Any, will
|
||||||
# just swap in typing.Any instead down below
|
# just swap in typing.Any instead down below
|
||||||
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
|
classes = [c for c in list(classes) if c.is_a != 'Arraylike' and c.class_uri != "linkml:Any"]
|
||||||
return sorted_classes
|
|
||||||
|
return classes
|
||||||
|
|
||||||
|
def _get_class_slots(self, sv:SchemaView, cls:ClassDefinition) -> List[SlotDefinition]:
|
||||||
|
slots = []
|
||||||
|
for slot_name in sv.class_slots(cls.name):
|
||||||
|
if slot_name in self.SKIP_SLOTS:
|
||||||
|
continue
|
||||||
|
slots.append(sv.induced_slot(slot_name, cls.name))
|
||||||
|
return slots
|
||||||
|
|
||||||
def _build_class(self, class_original:ClassDefinition) -> ClassDefinition:
|
def _build_class(self, class_original:ClassDefinition) -> ClassDefinition:
|
||||||
class_def: ClassDefinition
|
class_def: ClassDefinition
|
||||||
class_def = deepcopy(class_original)
|
class_def = copy(class_original)
|
||||||
class_def.name = camelcase(class_original.name)
|
class_def.name = camelcase(class_original.name)
|
||||||
if class_def.is_a:
|
if class_def.is_a:
|
||||||
class_def.is_a = camelcase(class_def.is_a)
|
class_def.is_a = camelcase(class_def.is_a)
|
||||||
|
@ -375,7 +398,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
return union
|
return union
|
||||||
|
|
||||||
|
|
||||||
def sort_classes(self, clist: List[ClassDefinition], imports:List[str]) -> List[ClassDefinition]:
|
def sort_classes(self, clist: List[ClassDefinition], imports:Dict[str, List[str]]) -> List[ClassDefinition]:
|
||||||
"""
|
"""
|
||||||
sort classes such that if C is a child of P then C appears after P in the list
|
sort classes such that if C is a child of P then C appears after P in the list
|
||||||
|
|
||||||
|
@ -383,9 +406,14 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
|
|
||||||
Modified from original to allow for imported classes
|
Modified from original to allow for imported classes
|
||||||
"""
|
"""
|
||||||
|
# unnest imports
|
||||||
|
imported_classes = []
|
||||||
|
for i in imports.values():
|
||||||
|
imported_classes.extend(i)
|
||||||
|
|
||||||
clist = list(clist)
|
clist = list(clist)
|
||||||
clist = [c for c in clist if c.name not in self.SKIP_CLASSES]
|
clist = [c for c in clist if c.name not in self.SKIP_CLASSES]
|
||||||
slist = [] # sorted
|
slist = [] # type: List[ClassDefinition]
|
||||||
while len(clist) > 0:
|
while len(clist) > 0:
|
||||||
can_add = False
|
can_add = False
|
||||||
for i in range(len(clist)):
|
for i in range(len(clist)):
|
||||||
|
@ -399,7 +427,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
can_add = True
|
can_add = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if set(candidates) <= set([p.name for p in slist] + imports):
|
if set(candidates) <= set([p.name for p in slist] + imported_classes):
|
||||||
can_add = True
|
can_add = True
|
||||||
|
|
||||||
if can_add:
|
if can_add:
|
||||||
|
@ -410,6 +438,9 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"could not find suitable element in {clist} that does not ref {slist}"
|
f"could not find suitable element in {clist} that does not ref {slist}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
|
||||||
|
self.sorted_class_names += [camelcase(c.name) for c in slist]
|
||||||
return slist
|
return slist
|
||||||
|
|
||||||
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
|
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
|
||||||
|
@ -449,7 +480,47 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
parents[camelcase(class_def.name)] = class_parents
|
parents[camelcase(class_def.name)] = class_parents
|
||||||
return parents
|
return parents
|
||||||
|
|
||||||
|
def get_predefined_slot_value(self, slot: SlotDefinition, class_def: ClassDefinition) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Modified from base pydantic generator to use already grabbed induced_slot from
|
||||||
|
already-grabbed and modified classes rather than doing a fresh iteration to
|
||||||
|
save time and respect changes already made elsewhere in the serialization routine
|
||||||
|
|
||||||
|
:return: Dictionary of dictionaries with predefined slot values for each class
|
||||||
|
"""
|
||||||
|
sv = self.schemaview
|
||||||
|
slot_value: Optional[str] = None
|
||||||
|
# for class_def in sv.all_classes().values():
|
||||||
|
# for slot_name in sv.class_slots(class_def.name):
|
||||||
|
# slot = sv.induced_slot(slot_name, class_def.name)
|
||||||
|
if slot.designates_type:
|
||||||
|
target_value = get_type_designator_value(sv, slot, class_def)
|
||||||
|
slot_value = f'"{target_value}"'
|
||||||
|
if slot.multivalued:
|
||||||
|
slot_value = (
|
||||||
|
"[" + slot_value + "]"
|
||||||
|
)
|
||||||
|
elif slot.ifabsent is not None:
|
||||||
|
value = ifabsent_value_declaration(slot.ifabsent, sv, class_def, slot)
|
||||||
|
slot_value = value
|
||||||
|
# Multivalued slots that are either not inlined (just an identifier) or are
|
||||||
|
# inlined as lists should get default_factory list, if they're inlined but
|
||||||
|
# not as a list, that means a dictionary
|
||||||
|
elif slot.multivalued:
|
||||||
|
# this is slow, needs to do additional induced slot calls
|
||||||
|
#has_identifier_slot = self.range_class_has_identifier_slot(slot)
|
||||||
|
|
||||||
|
if slot.inlined and not slot.inlined_as_list: # and has_identifier_slot:
|
||||||
|
slot_value = "default_factory=dict"
|
||||||
|
else:
|
||||||
|
slot_value = "default_factory=list"
|
||||||
|
|
||||||
|
return slot_value
|
||||||
|
|
||||||
def serialize(self) -> str:
|
def serialize(self) -> str:
|
||||||
|
predefined_slot_values = {}
|
||||||
|
"""splitting up parent class :meth:`.get_predefined_slot_values`"""
|
||||||
|
|
||||||
if self.template_file is not None:
|
if self.template_file is not None:
|
||||||
with open(self.template_file) as template_file:
|
with open(self.template_file) as template_file:
|
||||||
template_obj = Template(template_file.read())
|
template_obj = Template(template_file.read())
|
||||||
|
@ -458,34 +529,28 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
|
|
||||||
sv: SchemaView
|
sv: SchemaView
|
||||||
sv = self.schemaview
|
sv = self.schemaview
|
||||||
|
sv.schema_map = self.schema_map
|
||||||
schema = sv.schema
|
schema = sv.schema
|
||||||
pyschema = SchemaDefinition(
|
pyschema = SchemaDefinition(
|
||||||
id=schema.id,
|
id=schema.id,
|
||||||
name=schema.name,
|
name=schema.name,
|
||||||
description=schema.description.replace('"', '\\"') if schema.description else None,
|
description=schema.description.replace('"', '\\"') if schema.description else None,
|
||||||
)
|
)
|
||||||
|
# test caching if import closure
|
||||||
enums = self.generate_enums(sv.all_enums())
|
enums = self.generate_enums(sv.all_enums())
|
||||||
# filter skipped enums
|
# filter skipped enums
|
||||||
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
|
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
|
||||||
|
|
||||||
if self.split:
|
classes = self._get_classes(sv)
|
||||||
# import from local references, rather than serializing every class in every file
|
# just induce slots once because that turns out to be expensive
|
||||||
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True':
|
class_slots = {} # type: Dict[str, List[SlotDefinition]]
|
||||||
imports = self._get_namespace_imports(sv)
|
for aclass in classes:
|
||||||
else:
|
class_slots[aclass.name] = self._get_class_slots(sv, aclass)
|
||||||
imports = self._get_imports(sv)
|
|
||||||
|
|
||||||
sorted_classes = self._get_classes(sv, imports)
|
|
||||||
else:
|
|
||||||
sorted_classes = self.sort_classes(list(sv.all_classes().values()), [])
|
|
||||||
imports = {}
|
|
||||||
|
|
||||||
# Don't want to generate classes when class_uri is linkml:Any, will
|
|
||||||
# just swap in typing.Any instead down below
|
|
||||||
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
|
|
||||||
self.sorted_class_names = [camelcase(c.name) for c in sorted_classes]
|
|
||||||
|
|
||||||
|
# figure out what classes we need to imports
|
||||||
|
imports = self._get_imports(sv, classes, class_slots)
|
||||||
|
|
||||||
|
sorted_classes = self.sort_classes(classes, imports)
|
||||||
|
|
||||||
for class_original in sorted_classes:
|
for class_original in sorted_classes:
|
||||||
# Generate class definition
|
# Generate class definition
|
||||||
|
@ -503,11 +568,12 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
del class_def.attributes[attribute]
|
del class_def.attributes[attribute]
|
||||||
|
|
||||||
class_name = class_original.name
|
class_name = class_original.name
|
||||||
for sn in sv.class_slots(class_name):
|
predefined_slot_values[camelcase(class_name)] = {}
|
||||||
if sn in self.SKIP_SLOTS:
|
for s in class_slots[class_name]:
|
||||||
continue
|
sn = SlotDefinitionName(s.name)
|
||||||
# TODO: fix runtime, copy should not be necessary
|
predefined_slot_value = self.get_predefined_slot_value(s, class_def)
|
||||||
s = deepcopy(sv.induced_slot(sn, class_name))
|
if predefined_slot_value is not None:
|
||||||
|
predefined_slot_values[camelcase(class_name)][s.name] = predefined_slot_value
|
||||||
# logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}')
|
# logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}')
|
||||||
s.name = underscore(s.name)
|
s.name = underscore(s.name)
|
||||||
if s.description:
|
if s.description:
|
||||||
|
@ -557,7 +623,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
schema=pyschema,
|
schema=pyschema,
|
||||||
underscore=underscore,
|
underscore=underscore,
|
||||||
enums=enums,
|
enums=enums,
|
||||||
predefined_slot_values=self.get_predefined_slot_values(),
|
predefined_slot_values=predefined_slot_values,
|
||||||
allow_extra=self.allow_extra,
|
allow_extra=self.allow_extra,
|
||||||
metamodel_version=self.schema.metamodel_version,
|
metamodel_version=self.schema.metamodel_version,
|
||||||
version=self.schema.version,
|
version=self.schema.version,
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
import pdb
|
import pdb
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .fixtures import nwb_core_fixture, tmp_output_dir
|
from .fixtures import nwb_core_fixture, tmp_output_dir
|
||||||
from linkml_runtime.dumpers import yaml_dumper
|
from linkml_runtime.dumpers import yaml_dumper
|
||||||
|
from linkml_runtime.linkml_model import SchemaDefinition
|
||||||
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
|
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
|
||||||
|
from linkml_runtime.loaders.yaml_loader import YAMLLoader
|
||||||
|
|
||||||
from nwb_linkml.lang_elements import NwbLangSchema
|
from nwb_linkml.lang_elements import NwbLangSchema
|
||||||
|
|
||||||
|
@ -22,29 +26,38 @@ def test_generate_core(nwb_core_fixture, tmp_output_dir):
|
||||||
output_file = tmp_output_dir / 'schema' / (schema.name + '.yaml')
|
output_file = tmp_output_dir / 'schema' / (schema.name + '.yaml')
|
||||||
yaml_dumper.dump(schema, output_file)
|
yaml_dumper.dump(schema, output_file)
|
||||||
|
|
||||||
|
def load_schema_files(path: Path) -> Dict[str, SchemaDefinition]:
|
||||||
|
yaml_loader = YAMLLoader()
|
||||||
|
sch: SchemaDefinition
|
||||||
|
preloaded_schema = {}
|
||||||
|
for schema_path in (path / 'schema').glob('*.yaml'):
|
||||||
|
sch = yaml_loader.load(str(schema_path), target_class=SchemaDefinition)
|
||||||
|
preloaded_schema[sch.name] = sch
|
||||||
|
return preloaded_schema
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.depends(on=['test_generate_core'])
|
@pytest.mark.depends(on=['test_generate_core'])
|
||||||
def test_generate_pydantic(tmp_output_dir):
|
def test_generate_pydantic(tmp_output_dir):
|
||||||
|
|
||||||
|
|
||||||
# core_file = tmp_output_dir / 'core.yaml'
|
|
||||||
# pydantic_file = tmp_output_dir / 'core.py'
|
|
||||||
|
|
||||||
(tmp_output_dir / 'models').mkdir(exist_ok=True)
|
(tmp_output_dir / 'models').mkdir(exist_ok=True)
|
||||||
|
|
||||||
for schema in (tmp_output_dir / 'schema').glob('*.yaml'):
|
preloaded_schema = load_schema_files(tmp_output_dir)
|
||||||
if not schema.exists():
|
|
||||||
|
for schema_path in (tmp_output_dir / 'schema').glob('*.yaml'):
|
||||||
|
if not schema_path.exists():
|
||||||
continue
|
continue
|
||||||
# python friendly name
|
# python friendly name
|
||||||
python_name = schema.stem.replace('.', '_').replace('-','_')
|
python_name = schema_path.stem.replace('.', '_').replace('-','_')
|
||||||
|
|
||||||
pydantic_file = (schema.parent.parent / 'models' / python_name).with_suffix('.py')
|
pydantic_file = (schema_path.parent.parent / 'models' / python_name).with_suffix('.py')
|
||||||
|
|
||||||
generator = NWBPydanticGenerator(
|
generator = NWBPydanticGenerator(
|
||||||
str(schema),
|
str(schema_path),
|
||||||
pydantic_version='2',
|
pydantic_version='2',
|
||||||
emit_metadata=True,
|
emit_metadata=True,
|
||||||
gen_classvars=True,
|
gen_classvars=True,
|
||||||
gen_slots=True
|
gen_slots=True,
|
||||||
|
schema_map=preloaded_schema
|
||||||
)
|
)
|
||||||
gen_pydantic = generator.serialize()
|
gen_pydantic = generator.serialize()
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ from nwb_linkml.io.hdf5 import HDF5IO
|
||||||
|
|
||||||
def test_hdf_read():
|
def test_hdf_read():
|
||||||
NWBFILE = Path('/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb')
|
NWBFILE = Path('/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb')
|
||||||
|
if not NWBFILE.exists():
|
||||||
|
return
|
||||||
io = HDF5IO(path=NWBFILE)
|
io = HDF5IO(path=NWBFILE)
|
||||||
model = io.read('/general')
|
model = io.read('/general')
|
||||||
|
|
Loading…
Reference in a new issue