- Don't split schema into includes and base

- Import top-level classes in namespaces pydantic file
- Don't generate or import the arraylike class
This commit is contained in:
sneakers-the-rat 2023-09-04 14:56:36 -07:00
parent 9dd7304334
commit bb9dda6e66
4 changed files with 105 additions and 30 deletions

View file

@ -21,11 +21,13 @@ class NamespacesAdapter(Adapter):
schemas: List[SchemaAdapter]
imported: List['NamespacesAdapter'] = Field(default_factory=list)
_imports_populated = PrivateAttr(False)
_imports_populated: bool = PrivateAttr(False)
_split: bool = PrivateAttr(False)
def __init__(self, **kwargs):
super(NamespacesAdapter, self).__init__(**kwargs)
self._populate_schema_namespaces()
self.split = self._split
def build(self) -> BuildResult:
if not self._imports_populated:
@ -51,12 +53,36 @@ class NamespacesAdapter(Adapter):
id = ns.name,
description = ns.doc,
version = ns.version,
imports=[sch.name for sch in ns_schemas]
imports=[sch.name for sch in ns_schemas],
annotations=[{'tag': 'namespace', 'value': True}]
)
sch_result.schemas.append(ns_schema)
return sch_result
@property
def split(self) -> bool:
"""
Sets the :attr:`.SchemaAdapter.split` attribute for all contained and imported schema
Args:
split (bool): Set the generated schema to be split or not
Returns:
bool: whether the schema are set to be split!
"""
return self._split
@split.setter
def split(self, split):
for sch in self.schemas:
sch.split = split
for ns in self.imported:
for sch in ns.schemas:
sch.split = split
self._split = split
def _populate_schema_namespaces(self):
# annotate for each schema which namespace imports it
for sch in self.schemas:

View file

@ -35,7 +35,7 @@ class SchemaAdapter(Adapter):
None,
description="""String of containing namespace. Populated by NamespacesAdapter""")
split: bool = Field(
True,
False,
description="Split anonymous subclasses into a separate schema file"
)
_created_classes: List[Type[Group | Dataset]] = PrivateAttr(default_factory=list)

View file

@ -17,7 +17,7 @@ The `serialize` method
"""
import pdb
from typing import List, Dict, Set
from typing import List, Dict, Set, Tuple, Optional
from copy import deepcopy
import warnings
@ -174,6 +174,52 @@ class NWBPydanticGenerator(PydanticGenerator):
SKIP_ENUM=('FlatDType',)
def _locate_imports(
self,
needed_classes:List[str],
sv:SchemaView
) -> Dict[str, List[str]]:
"""
Given a list of class names, find the python modules that need to be imported
"""
imports = {}
# These classes are not generated by pydantic!
skips = ('AnyType',)
for cls in needed_classes:
if cls in skips:
continue
# Find module that contains class
module_name = sv.element_by_schema_map()[ElementName(cls)]
# Don't get classes that are defined in this schema!
if module_name == self.schema.name:
continue
local_mod_name = '.' + module_name.replace('.', '_').replace('-', '_')
if local_mod_name not in imports:
imports[local_mod_name] = [camelcase(cls)]
else:
imports[local_mod_name].append(camelcase(cls))
return imports
def _get_namespace_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
"""
Get imports for namespace packages. For these we import all
the tree_root classes, ie. all the classes that are top-level classes rather than
rather than nested classes
"""
all_classes = sv.all_classes(imports=True)
needed_classes = []
for clsname, cls in all_classes.items():
if cls.tree_root:
needed_classes.append(clsname)
imports = self._locate_imports(needed_classes, sv)
return imports
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
all_classes = sv.all_classes(imports=True)
local_classes = sv.all_classes(imports=False)
@ -190,26 +236,10 @@ class NWBPydanticGenerator(PydanticGenerator):
if any_slot_range.range in all_classes:
needed_classes.append(any_slot_range.range)
needed_classes = [cls for cls in set(needed_classes) if cls is not None]
imports = {}
needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
needed_classes = [cls for cls in needed_classes if sv.get_class(cls).is_a != 'Arraylike']
# These classes are not generated by pydantic!
skips = ('AnyType',)
for cls in needed_classes:
if cls in skips:
continue
# Find module that contains class
module_name = sv.element_by_schema_map()[ElementName(cls)]
# Don't get classes that are defined in this schema!
if module_name == self.schema.name:
continue
local_mod_name = '.' + module_name.replace('.', '_').replace('-','_')
if local_mod_name not in imports:
imports[local_mod_name] = [camelcase(cls)]
else:
imports[local_mod_name].append(camelcase(cls))
imports = self._locate_imports(needed_classes, sv)
return imports
@ -219,8 +249,11 @@ class NWBPydanticGenerator(PydanticGenerator):
imported_classes = []
for classes in imports.values():
imported_classes.extend(classes)
# pdb.set_trace()
sorted_classes = self.sort_classes(list(module_classes), imported_classes)
module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
imported_classes = [c for c in imported_classes if sv.get_class(c).is_a != 'Arraylike']
sorted_classes = self.sort_classes(module_classes, imported_classes)
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
@ -361,7 +394,7 @@ class NWBPydanticGenerator(PydanticGenerator):
else:
return super().get_class_slot_range(slot_range, inlined, inlined_as_list)
def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]:
def get_class_isa_plus_mixins(self, classes:Optional[List[ClassDefinition]] = None) -> Dict[str, List[str]]:
"""
Generate the inheritance list for each class from is_a plus mixins
@ -370,8 +403,11 @@ class NWBPydanticGenerator(PydanticGenerator):
:return:
"""
sv = self.schemaview
if classes is None:
classes = sv.all_classes(imports=False).values()
parents = {}
for class_def in sv.all_classes(imports=False).values():
for class_def in classes:
class_parents = []
if class_def.is_a:
class_parents.append(camelcase(class_def.is_a))
@ -404,14 +440,23 @@ class NWBPydanticGenerator(PydanticGenerator):
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
# import from local references, rather than serializing every class in every file
imports = self._get_imports(sv)
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True':
imports = self._get_namespace_imports(sv)
else:
imports = self._get_imports(sv)
sorted_classes = self._get_classes(sv, imports)
for class_original in sorted_classes:
# Generate class definition
class_def = self._build_class(class_original)
pyschema.classes[class_def.name] = class_def
if class_def.is_a != "Arraylike":
# skip actually generating arraylike classes, just use them to generate
# the npytyping annotations
pyschema.classes[class_def.name] = class_def
else:
continue
# Not sure why this happens
for attribute in list(class_def.attributes.keys()):
@ -474,6 +519,6 @@ class NWBPydanticGenerator(PydanticGenerator):
allow_extra=self.allow_extra,
metamodel_version=self.schema.metamodel_version,
version=self.schema.version,
class_isa_plus_mixins=self.get_class_isa_plus_mixins(),
class_isa_plus_mixins=self.get_class_isa_plus_mixins(sorted_classes),
)
return code

View file

@ -52,3 +52,7 @@ def test_generate_pydantic(tmp_output_dir):
with open(pydantic_file, 'w') as pfile:
pfile.write(gen_pydantic)
# make __init__.py
with open(tmp_output_dir / 'models' / '__init__.py', 'w') as initfile:
initfile.write('# Autogenerated module indicator')