mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 06:04:28 +00:00
- Don't split schema into includes and base
- Import top-level classes in namespaces pydantic file - Don't generate or import the arraylike class
This commit is contained in:
parent
9dd7304334
commit
bb9dda6e66
4 changed files with 105 additions and 30 deletions
|
@ -21,11 +21,13 @@ class NamespacesAdapter(Adapter):
|
|||
schemas: List[SchemaAdapter]
|
||||
imported: List['NamespacesAdapter'] = Field(default_factory=list)
|
||||
|
||||
_imports_populated = PrivateAttr(False)
|
||||
_imports_populated: bool = PrivateAttr(False)
|
||||
_split: bool = PrivateAttr(False)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(NamespacesAdapter, self).__init__(**kwargs)
|
||||
self._populate_schema_namespaces()
|
||||
self.split = self._split
|
||||
|
||||
def build(self) -> BuildResult:
|
||||
if not self._imports_populated:
|
||||
|
@ -51,12 +53,36 @@ class NamespacesAdapter(Adapter):
|
|||
id = ns.name,
|
||||
description = ns.doc,
|
||||
version = ns.version,
|
||||
imports=[sch.name for sch in ns_schemas]
|
||||
imports=[sch.name for sch in ns_schemas],
|
||||
annotations=[{'tag': 'namespace', 'value': True}]
|
||||
)
|
||||
sch_result.schemas.append(ns_schema)
|
||||
|
||||
return sch_result
|
||||
|
||||
@property
|
||||
def split(self) -> bool:
|
||||
"""
|
||||
Sets the :attr:`.SchemaAdapter.split` attribute for all contained and imported schema
|
||||
|
||||
Args:
|
||||
split (bool): Set the generated schema to be split or not
|
||||
|
||||
Returns:
|
||||
bool: whether the schema are set to be split!
|
||||
"""
|
||||
return self._split
|
||||
|
||||
@split.setter
|
||||
def split(self, split):
|
||||
for sch in self.schemas:
|
||||
sch.split = split
|
||||
for ns in self.imported:
|
||||
for sch in ns.schemas:
|
||||
sch.split = split
|
||||
|
||||
self._split = split
|
||||
|
||||
def _populate_schema_namespaces(self):
|
||||
# annotate for each schema which namespace imports it
|
||||
for sch in self.schemas:
|
||||
|
|
|
@ -35,7 +35,7 @@ class SchemaAdapter(Adapter):
|
|||
None,
|
||||
description="""String of containing namespace. Populated by NamespacesAdapter""")
|
||||
split: bool = Field(
|
||||
True,
|
||||
False,
|
||||
description="Split anonymous subclasses into a separate schema file"
|
||||
)
|
||||
_created_classes: List[Type[Group | Dataset]] = PrivateAttr(default_factory=list)
|
||||
|
|
|
@ -17,7 +17,7 @@ The `serialize` method
|
|||
|
||||
"""
|
||||
import pdb
|
||||
from typing import List, Dict, Set
|
||||
from typing import List, Dict, Set, Tuple, Optional
|
||||
from copy import deepcopy
|
||||
import warnings
|
||||
|
||||
|
@ -174,6 +174,52 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
|
||||
SKIP_ENUM=('FlatDType',)
|
||||
|
||||
def _locate_imports(
|
||||
self,
|
||||
needed_classes:List[str],
|
||||
sv:SchemaView
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Given a list of class names, find the python modules that need to be imported
|
||||
"""
|
||||
imports = {}
|
||||
|
||||
# These classes are not generated by pydantic!
|
||||
skips = ('AnyType',)
|
||||
|
||||
for cls in needed_classes:
|
||||
if cls in skips:
|
||||
continue
|
||||
# Find module that contains class
|
||||
module_name = sv.element_by_schema_map()[ElementName(cls)]
|
||||
# Don't get classes that are defined in this schema!
|
||||
if module_name == self.schema.name:
|
||||
continue
|
||||
|
||||
local_mod_name = '.' + module_name.replace('.', '_').replace('-', '_')
|
||||
if local_mod_name not in imports:
|
||||
imports[local_mod_name] = [camelcase(cls)]
|
||||
else:
|
||||
imports[local_mod_name].append(camelcase(cls))
|
||||
return imports
|
||||
|
||||
def _get_namespace_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get imports for namespace packages. For these we import all
|
||||
the tree_root classes, ie. all the classes that are top-level classes rather than
|
||||
rather than nested classes
|
||||
"""
|
||||
all_classes = sv.all_classes(imports=True)
|
||||
needed_classes = []
|
||||
for clsname, cls in all_classes.items():
|
||||
if cls.tree_root:
|
||||
needed_classes.append(clsname)
|
||||
|
||||
imports = self._locate_imports(needed_classes, sv)
|
||||
return imports
|
||||
|
||||
|
||||
|
||||
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
|
||||
all_classes = sv.all_classes(imports=True)
|
||||
local_classes = sv.all_classes(imports=False)
|
||||
|
@ -190,26 +236,10 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
if any_slot_range.range in all_classes:
|
||||
needed_classes.append(any_slot_range.range)
|
||||
|
||||
needed_classes = [cls for cls in set(needed_classes) if cls is not None]
|
||||
imports = {}
|
||||
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
|
||||
needed_classes = [cls for cls in needed_classes if sv.get_class(cls).is_a != 'Arraylike']
|
||||
|
||||
# These classes are not generated by pydantic!
|
||||
skips = ('AnyType',)
|
||||
|
||||
for cls in needed_classes:
|
||||
if cls in skips:
|
||||
continue
|
||||
# Find module that contains class
|
||||
module_name = sv.element_by_schema_map()[ElementName(cls)]
|
||||
# Don't get classes that are defined in this schema!
|
||||
if module_name == self.schema.name:
|
||||
continue
|
||||
|
||||
local_mod_name = '.' + module_name.replace('.', '_').replace('-','_')
|
||||
if local_mod_name not in imports:
|
||||
imports[local_mod_name] = [camelcase(cls)]
|
||||
else:
|
||||
imports[local_mod_name].append(camelcase(cls))
|
||||
imports = self._locate_imports(needed_classes, sv)
|
||||
|
||||
return imports
|
||||
|
||||
|
@ -219,8 +249,11 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
imported_classes = []
|
||||
for classes in imports.values():
|
||||
imported_classes.extend(classes)
|
||||
# pdb.set_trace()
|
||||
sorted_classes = self.sort_classes(list(module_classes), imported_classes)
|
||||
|
||||
module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
|
||||
imported_classes = [c for c in imported_classes if sv.get_class(c).is_a != 'Arraylike']
|
||||
|
||||
sorted_classes = self.sort_classes(module_classes, imported_classes)
|
||||
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
|
||||
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
|
||||
|
||||
|
@ -361,7 +394,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
else:
|
||||
return super().get_class_slot_range(slot_range, inlined, inlined_as_list)
|
||||
|
||||
def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]:
|
||||
def get_class_isa_plus_mixins(self, classes:Optional[List[ClassDefinition]] = None) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Generate the inheritance list for each class from is_a plus mixins
|
||||
|
||||
|
@ -370,8 +403,11 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
:return:
|
||||
"""
|
||||
sv = self.schemaview
|
||||
if classes is None:
|
||||
classes = sv.all_classes(imports=False).values()
|
||||
|
||||
parents = {}
|
||||
for class_def in sv.all_classes(imports=False).values():
|
||||
for class_def in classes:
|
||||
class_parents = []
|
||||
if class_def.is_a:
|
||||
class_parents.append(camelcase(class_def.is_a))
|
||||
|
@ -404,14 +440,23 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
|
||||
|
||||
# import from local references, rather than serializing every class in every file
|
||||
imports = self._get_imports(sv)
|
||||
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True':
|
||||
imports = self._get_namespace_imports(sv)
|
||||
else:
|
||||
imports = self._get_imports(sv)
|
||||
|
||||
sorted_classes = self._get_classes(sv, imports)
|
||||
|
||||
for class_original in sorted_classes:
|
||||
# Generate class definition
|
||||
class_def = self._build_class(class_original)
|
||||
pyschema.classes[class_def.name] = class_def
|
||||
|
||||
if class_def.is_a != "Arraylike":
|
||||
# skip actually generating arraylike classes, just use them to generate
|
||||
# the npytyping annotations
|
||||
pyschema.classes[class_def.name] = class_def
|
||||
else:
|
||||
continue
|
||||
|
||||
# Not sure why this happens
|
||||
for attribute in list(class_def.attributes.keys()):
|
||||
|
@ -474,6 +519,6 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
allow_extra=self.allow_extra,
|
||||
metamodel_version=self.schema.metamodel_version,
|
||||
version=self.schema.version,
|
||||
class_isa_plus_mixins=self.get_class_isa_plus_mixins(),
|
||||
class_isa_plus_mixins=self.get_class_isa_plus_mixins(sorted_classes),
|
||||
)
|
||||
return code
|
|
@ -52,3 +52,7 @@ def test_generate_pydantic(tmp_output_dir):
|
|||
|
||||
with open(pydantic_file, 'w') as pfile:
|
||||
pfile.write(gen_pydantic)
|
||||
|
||||
# make __init__.py
|
||||
with open(tmp_output_dir / 'models' / '__init__.py', 'w') as initfile:
|
||||
initfile.write('# Autogenerated module indicator')
|
||||
|
|
Loading…
Reference in a new issue