- Don't split schema into includes and base

- Import top-level classes in namespaces pydantic file
- Don't generate or import the arraylike class
This commit is contained in:
sneakers-the-rat 2023-09-04 14:56:36 -07:00
parent 9dd7304334
commit bb9dda6e66
4 changed files with 105 additions and 30 deletions

View file

@ -21,11 +21,13 @@ class NamespacesAdapter(Adapter):
schemas: List[SchemaAdapter] schemas: List[SchemaAdapter]
imported: List['NamespacesAdapter'] = Field(default_factory=list) imported: List['NamespacesAdapter'] = Field(default_factory=list)
_imports_populated = PrivateAttr(False) _imports_populated: bool = PrivateAttr(False)
_split: bool = PrivateAttr(False)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(NamespacesAdapter, self).__init__(**kwargs) super(NamespacesAdapter, self).__init__(**kwargs)
self._populate_schema_namespaces() self._populate_schema_namespaces()
self.split = self._split
def build(self) -> BuildResult: def build(self) -> BuildResult:
if not self._imports_populated: if not self._imports_populated:
@ -51,12 +53,36 @@ class NamespacesAdapter(Adapter):
id = ns.name, id = ns.name,
description = ns.doc, description = ns.doc,
version = ns.version, version = ns.version,
imports=[sch.name for sch in ns_schemas] imports=[sch.name for sch in ns_schemas],
annotations=[{'tag': 'namespace', 'value': True}]
) )
sch_result.schemas.append(ns_schema) sch_result.schemas.append(ns_schema)
return sch_result return sch_result
@property
def split(self) -> bool:
"""
Sets the :attr:`.SchemaAdapter.split` attribute for all contained and imported schema
Args:
split (bool): Set the generated schema to be split or not
Returns:
bool: whether the schema are set to be split!
"""
return self._split
@split.setter
def split(self, split):
for sch in self.schemas:
sch.split = split
for ns in self.imported:
for sch in ns.schemas:
sch.split = split
self._split = split
def _populate_schema_namespaces(self): def _populate_schema_namespaces(self):
# annotate for each schema which namespace imports it # annotate for each schema which namespace imports it
for sch in self.schemas: for sch in self.schemas:

View file

@ -35,7 +35,7 @@ class SchemaAdapter(Adapter):
None, None,
description="""String of containing namespace. Populated by NamespacesAdapter""") description="""String of containing namespace. Populated by NamespacesAdapter""")
split: bool = Field( split: bool = Field(
True, False,
description="Split anonymous subclasses into a separate schema file" description="Split anonymous subclasses into a separate schema file"
) )
_created_classes: List[Type[Group | Dataset]] = PrivateAttr(default_factory=list) _created_classes: List[Type[Group | Dataset]] = PrivateAttr(default_factory=list)

View file

