diff --git a/nwb_linkml/src/nwb_linkml/adapters/adapter.py b/nwb_linkml/src/nwb_linkml/adapters/adapter.py index 1ceb7b5..07c5231 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/adapter.py +++ b/nwb_linkml/src/nwb_linkml/adapters/adapter.py @@ -354,3 +354,40 @@ def defaults(cls: Dataset | Attribute) -> dict: ret["ifabsent"] = cls.default_value return ret + + +def is_container(group: Group) -> bool: + """ + Check if a group is a container group. + + i.e. a group that... + * has no name + * multivalued quantity + * has a ``neurodata_type_inc`` + * has no ``neurodata_type_def`` + * has no sub-groups + * has no datasets + * has no attributes + + Examples: + + .. code-block:: yaml + + - name: templates + groups: + - neurodata_type_inc: TimeSeries + doc: TimeSeries objects containing template data of presented stimuli. + quantity: '*' + - neurodata_type_inc: Images + doc: Images objects containing images of presented stimuli. + quantity: '*' + """ + return ( + not group.name + and group.quantity == "*" + and group.neurodata_type_inc + and not group.neurodata_type_def + and not group.datasets + and not group.groups + and not group.attributes + ) diff --git a/nwb_linkml/src/nwb_linkml/adapters/group.py b/nwb_linkml/src/nwb_linkml/adapters/group.py index fb919d0..f9ef07d 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/group.py +++ b/nwb_linkml/src/nwb_linkml/adapters/group.py @@ -2,11 +2,11 @@ Adapter for NWB groups to linkml Classes """ -from typing import List, Type +from typing import Type from linkml_runtime.linkml_model import SlotDefinition -from nwb_linkml.adapters.adapter import BuildResult +from nwb_linkml.adapters.adapter import BuildResult, is_container from nwb_linkml.adapters.classes import ClassAdapter from nwb_linkml.adapters.dataset import DatasetAdapter from nwb_linkml.maps import QUANTITY_MAP @@ -45,19 +45,21 @@ class GroupAdapter(ClassAdapter): ): return self.handle_container_slot(self.cls) - nested_res = self.build_subclasses() - # add links - links = self.build_links() + nested_res = self.build_datasets() + nested_res += self.build_groups() + nested_res += self.build_links() + nested_res += self.build_containers() + nested_res += self.build_special_cases() # we don't propagate slots up to the next level since they are meant for this # level (ie. a way to refer to our children) - res = self.build_base(extra_attrs=nested_res.slots + links) + res = self.build_base(extra_attrs=nested_res.slots) # we do propagate classes tho res.classes.extend(nested_res.classes) return res - def build_links(self) -> List[SlotDefinition]: + def build_links(self) -> BuildResult: """ Build links specified in the ``links`` field as slots that refer to other classes, with an additional annotation specifying that they are in fact links. @@ -66,7 +68,7 @@ class GroupAdapter(ClassAdapter): file hierarchy as a string. """ if not self.cls.links: - return [] + return BuildResult() annotations = [{"tag": "source_type", "value": "link"}] @@ -83,7 +85,7 @@ class GroupAdapter(ClassAdapter): ) for link in self.cls.links ] - return slots + return BuildResult(slots=slots) def handle_container_group(self, cls: Group) -> BuildResult: """ @@ -129,7 +131,7 @@ class GroupAdapter(ClassAdapter): # We are a top-level container class like ProcessingModule base = self.build_base() # remove all the attributes and replace with child slot - base.classes[0].attributes.append(slot) + base.classes[0].attributes.update({slot.name: slot}) return base def handle_container_slot(self, cls: Group) -> BuildResult: @@ -167,30 +169,88 @@ class GroupAdapter(ClassAdapter): return BuildResult(slots=[slot]) - def build_subclasses(self) -> BuildResult: + def build_datasets(self) -> BuildResult: """ Build nested groups and datasets Create ClassDefinitions for each, but then also create SlotDefinitions that will be used as attributes linking the main class to the subclasses + + Datasets are simple, they are terminal classes, and all logic + for creating slots vs. classes is handled by the adapter class """ - # Datasets are simple, they are terminal classes, and all logic - # for creating slots vs. classes is handled by the adapter class dataset_res = BuildResult() if self.cls.datasets: for dset in self.cls.datasets: dset_adapter = DatasetAdapter(cls=dset, parent=self) dataset_res += dset_adapter.build() + return dataset_res + + def build_groups(self) -> BuildResult: + """ + Build subgroups, excluding pure container subgroups + """ group_res = BuildResult() if self.cls.groups: for group in self.cls.groups: + if is_container(group): + continue group_adapter = GroupAdapter(cls=group, parent=self) group_res += group_adapter.build() - res = dataset_res + group_res + return group_res + def build_containers(self) -> BuildResult: + """ + Build all container types into a single ``value`` slot + """ + res = BuildResult() + if not self.cls.groups: + return res + containers = [grp for grp in self.cls.groups if is_container(grp)] + if not containers: + return res + + if len(containers) == 1: + range = {"range": containers[0].neurodata_type_inc} + description = containers[0].doc + else: + range = {"any_of": [{"range": subcls.neurodata_type_inc} for subcls in containers]} + description = "\n\n".join([grp.doc for grp in containers]) + + slot = SlotDefinition( + name="value", + multivalued=True, + inlined=True, + inlined_as_list=False, + description=description, + **range, + ) + + if self.debug: # pragma: no cover - only used in development + slot.annotations["group_adapter"] = { + "tag": "slot_adapter", + "value": "container_value_slot", + } + res.slots = [slot] + return res + + def build_special_cases(self) -> BuildResult: + """ + Special cases, at this point just for NWBFile, which has + extra ``.specloc`` and ``specifications`` attrs + """ + res = BuildResult() + if self.cls.neurodata_type_def == "NWBFile": + res.slots = [ + SlotDefinition( + name="specifications", + range="dict", + description="Nested dictionary of schema specifications", + ), + ] return res def build_self_slot(self) -> SlotDefinition: diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 336bbf8..4b3d412 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -15,7 +15,7 @@ from linkml.generators import PydanticGenerator from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.build import ClassResult, SlotResult from linkml.generators.pydanticgen.pydanticgen import SplitMode -from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule +from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule, ObjectImport from linkml_runtime.linkml_model.meta import ( ArrayExpression, SchemaDefinition, @@ -30,6 +30,7 @@ from nwb_linkml.includes.base import ( BASEMODEL_COERCE_CHILD, BASEMODEL_COERCE_VALUE, BASEMODEL_GETITEM, + BASEMODEL_EXTRA_TO_VALUE, ) from nwb_linkml.includes.hdmf import ( DYNAMIC_TABLE_IMPORTS, @@ -58,9 +59,15 @@ class NWBPydanticGenerator(PydanticGenerator): BASEMODEL_COERCE_VALUE, BASEMODEL_CAST_WITH_VALUE, BASEMODEL_COERCE_CHILD, + BASEMODEL_EXTRA_TO_VALUE, ) split: bool = True - imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")]) + imports: list[Import] = field( + default_factory=lambda: [ + Import(module="numpy", alias="np"), + Import(module="pydantic", objects=[ObjectImport(name="model_validator")]), + ] + ) schema_map: Optional[Dict[str, SchemaDefinition]] = None """See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from""" diff --git a/nwb_linkml/src/nwb_linkml/includes/base.py b/nwb_linkml/src/nwb_linkml/includes/base.py index c081587..6cad4a3 100644 --- a/nwb_linkml/src/nwb_linkml/includes/base.py +++ b/nwb_linkml/src/nwb_linkml/includes/base.py @@ -3,7 +3,7 @@ Modifications to the ConfiguredBaseModel used by all generated classes """ BASEMODEL_GETITEM = """ - def __getitem__(self, val: Union[int, slice]) -> Any: + def __getitem__(self, val: Union[int, slice, str]) -> Any: \"\"\"Try and get a value from value or "data" if we have it\"\"\" if hasattr(self, "value") and self.value is not None: return self.value[val] @@ -64,3 +64,23 @@ BASEMODEL_COERCE_CHILD = """ pass return v """ + +BASEMODEL_EXTRA_TO_VALUE = """ + @model_validator(mode="before") + @classmethod + def gather_extra_to_value(cls, v: Any, handler) -> Any: + \"\"\" + For classes that don't allow extra fields and have a value slot, + pack those extra kwargs into ``value`` + \"\"\" + if cls.model_config["extra"] == "forbid" and "value" in cls.model_fields and isinstance(v, dict): + extras = {key:val for key,val in v.items() if key not in cls.model_fields} + if extras: + for k in extras: + del v[k] + if "value" in v: + v["value"].update(extras) + else: + v["value"] = extras + return v +""" diff --git a/nwb_linkml/src/nwb_linkml/io/hdf5.py b/nwb_linkml/src/nwb_linkml/io/hdf5.py index 1691a46..d46465f 100644 --- a/nwb_linkml/src/nwb_linkml/io/hdf5.py +++ b/nwb_linkml/src/nwb_linkml/io/hdf5.py @@ -35,7 +35,7 @@ import h5py import networkx as nx import numpy as np from numpydantic.interface.hdf5 import H5ArrayPath -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel from tqdm import tqdm from nwb_linkml.maps.hdf5 import ( @@ -166,24 +166,28 @@ def _load_node( raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}") if "neurodata_type" in obj.attrs: + # SPECIAL CASE: ignore `.specloc` + if ".specloc" in args: + del args[".specloc"] + model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"]) - try: - return model(**args) - except ValidationError as e1: - # try to restack extra fields into ``value`` - if "value" in model.model_fields: - value_dict = { - key: val for key, val in args.items() if key not in model.model_fields - } - for k in value_dict: - del args[k] - args["value"] = value_dict - try: - return model(**args) - except Exception as e2: - raise e2 from e1 - else: - raise e1 + # try: + return model(**args) + # except ValidationError as e1: + # # try to restack extra fields into ``value`` + # if "value" in model.model_fields: + # value_dict = { + # key: val for key, val in args.items() if key not in model.model_fields + # } + # for k in value_dict: + # del args[k] + # args["value"] = value_dict + # try: + # return model(**args) + # except Exception as e2: + # raise e2 from e1 + # else: + # raise e1 else: if "name" in args: diff --git a/nwb_linkml/src/nwb_linkml/lang_elements.py b/nwb_linkml/src/nwb_linkml/lang_elements.py index fdde634..476e6e2 100644 --- a/nwb_linkml/src/nwb_linkml/lang_elements.py +++ b/nwb_linkml/src/nwb_linkml/lang_elements.py @@ -39,6 +39,10 @@ def _make_dtypes() -> List[TypeDefinition]: repr=linkml_reprs.get(nwbtype, None), ) DTypeTypes.append(atype) + + # a dict type! + DTypeTypes.append(TypeDefinition(name="dict", repr="dict")) + return DTypeTypes diff --git a/nwb_linkml/tests/test_io/test_io_nwb.py b/nwb_linkml/tests/test_io/test_io_nwb.py index 1ad51ed..32a50d1 100644 --- a/nwb_linkml/tests/test_io/test_io_nwb.py +++ b/nwb_linkml/tests/test_io/test_io_nwb.py @@ -80,7 +80,7 @@ def test_position(read_nwbfile, read_pynwb): py_trials = read_pynwb.trials.to_dataframe() pd.testing.assert_frame_equal(py_trials, trials) - spatial = read_nwbfile.processing["behavior"].Position.SpatialSeries + spatial = read_nwbfile.processing["behavior"]["Position"]["SpatialSeries"] py_spatial = read_pynwb.processing["behavior"]["Position"]["SpatialSeries"] _compare_attrs(spatial, py_spatial) assert np.array_equal(spatial[:], py_spatial.data[:]) diff --git a/scripts/generate_core.py b/scripts/generate_core.py index 55fc94e..413b85b 100644 --- a/scripts/generate_core.py +++ b/scripts/generate_core.py @@ -19,37 +19,6 @@ from nwb_linkml.providers import LinkMLProvider, PydanticProvider from nwb_linkml.providers.git import NWB_CORE_REPO, HDMF_COMMON_REPO, GitRepo from nwb_linkml.io import schema as io - -def generate_core_yaml(output_path: Path, dry_run: bool = False, hdmf_only: bool = False): - """Just build the latest version of the core schema""" - - core = io.load_nwb_core(hdmf_only=hdmf_only) - built_schemas = core.build().schemas - for schema in built_schemas: - output_file = output_path / (schema.name + ".yaml") - if not dry_run: - yaml_dumper.dump(schema, output_file) - - -def generate_core_pydantic(yaml_path: Path, output_path: Path, dry_run: bool = False): - """Just generate the latest version of the core schema""" - for schema in yaml_path.glob("*.yaml"): - python_name = schema.stem.replace(".", "_").replace("-", "_") - pydantic_file = (output_path / python_name).with_suffix(".py") - - generator = NWBPydanticGenerator( - str(schema), - pydantic_version="2", - emit_metadata=True, - gen_classvars=True, - gen_slots=True, - ) - gen_pydantic = generator.serialize() - if not dry_run: - with open(pydantic_file, "w") as pfile: - pfile.write(gen_pydantic) - - def make_tmp_dir(clear: bool = False) -> Path: # use a directory underneath this one as the temporary directory rather than # the default hidden one @@ -68,6 +37,7 @@ def generate_versions( dry_run: bool = False, repo: GitRepo = NWB_CORE_REPO, pdb=False, + latest: bool = False, ): """ Generate linkml models for all versions @@ -82,8 +52,13 @@ def generate_versions( failed_versions = {} + if latest: + versions = [repo.namespace.versions[-1]] + else: + versions = repo.namespace.versions + overall_progress = Progress() - overall_task = overall_progress.add_task("All Versions", total=len(NWB_CORE_REPO.versions)) + overall_task = overall_progress.add_task("All Versions", total=len(versions)) build_progress = Progress( TextColumn( @@ -100,7 +75,7 @@ def generate_versions( linkml_task = None pydantic_task = None - for version in repo.namespace.versions: + for version in versions: # build linkml try: # check out the version (this should also refresh the hdmf-common schema) @@ -251,11 +226,10 @@ def main(): if not args.dry_run: args.yaml.mkdir(exist_ok=True) args.pydantic.mkdir(exist_ok=True) - if args.latest: - generate_core_yaml(args.yaml, args.dry_run) - generate_core_pydantic(args.yaml, args.pydantic, args.dry_run) - else: - generate_versions(args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb) + + generate_versions( + args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb, latest=args.latest + ) if __name__ == "__main__":