diff --git a/nwb_linkml/src/nwb_linkml/adapters/adapter.py b/nwb_linkml/src/nwb_linkml/adapters/adapter.py index 9984ecb..2e49965 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/adapter.py +++ b/nwb_linkml/src/nwb_linkml/adapters/adapter.py @@ -113,6 +113,28 @@ class Adapter(BaseModel): Generate the corresponding linkML element for this adapter """ + def get(self, name: str) -> Union[Group, Dataset]: + """ + Get the first item whose ``neurodata_type_def`` matches ``name`` + + Convenience wrapper around :meth:`.walk_field_values` + """ + return next(self.walk_field_values(self, 'neurodata_type_def', name)) + + def get_model_with_field(self, field: str) -> Generator[Union[Group, Dataset], None, None]: + """ + Yield models that have a non-None value in the given field. + + Useful during development to find all the ways that a given + field is used. + + Args: + field (str): Field to search for + """ + for model in self.walk_types(self, (Group, Dataset)): + if getattr(model, field, None) is not None: + yield model + def walk( self, input: Union[BaseModel, dict, list] ) -> Generator[Union[BaseModel, Any, None], None, None]: diff --git a/nwb_linkml/src/nwb_linkml/adapters/classes.py b/nwb_linkml/src/nwb_linkml/adapters/classes.py index 1f36643..c700d53 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/classes.py +++ b/nwb_linkml/src/nwb_linkml/adapters/classes.py @@ -248,7 +248,7 @@ class ClassAdapter(Adapter): range="string", ) else: - name_slot = SlotDefinition(name="name", required=True, range="string", identifier=True) + name_slot = SlotDefinition(name="name", required=True, range="string") return name_slot def build_self_slot(self) -> SlotDefinition: diff --git a/nwb_linkml/src/nwb_linkml/adapters/dataset.py b/nwb_linkml/src/nwb_linkml/adapters/dataset.py index c1e72d0..deabe66 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/dataset.py +++ b/nwb_linkml/src/nwb_linkml/adapters/dataset.py @@ -530,6 +530,45 @@ class MapArrayLikeAttributes(DatasetMap): return res +class MapClassRange(DatasetMap): + """ + Datasets that are a simple named reference to another type without any + additional modification to that type. + """ + @classmethod + def check(c, cls: Dataset) -> bool: + """ + Check that we are a dataset with a ``neurodata_type_inc`` and a name but nothing else + """ + return ( + cls.neurodata_type_inc + and not cls.neurodata_type_def + and not cls.attributes + and not cls.dims + and not cls.shape + and not cls.dtype + and cls.name + ) + + @classmethod + def apply( + c, cls: Dataset, res: Optional[BuildResult] = None, name: Optional[str] = None + ) -> BuildResult: + """ + Replace the base class with a slot with an annotation that indicates + it should use the :class:`.Named` generic when generated to pydantic + """ + this_slot = SlotDefinition( + name=cls.name, + description=cls.doc, + range=f"{cls.neurodata_type_inc}", + annotations=[{'named': True}], + **QUANTITY_MAP[cls.quantity], + ) + res = BuildResult(slots=[this_slot]) + return res + + # -------------------------------------------------- # DynamicTable special cases # -------------------------------------------------- diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 3e6364e..1b15468 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -62,6 +62,7 @@ from pydantic import BaseModel from nwb_linkml.maps import flat_to_nptyping from nwb_linkml.maps.naming import module_case, version_module_case +from nwb_linkml.includes import ModelTypeString, _get_name, NamedString, NamedImports OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]") @@ -119,35 +120,16 @@ class NWBPydanticGenerator(PydanticGenerator): - strip unwanted metadata - generate range with any_of """ - for key in self.skip_meta: - if key in slot.attribute.meta: - del slot.attribute.meta[key] - - # make array ranges in any_of - if "any_of" in slot.attribute.meta: - any_ofs = slot.attribute.meta["any_of"] - if all(["array" in expr for expr in any_ofs]): - ranges = [] - is_optional = False - for expr in any_ofs: - # remove optional from inner type - pyrange = slot.attribute.range - is_optional = OPTIONAL_PATTERN.match(pyrange) - if is_optional: - pyrange = is_optional.groups()[0] - range_generator = NumpydanticArray(ArrayExpression(**expr["array"]), pyrange) - ranges.append(range_generator.make().range) - - slot.attribute.range = "Union[" + ", ".join(ranges) + "]" - if is_optional: - slot.attribute.range = "Optional[" + slot.attribute.range + "]" - del slot.attribute.meta["any_of"] + slot = AfterGenerateSlot.skip_meta(slot, self.skip_meta) + slot = AfterGenerateSlot.make_array_anyofs(slot) + slot = AfterGenerateSlot.make_named_class_range(slot) return slot def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: if "source_file" in template.meta: del template.meta["source_file"] + return template def compile_module( self, module_path: Path = None, module_name: str = "test", **kwargs @@ -169,6 +151,62 @@ class NWBPydanticGenerator(PydanticGenerator): raise e +class AfterGenerateSlot: + """ + Container class for slot-modification methods + """ + @staticmethod + def skip_meta(slot: SlotResult, skip_meta: tuple[str]) -> SlotResult: + for key in skip_meta: + if key in slot.attribute.meta: + del slot.attribute.meta[key] + return slot + + @staticmethod + def make_array_anyofs(slot: SlotResult) -> SlotResult: + """ + Make a Union of array ranges if multiple array types specified in ``any_of`` + """ + # make array ranges in any_of + if "any_of" in slot.attribute.meta: + any_ofs = slot.attribute.meta["any_of"] + if all(["array" in expr for expr in any_ofs]): + ranges = [] + is_optional = False + for expr in any_ofs: + # remove optional from inner type + pyrange = slot.attribute.range + is_optional = OPTIONAL_PATTERN.match(pyrange) + if is_optional: + pyrange = is_optional.groups()[0] + range_generator = NumpydanticArray(ArrayExpression(**expr["array"]), pyrange) + ranges.append(range_generator.make().range) + + slot.attribute.range = "Union[" + ", ".join(ranges) + "]" + if is_optional: + slot.attribute.range = "Optional[" + slot.attribute.range + "]" + del slot.attribute.meta["any_of"] + return slot + + @staticmethod + def make_named_class_range(slot: SlotResult) -> SlotResult: + """ + When a slot has a ``named`` annotation, wrap it in :class:`.Named` + """ + + if 'named' in slot.source.annotations and slot.source.annotations['named'].value: + slot.attribute.range = f"Named[{slot.attribute.range}]" + named_injects = [ModelTypeString, _get_name, NamedString] + if slot.injected_classes is None: + slot.injected_classes = named_injects + else: + slot.injected_classes.extend([ModelTypeString, _get_name, NamedString]) + if slot.imports: + slot.imports += NamedImports + else: + slot.imports = NamedImports + return slot + def compile_python( text_or_fn: str, package_path: Path = None, module_name: str = "test" ) -> ModuleType: diff --git a/nwb_linkml/src/nwb_linkml/includes.py b/nwb_linkml/src/nwb_linkml/includes.py new file mode 100644 index 0000000..5c63de4 --- /dev/null +++ b/nwb_linkml/src/nwb_linkml/includes.py @@ -0,0 +1,50 @@ +""" +Classes and types that are injected in generated pydantic modules, but have no +corresponding representation in linkml (ie. that don't belong in :mod:`.lang_elements` + +Used to customize behavior of pydantic classes either to match pynwb behavior or +reduce the verbosity of the generated models with convenience classes. +""" + +from pydantic import BaseModel, ValidationInfo, BeforeValidator +from typing import Annotated, TypeVar, Type + +from linkml.generators.pydanticgen.template import Imports, Import, ObjectImport + +ModelType = TypeVar("ModelType", bound=Type[BaseModel]) +# inspect.getsource() doesn't work for typevars because everything in the typing module +# doesn't behave like a normal python object +ModelTypeString = """ModelType = TypeVar("ModelType", bound=Type[BaseModel])""" + +def _get_name(item: BaseModel | dict, info: ValidationInfo): + assert isinstance(item, (BaseModel, dict)) + name = info.field_name + if isinstance(item, BaseModel): + item.name = name + else: + item['name'] = name + return item + +Named = Annotated[ModelType, BeforeValidator(_get_name)] +""" +Generic annotated type that sets the ``name`` field of a model +to the name of the field with this type. + +Examples: + + class ChildModel(BaseModel): + name: str + value: int + + class MyModel(BaseModel): + named_field: Named[ChildModel] + + instance = MyModel(named_field={'value': 1}) + instance.named_field.name == "named_field" +""" +NamedString = """Named = Annotated[ModelType, BeforeValidator(_get_name)]""" + +NamedImports = Imports(imports=[ + Import(module="typing", objects=[ObjectImport(name="Annotated"),ObjectImport(name="Type"), ObjectImport(name="TypeVar"), ]), + Import(module="pydantic", objects=[ObjectImport(name="ValidationInfo"), ObjectImport(name="BeforeValidator")]) +]) \ No newline at end of file diff --git a/nwb_linkml/src/nwb_linkml/io/schema.py b/nwb_linkml/src/nwb_linkml/io/schema.py index 45f599c..3e2a76e 100644 --- a/nwb_linkml/src/nwb_linkml/io/schema.py +++ b/nwb_linkml/src/nwb_linkml/io/schema.py @@ -120,7 +120,7 @@ def load_namespace_adapter( return adapter -def load_nwb_core(core_version: str = "2.6.0", hdmf_version: str = "1.5.0") -> NamespacesAdapter: +def load_nwb_core(core_version: str = "2.7.0", hdmf_version: str = "1.8.0") -> NamespacesAdapter: """ Convenience function for loading the NWB core schema + hdmf-common as a namespace adapter. diff --git a/nwb_linkml/tests/test_includes.py b/nwb_linkml/tests/test_includes.py new file mode 100644 index 0000000..ce59503 --- /dev/null +++ b/nwb_linkml/tests/test_includes.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel +from nwb_linkml.includes import Named + +def test_named_generic(): + """ + the Named type should fill in the ``name`` field in a model from the field name + """ + class Child(BaseModel): + name: str + value: int + + class Parent(BaseModel): + field_name: Named[Child] + + # should instantiate correctly and have name set + instance = Parent(field_name={'value': 1}) + assert instance.field_name.name == 'field_name' \ No newline at end of file