@ -17,7 +17,7 @@ The `serialize` method
""" """
import pdb import pdb
from typing import List, Dict, Set from typing import List, Dict, Set, Tuple, Optional
from copy import deepcopy from copy import deepcopy
import warnings import warnings
@ -174,6 +174,52 @@ class NWBPydanticGenerator(PydanticGenerator):
SKIP_ENUM=('FlatDType',) SKIP_ENUM=('FlatDType',)
def _locate_imports(
self,
needed_classes:List[str],
sv:SchemaView
) -> Dict[str, List[str]]:
"""
Given a list of class names, find the python modules that need to be imported
"""
imports = {}
# These classes are not generated by pydantic!
skips = ('AnyType',)
for cls in needed_classes:
if cls in skips:
continue
# Find module that contains class
module_name = sv.element_by_schema_map()[ElementName(cls)]
# Don't get classes that are defined in this schema!
if module_name == self.schema.name:
continue
local_mod_name = '.' + module_name.replace('.', '_').replace('-', '_')
if local_mod_name not in imports:
imports[local_mod_name] = [camelcase(cls)]
else:
imports[local_mod_name].append(camelcase(cls))
return imports
def _get_namespace_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
"""
Get imports for namespace packages. For these we import all
the tree_root classes, ie. all the classes that are top-level classes rather than
rather than nested classes
"""
all_classes = sv.all_classes(imports=True)
needed_classes = []
for clsname, cls in all_classes.items():
if cls.tree_root:
needed_classes.append(clsname)
imports = self._locate_imports(needed_classes, sv)
return imports
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]: def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
all_classes = sv.all_classes(imports=True) all_classes = sv.all_classes(imports=True)
local_classes = sv.all_classes(imports=False) local_classes = sv.all_classes(imports=False)
@ -190,26 +236,10 @@ class NWBPydanticGenerator(PydanticGenerator):
if any_slot_range.range in all_classes: if any_slot_range.range in all_classes:
needed_classes.append(any_slot_range.range) needed_classes.append(any_slot_range.range)
needed_classes = [cls for cls in set(needed_classes) if cls is not None] needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike']
imports = {} needed_classes = [cls for cls in needed_classes if sv.get_class(cls).is_a != 'Arraylike']
# These classes are not generated by pydantic! imports = self._locate_imports(needed_classes, sv)
skips = ('AnyType',)
for cls in needed_classes:
if cls in skips:
continue
# Find module that contains class
module_name = sv.element_by_schema_map()[ElementName(cls)]
# Don't get classes that are defined in this schema!
if module_name == self.schema.name:
continue
local_mod_name = '.' + module_name.replace('.', '_').replace('-','_')
if local_mod_name not in imports:
imports[local_mod_name] = [camelcase(cls)]
else:
imports[local_mod_name].append(camelcase(cls))
return imports return imports
@ -219,8 +249,11 @@ class NWBPydanticGenerator(PydanticGenerator):
imported_classes = [] imported_classes = []
for classes in imports.values(): for classes in imports.values():
imported_classes.extend(classes) imported_classes.extend(classes)
# pdb.set_trace()
sorted_classes = self.sort_classes(list(module_classes), imported_classes) module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike']
imported_classes = [c for c in imported_classes if sv.get_class(c).is_a != 'Arraylike']
sorted_classes = self.sort_classes(module_classes, imported_classes)
self.sorted_class_names = [camelcase(cname) for cname in imported_classes] self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes] self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
@ -361,7 +394,7 @@ class NWBPydanticGenerator(PydanticGenerator):
else: else:
return super().get_class_slot_range(slot_range, inlined, inlined_as_list) return super().get_class_slot_range(slot_range, inlined, inlined_as_list)
def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]: def get_class_isa_plus_mixins(self, classes:Optional[List[ClassDefinition]] = None) -> Dict[str, List[str]]:
""" """
Generate the inheritance list for each class from is_a plus mixins Generate the inheritance list for each class from is_a plus mixins
@ -370,8 +403,11 @@ class NWBPydanticGenerator(PydanticGenerator):
:return: :return:
""" """
sv = self.schemaview sv = self.schemaview
if classes is None:
classes = sv.all_classes(imports=False).values()
parents = {} parents = {}
for class_def in sv.all_classes(imports=False).values(): for class_def in classes:
class_parents = [] class_parents = []
if class_def.is_a: if class_def.is_a:
class_parents.append(camelcase(class_def.is_a)) class_parents.append(camelcase(class_def.is_a))
@ -404,6 +440,9 @@ class NWBPydanticGenerator(PydanticGenerator):
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM} enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
# import from local references, rather than serializing every class in every file # import from local references, rather than serializing every class in every file
if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True':
imports = self._get_namespace_imports(sv)
else:
imports = self._get_imports(sv) imports = self._get_imports(sv)
sorted_classes = self._get_classes(sv, imports) sorted_classes = self._get_classes(sv, imports)
@ -411,7 +450,13 @@ class NWBPydanticGenerator(PydanticGenerator):
for class_original in sorted_classes: for class_original in sorted_classes:
# Generate class definition # Generate class definition
class_def = self._build_class(class_original) class_def = self._build_class(class_original)
if class_def.is_a != "Arraylike":
# skip actually generating arraylike classes, just use them to generate
# the npytyping annotations
pyschema.classes[class_def.name] = class_def pyschema.classes[class_def.name] = class_def
else:
continue
# Not sure why this happens # Not sure why this happens
for attribute in list(class_def.attributes.keys()): for attribute in list(class_def.attributes.keys()):
@ -474,6 +519,6 @@ class NWBPydanticGenerator(PydanticGenerator):
allow_extra=self.allow_extra, allow_extra=self.allow_extra,
metamodel_version=self.schema.metamodel_version, metamodel_version=self.schema.metamodel_version,
version=self.schema.version, version=self.schema.version,
class_isa_plus_mixins=self.get_class_isa_plus_mixins(), class_isa_plus_mixins=self.get_class_isa_plus_mixins(sorted_classes),
) )
return code return code

View file

@ -52,3 +52,7 @@ def test_generate_pydantic(tmp_output_dir):
with open(pydantic_file, 'w') as pfile: with open(pydantic_file, 'w') as pfile:
pfile.write(gen_pydantic) pfile.write(gen_pydantic)
# make __init__.py
with open(tmp_output_dir / 'models' / '__init__.py', 'w') as initfile:
initfile.write('# Autogenerated module indicator')