diff --git a/nwb_linkml/adapters/namespaces.py b/nwb_linkml/adapters/namespaces.py index d939953..9ab7af3 100644 --- a/nwb_linkml/adapters/namespaces.py +++ b/nwb_linkml/adapters/namespaces.py @@ -21,11 +21,13 @@ class NamespacesAdapter(Adapter): schemas: List[SchemaAdapter] imported: List['NamespacesAdapter'] = Field(default_factory=list) - _imports_populated = PrivateAttr(False) + _imports_populated: bool = PrivateAttr(False) + _split: bool = PrivateAttr(False) def __init__(self, **kwargs): super(NamespacesAdapter, self).__init__(**kwargs) self._populate_schema_namespaces() + self.split = self._split def build(self) -> BuildResult: if not self._imports_populated: @@ -51,12 +53,36 @@ class NamespacesAdapter(Adapter): id = ns.name, description = ns.doc, version = ns.version, - imports=[sch.name for sch in ns_schemas] + imports=[sch.name for sch in ns_schemas], + annotations=[{'tag': 'namespace', 'value': True}] ) sch_result.schemas.append(ns_schema) return sch_result + @property + def split(self) -> bool: + """ + Sets the :attr:`.SchemaAdapter.split` attribute for all contained and imported schema + + Args: + split (bool): Set the generated schema to be split or not + + Returns: + bool: whether the schema are set to be split! + """ + return self._split + + @split.setter + def split(self, split): + for sch in self.schemas: + sch.split = split + for ns in self.imported: + for sch in ns.schemas: + sch.split = split + + self._split = split + def _populate_schema_namespaces(self): # annotate for each schema which namespace imports it for sch in self.schemas: diff --git a/nwb_linkml/adapters/schema.py b/nwb_linkml/adapters/schema.py index d96097f..2d05a9b 100644 --- a/nwb_linkml/adapters/schema.py +++ b/nwb_linkml/adapters/schema.py @@ -35,7 +35,7 @@ class SchemaAdapter(Adapter): None, description="""String of containing namespace. Populated by NamespacesAdapter""") split: bool = Field( - True, + False, description="Split anonymous subclasses into a separate schema file" ) _created_classes: List[Type[Group | Dataset]] = PrivateAttr(default_factory=list) diff --git a/nwb_linkml/generators/pydantic.py b/nwb_linkml/generators/pydantic.py index 8a85731..4adce1b 100644 --- a/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/generators/pydantic.py @@ -17,7 +17,7 @@ The `serialize` method """ import pdb -from typing import List, Dict, Set +from typing import List, Dict, Set, Tuple, Optional from copy import deepcopy import warnings @@ -174,6 +174,52 @@ class NWBPydanticGenerator(PydanticGenerator): SKIP_ENUM=('FlatDType',) + def _locate_imports( + self, + needed_classes:List[str], + sv:SchemaView + ) -> Dict[str, List[str]]: + """ + Given a list of class names, find the python modules that need to be imported + """ + imports = {} + + # These classes are not generated by pydantic! + skips = ('AnyType',) + + for cls in needed_classes: + if cls in skips: + continue + # Find module that contains class + module_name = sv.element_by_schema_map()[ElementName(cls)] + # Don't get classes that are defined in this schema! + if module_name == self.schema.name: + continue + + local_mod_name = '.' + module_name.replace('.', '_').replace('-', '_') + if local_mod_name not in imports: + imports[local_mod_name] = [camelcase(cls)] + else: + imports[local_mod_name].append(camelcase(cls)) + return imports + + def _get_namespace_imports(self, sv:SchemaView) -> Dict[str, List[str]]: + """ + Get imports for namespace packages. For these we import all + the tree_root classes, ie. all the classes that are top-level classes rather than + rather than nested classes + """ + all_classes = sv.all_classes(imports=True) + needed_classes = [] + for clsname, cls in all_classes.items(): + if cls.tree_root: + needed_classes.append(clsname) + + imports = self._locate_imports(needed_classes, sv) + return imports + + + def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]: all_classes = sv.all_classes(imports=True) local_classes = sv.all_classes(imports=False) @@ -190,26 +236,10 @@ class NWBPydanticGenerator(PydanticGenerator): if any_slot_range.range in all_classes: needed_classes.append(any_slot_range.range) - needed_classes = [cls for cls in set(needed_classes) if cls is not None] - imports = {} + needed_classes = [cls for cls in set(needed_classes) if cls is not None and cls != 'Arraylike'] + needed_classes = [cls for cls in needed_classes if sv.get_class(cls).is_a != 'Arraylike'] - # These classes are not generated by pydantic! - skips = ('AnyType',) - - for cls in needed_classes: - if cls in skips: - continue - # Find module that contains class - module_name = sv.element_by_schema_map()[ElementName(cls)] - # Don't get classes that are defined in this schema! - if module_name == self.schema.name: - continue - - local_mod_name = '.' + module_name.replace('.', '_').replace('-','_') - if local_mod_name not in imports: - imports[local_mod_name] = [camelcase(cls)] - else: - imports[local_mod_name].append(camelcase(cls)) + imports = self._locate_imports(needed_classes, sv) return imports @@ -219,8 +249,11 @@ class NWBPydanticGenerator(PydanticGenerator): imported_classes = [] for classes in imports.values(): imported_classes.extend(classes) - # pdb.set_trace() - sorted_classes = self.sort_classes(list(module_classes), imported_classes) + + module_classes = [c for c in list(module_classes) if c.is_a != 'Arraylike'] + imported_classes = [c for c in imported_classes if sv.get_class(c).is_a != 'Arraylike'] + + sorted_classes = self.sort_classes(module_classes, imported_classes) self.sorted_class_names = [camelcase(cname) for cname in imported_classes] self.sorted_class_names += [camelcase(c.name) for c in sorted_classes] @@ -361,7 +394,7 @@ class NWBPydanticGenerator(PydanticGenerator): else: return super().get_class_slot_range(slot_range, inlined, inlined_as_list) - def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]: + def get_class_isa_plus_mixins(self, classes:Optional[List[ClassDefinition]] = None) -> Dict[str, List[str]]: """ Generate the inheritance list for each class from is_a plus mixins @@ -370,8 +403,11 @@ class NWBPydanticGenerator(PydanticGenerator): :return: """ sv = self.schemaview + if classes is None: + classes = sv.all_classes(imports=False).values() + parents = {} - for class_def in sv.all_classes(imports=False).values(): + for class_def in classes: class_parents = [] if class_def.is_a: class_parents.append(camelcase(class_def.is_a)) @@ -404,14 +440,23 @@ class NWBPydanticGenerator(PydanticGenerator): enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM} # import from local references, rather than serializing every class in every file - imports = self._get_imports(sv) + if 'namespace' in schema.annotations.keys() and schema.annotations['namespace']['value'] == 'True': + imports = self._get_namespace_imports(sv) + else: + imports = self._get_imports(sv) sorted_classes = self._get_classes(sv, imports) for class_original in sorted_classes: # Generate class definition class_def = self._build_class(class_original) - pyschema.classes[class_def.name] = class_def + + if class_def.is_a != "Arraylike": + # skip actually generating arraylike classes, just use them to generate + # the npytyping annotations + pyschema.classes[class_def.name] = class_def + else: + continue # Not sure why this happens for attribute in list(class_def.attributes.keys()): @@ -474,6 +519,6 @@ class NWBPydanticGenerator(PydanticGenerator): allow_extra=self.allow_extra, metamodel_version=self.schema.metamodel_version, version=self.schema.version, - class_isa_plus_mixins=self.get_class_isa_plus_mixins(), + class_isa_plus_mixins=self.get_class_isa_plus_mixins(sorted_classes), ) return code \ No newline at end of file diff --git a/tests/test_generate.py b/tests/test_generate.py index b029993..112e3de 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -52,3 +52,7 @@ def test_generate_pydantic(tmp_output_dir): with open(pydantic_file, 'w') as pfile: pfile.write(gen_pydantic) + + # make __init__.py + with open(tmp_output_dir / 'models' / '__init__.py', 'w') as initfile: + initfile.write('# Autogenerated module indicator')