mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-10 14:14:27 +00:00
- Don't split schema into includes and base
- Import top-level classes in namespaces pydantic file - Don't generate or import the arraylike class
This commit is contained in:
parent
9dd7304334
commit
bb9dda6e66
4 changed files with 105 additions and 30 deletions
|
@ -21,11 +21,13 @@ class NamespacesAdapter(Adapter):
|
||||||
schemas: List[SchemaAdapter]
|
schemas: List[SchemaAdapter]
|
||||||
imported: List['NamespacesAdapter'] = Field(default_factory=list)
|
imported: List['NamespacesAdapter'] = Field(default_factory=list)
|
||||||
|
|
||||||
_imports_populated = PrivateAttr(False)
|
_imports_populated: bool = PrivateAttr(False)
|
||||||
|
_split: bool = PrivateAttr(False)
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(NamespacesAdapter, self).__init__(**kwargs)
|
super(NamespacesAdapter, self).__init__(**kwargs)
|
||||||
self._populate_schema_namespaces()
|
self._populate_schema_namespaces()
|
||||||
|
self.split = self._split
|
||||||
|
|
||||||
def build(self) -> BuildResult:
|
def build(self) -> BuildResult:
|
||||||
if not self._imports_populated:
|
if not self._imports_populated:
|
||||||
|
@ -51,12 +53,36 @@ class NamespacesAdapter(Adapter):
|
||||||
id = ns.name,
|
id = ns.name,
|
||||||
description = ns.doc,
|
description = ns.doc,
|
||||||
version = ns.version,
|
version = ns.version,
|
||||||
imports=[sch.name for sch in ns_schemas]
|
imports=[sch.name for sch in ns_schemas],
|
||||||
|
annotations=[{'tag': 'namespace', 'value': True}]
|
||||||
)
|
)
|
||||||
sch_result.schemas.append(ns_schema)
|
sch_result.schemas.append(ns_schema)
|
||||||
|
|
||||||
return sch_result
|
return sch_result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def split(self) -> bool:
|
||||||
|
"""
|
||||||
|
Sets the :attr:`.SchemaAdapter.split` attribute for all contained and imported schema
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split (bool): Set the generated schema to be split or not
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether the schema are set to be split!
|
||||||
|
"""
|
||||||
|
return self._split
|
||||||
|
|
||||||
|
@split.setter
|
||||||
|
def split(self, split):
|
||||||
|
for sch in self.schemas:
|
||||||
|
sch.split = split
|
||||||
|
for ns in self.imported:
|
||||||
|
for sch in ns.schemas:
|
||||||
|
sch.split = split
|
||||||
|
|
||||||
|
self._split = split
|
||||||
|
|
||||||
def _populate_schema_namespaces(self):
|
def _populate_schema_namespaces(self):
|
||||||
# annotate for each schema which namespace imports it
|
# annotate for each schema which namespace imports it
|
||||||
for sch in self.schemas:
|
for sch in self.schemas:
|
||||||
|
|
|
@ -35,7 +35,7 @@ class SchemaAdapter(Adapter):
|
||||||
None,
|
None,
|
||||||
description="""String of containing namespace. Populated by NamespacesAdapter""")
|
description="""String of containing namespace. Populated by NamespacesAdapter""")
|
||||||
split: bool = Field(
|
split: bool = Field(
|
||||||
True,
|
False,
|
||||||
description="Split anonymous subclasses into a separate schema file"
|
description="Split anonymous subclasses into a separate schema file"
|
||||||
)
|
)
|
||||||
_created_classes: List[Type[Group | Dataset]] = PrivateAttr(default_factory=list)
|
_created_classes: List[Type[Group | Dataset]] = PrivateAttr(default_factory=list)
|
||||||
|
|
|
@ -17,7 +17,7 @@ The `serialize` method
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import pdb
|
import pdb
|
||||||
from typing import List, Dict, Set
|
from typing import List, Dict, Set, Tuple, Optional
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
@ -174,6 +174,52 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
|
|
||||||
SKIP_ENUM=('FlatDType',)
|
SKIP_ENUM=('FlatDType',)
|
||||||
|
|
||||||
|
def _locate_imports(
|
||||||
|
self,
|
||||||
|
needed_classes:List[str],
|
||||||
|
sv:SchemaView
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Given a list of class names, find the python modules that need to be imported
|
||||||
|
"""
|
||||||
|
imports = {}
|
||||||
|
|
||||||
|
# These classes are not generated by pydantic!
|
||||||
|
skips = ('AnyType',)
|
||||||
|
|
||||||
|
for cls in needed_classes:
|
||||||
|
if cls in skips:
|
||||||
|
continue
|
||||||
|
# Find module that contains class
|
||||||
|
module_name = sv.element_by_schema_map()[ElementName(cls)]
|
||||||
|
# Don't get classes that are defined in this schema!
|
||||||
|
if module_name == self.schema.name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
local_mod_name = '.' + module_name.replace('.', '_').replace('-', '_')
|
||||||
|
if local_mod_name not in imports:
|
||||||
|
imports[local_mod_name] = [camelcase(cls)]
|
||||||
|
else:
|
||||||
|
imports[local_mod_name].append(camelcase(cls))
|
||||||
|
return imports
|
||||||
|
|
||||||
|
def _get_namespace_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get imports for namespace packages. For these we import all
|
||||||
|
the tree_root classes, ie. all the classes that are top-level classes rather than
|
||||||
|
rather than nested classes
|
||||||
|
"""
|
||||||
|
all_classes = sv.all_classes(imports=True)
|
||||||
|
needed_classes = []
|
||||||
|
for clsname, cls in all_classes.items():
|
||||||
|
if cls.tree_root:
|
||||||
|
needed_classes.append(clsname)
|
||||||
|
|
||||||
|
imports = self._locate_imports(needed_classes, sv)
|
||||||
|
return imports
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
|
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
|
||||||
all_classes = sv.all_classes(imports=True)
|
all_classes = sv.all_classes(imports=True)
|
||||||
local_classes = sv.all_classes(imports=False)
|
local_classes = sv.all_classes(imports=False)
|
||||||
|
@ -190,26 +236,10 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
if any_slot_range.range in all_classes:
|
if any_slot_range.range in all_classes:
|
||||||
needed_classes.append(any_slot_range.range)
|
needed_classes.append(any_slot_range.range)
|
||||||
|
|
||||||
needed_classes = [cls for cls in set(needed_classes) if cls is not None]
|
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
|
||||||
imports = {}
|
needed_classes = [cls for cls in needed_classes if sv.get_class(cls).is_a != 'Arraylike']
|
||||||
|
|
||||||
# These classes are not generated by pydantic!
|
imports = self._locate_imports(needed_classes, sv)
|
||||||
skips = ('AnyType',)
|
|
||||||
|
|
||||||
for cls in needed_classes:
|
|
||||||
if cls in skips:
|
|
||||||
continue
|
|
||||||
# Find module that contains class
|
|
||||||
module_name = sv.element_by_schema_map()[ElementName(cls)]
|
|
||||||
# Don't get classes that are defined in this schema!
|
|
||||||
if module_name == self.schema.name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
local_mod_name = '.' + module_name.replace('.', '_').replace('-','_')
|
|
||||||
if local_mod_name not in imports:
|
|
||||||
imports[local_mod_name] = [camelcase(cls)]
|
|
||||||
else:
|
|
||||||
imports[local_mod_name].append(camelcase(cls))
|
|
||||||
|
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
|
@ -219,8 +249,11 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
imported_classes = []
|
imported_classes = []
|
||||||
for classes in imports.values():
|
for classes in imports.values():
|
||||||
imported_classes.extend(classes)
|
imported_classes.extend(classes)
|
||||||
# pdb.set_trace()
|
|
||||||
sorted_classes = self.sort_classes(list(module_classes), imported_classes)
|
module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
|
||||||
|
imported_classes = [c for c in imported_classes if sv.get_class(c).is_a != 'Arraylike']
|
||||||
|
|
||||||
|
sorted_classes = self.sort_classes(module_classes, imported_classes)
|
||||||
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
|
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
|
||||||
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
|
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
|
||||||
|
|
||||||
|
@ -361,7 +394,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
else:
|
else:
|
||||||
return super().get_class_slot_range(slot_range, inlined, inlined_as_list)
|
return super().get_class_slot_range(slot_range, inlined, inlined_as_list)
|
||||||
|
|
||||||
def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]:
|
def get_class_isa_plus_mixins(self, classes:Optional[List[ClassDefinition]] = None) -> Dict[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
Generate the inheritance list for each class from is_a plus mixins
|
Generate the inheritance list for each class from is_a plus mixins
|
||||||
|
|
||||||
|
@ -370,8 +403,11 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
sv = self.schemaview
|
sv = self.schemaview
|
||||||
|
if classes is None:
|
||||||
|
classes = sv.all_classes(imports=False).values()
|
||||||
|
|
||||||
parents = {}
|
parents = {}
|
||||||
for class_def in sv.all_classes(imports=False).values():
|
for class_def in classes:
|
||||||
class_parents = []
|
class_parents = []
|
||||||
if class_def.is_a:
|
if class_def.is_a:
|
||||||
class_parents.append(camelcase(class_def.is_a))
|
class_parents.append(camelcase(class_def.is_a))
|
||||||
|
@ -404,14 +440,23 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
|
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
|
||||||
|
|
||||||
# import from local references, rather than serializing every class in every file
|
# import from local references, rather than serializing every class in every file
|
||||||
imports = self._get_imports(sv)
|
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True':
|
||||||
|
imports = self._get_namespace_imports(sv)
|
||||||
|
else:
|
||||||
|
imports = self._get_imports(sv)
|
||||||
|
|
||||||
sorted_classes = self._get_classes(sv, imports)
|
sorted_classes = self._get_classes(sv, imports)
|
||||||
|
|
||||||
for class_original in sorted_classes:
|
for class_original in sorted_classes:
|
||||||
# Generate class definition
|
# Generate class definition
|
||||||
class_def = self._build_class(class_original)
|
class_def = self._build_class(class_original)
|
||||||
pyschema.classes[class_def.name] = class_def
|
|
||||||
|
if class_def.is_a != "Arraylike":
|
||||||
|
# skip actually generating arraylike classes, just use them to generate
|
||||||
|
# the npytyping annotations
|
||||||
|
pyschema.classes[class_def.name] = class_def
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
# Not sure why this happens
|
# Not sure why this happens
|
||||||
for attribute in list(class_def.attributes.keys()):
|
for attribute in list(class_def.attributes.keys()):
|
||||||
|
@ -474,6 +519,6 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
allow_extra=self.allow_extra,
|
allow_extra=self.allow_extra,
|
||||||
metamodel_version=self.schema.metamodel_version,
|
metamodel_version=self.schema.metamodel_version,
|
||||||
version=self.schema.version,
|
version=self.schema.version,
|
||||||
class_isa_plus_mixins=self.get_class_isa_plus_mixins(),
|
class_isa_plus_mixins=self.get_class_isa_plus_mixins(sorted_classes),
|
||||||
)
|
)
|
||||||
return code
|
return code
|
|
@ -52,3 +52,7 @@ def test_generate_pydantic(tmp_output_dir):
|
||||||
|
|
||||||
with open(pydantic_file, 'w') as pfile:
|
with open(pydantic_file, 'w') as pfile:
|
||||||
pfile.write(gen_pydantic)
|
pfile.write(gen_pydantic)
|
||||||
|
|
||||||
|
# make __init__.py
|
||||||
|
with open(tmp_output_dir / 'models' / '__init__.py', 'w') as initfile:
|
||||||
|
initfile.write('# Autogenerated module indicator')
|
||||||
|
|
Loading…
Reference in a new issue