From 27b5dddfddff8a93372d9075d16d3ed0e504e760 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 4 Sep 2024 00:04:21 -0700 Subject: [PATCH] updating model generation methods, still some models being converted to str instead of being inlined, but almost there --- nwb_linkml/src/nwb_linkml/adapters/classes.py | 24 ++++++- nwb_linkml/src/nwb_linkml/adapters/dataset.py | 30 ++++++--- nwb_linkml/src/nwb_linkml/adapters/group.py | 62 +++++++++++++------ .../src/nwb_linkml/generators/pydantic.py | 50 +++++++++++++-- nwb_linkml/src/nwb_linkml/includes/base.py | 13 +++- nwb_linkml/src/nwb_linkml/includes/hdmf.py | 20 +++--- nwb_linkml/src/nwb_linkml/maps/dtype.py | 28 ++++++--- nwb_linkml/tests/test_io/test_io_nwb.py | 1 + scripts/generate_core.py | 11 ++++ 9 files changed, 182 insertions(+), 57 deletions(-) diff --git a/nwb_linkml/src/nwb_linkml/adapters/classes.py b/nwb_linkml/src/nwb_linkml/adapters/classes.py index 0097e47..be8f336 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/classes.py +++ b/nwb_linkml/src/nwb_linkml/adapters/classes.py @@ -2,6 +2,7 @@ Adapters to linkML classes """ +import os from abc import abstractmethod from typing import List, Optional, Type, TypeVar @@ -32,6 +33,14 @@ class ClassAdapter(Adapter): cls: TI parent: Optional["ClassAdapter"] = None + _debug: Optional[bool] = None + + @property + def debug(self) -> bool: + if self._debug is None: + self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False)) + return self._debug + @field_validator("cls", mode="before") @classmethod def cast_from_string(cls, value: str | TI) -> TI: @@ -92,6 +101,13 @@ class ClassAdapter(Adapter): # Get vanilla top-level attributes kwargs["attributes"].extend(self.build_attrs(self.cls)) + if self.debug: + kwargs["annotations"] = {} + kwargs["annotations"]["group_adapter"] = { + "tag": "group_adapter", + "value": "container_slot", + } + if extra_attrs is not None: if isinstance(extra_attrs, SlotDefinition): extra_attrs = [extra_attrs] @@ -230,18 +246,22 @@ class ClassAdapter(Adapter): ifabsent=f"string({name})", equals_string=equals_string, range="string", + identifier=True, ) else: - name_slot = SlotDefinition(name="name", required=True, range="string") + name_slot = SlotDefinition(name="name", required=True, range="string", identifier=True) return name_slot def build_self_slot(self) -> SlotDefinition: """ If we are a child class, we make a slot so our parent can refer to us """ - return SlotDefinition( + slot = SlotDefinition( name=self._get_slot_name(), description=self.cls.doc, range=self._get_full_name(), **QUANTITY_MAP[self.cls.quantity], ) + if self.debug: + slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "self_slot"} + return slot diff --git a/nwb_linkml/src/nwb_linkml/adapters/dataset.py b/nwb_linkml/src/nwb_linkml/adapters/dataset.py index a60c440..ed1818a 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/dataset.py +++ b/nwb_linkml/src/nwb_linkml/adapters/dataset.py @@ -11,7 +11,7 @@ from nwb_linkml.adapters.adapter import BuildResult, has_attrs, is_1d, is_compou from nwb_linkml.adapters.array import ArrayAdapter from nwb_linkml.adapters.classes import ClassAdapter from nwb_linkml.maps import QUANTITY_MAP, Map -from nwb_linkml.maps.dtype import flat_to_linkml, handle_dtype +from nwb_linkml.maps.dtype import flat_to_linkml, handle_dtype, inlined from nwb_linkml.maps.naming import camel_to_snake from nwb_schema_language import Dataset @@ -299,6 +299,8 @@ class MapListlike(DatasetMap): description=cls.doc, required=cls.quantity not in ("*", "?"), annotations=[{"source_type": "reference"}], + inlined=True, + inlined_as_list=True, ) res.classes[0].attributes["value"] = slot return res @@ -544,7 +546,9 @@ class MapArrayLikeAttributes(DatasetMap): array_adapter = ArrayAdapter(cls.dims, cls.shape) expressions = array_adapter.make_slot() # make a slot for the arraylike class - array_slot = SlotDefinition(name="value", range=handle_dtype(cls.dtype), **expressions) + array_slot = SlotDefinition( + name="value", range=handle_dtype(cls.dtype), inlined=inlined(cls.dtype), **expressions + ) res.classes[0].attributes.update({"value": array_slot}) return res @@ -583,6 +587,7 @@ class MapClassRange(DatasetMap): description=cls.doc, range=f"{cls.neurodata_type_inc}", annotations=[{"named": True}, {"source_type": "neurodata_type_inc"}], + inlined=True, **QUANTITY_MAP[cls.quantity], ) res = BuildResult(slots=[this_slot]) @@ -799,6 +804,7 @@ class MapCompoundDtype(DatasetMap): description=a_dtype.doc, range=handle_dtype(a_dtype.dtype), array=ArrayExpression(exact_number_dimensions=1), + inlined=inlined(a_dtype.dtype), **QUANTITY_MAP[cls.quantity], ) res.classes[0].attributes.update(slots) @@ -830,6 +836,8 @@ class DatasetAdapter(ClassAdapter): if map is not None: res = map.apply(self.cls, res, self._get_full_name()) + if self.debug: + res = self._amend_debug(res, map) return res def match(self) -> Optional[Type[DatasetMap]]: @@ -855,11 +863,13 @@ class DatasetAdapter(ClassAdapter): else: return matches[0] - def special_cases(self, res: BuildResult) -> BuildResult: - """ - Apply special cases to build result - """ - res = self._datetime_or_str(res) - - def _datetime_or_str(self, res: BuildResult) -> BuildResult: - """HDF5 doesn't support datetime, so""" + def _amend_debug(self, res: BuildResult, map: Optional[Type[DatasetMap]] = None) -> BuildResult: + if map is None: + map_name = "None" + else: + map_name = map.__name__ + for cls in res.classes: + cls.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name} + for slot in res.slots: + slot.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name} + return res diff --git a/nwb_linkml/src/nwb_linkml/adapters/group.py b/nwb_linkml/src/nwb_linkml/adapters/group.py index 13a03b7..ba5d004 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/group.py +++ b/nwb_linkml/src/nwb_linkml/adapters/group.py @@ -111,6 +111,9 @@ class GroupAdapter(ClassAdapter): inlined_as_list=False, ) + if self.debug: + slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_group"} + if self.parent is not None: # if we have a parent, # just return the slot itself without the class @@ -144,17 +147,20 @@ class GroupAdapter(ClassAdapter): """ name = camel_to_snake(self.cls.neurodata_type_inc) if not self.cls.name else cls.name - return BuildResult( - slots=[ - SlotDefinition( - name=name, - range=self.cls.neurodata_type_inc, - description=self.cls.doc, - **QUANTITY_MAP[cls.quantity], - ) - ] + slot = SlotDefinition( + name=name, + range=self.cls.neurodata_type_inc, + description=self.cls.doc, + inlined=True, + inlined_as_list=False, + **QUANTITY_MAP[cls.quantity], ) + if self.debug: + slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_slot"} + + return BuildResult(slots=[slot]) + def build_subclasses(self) -> BuildResult: """ Build nested groups and datasets @@ -166,20 +172,9 @@ class GroupAdapter(ClassAdapter): # for creating slots vs. classes is handled by the adapter class dataset_res = BuildResult() for dset in self.cls.datasets: - # if dset.name == 'timestamps': - # pdb.set_trace() dset_adapter = DatasetAdapter(cls=dset, parent=self) dataset_res += dset_adapter.build() - # Actually i'm not sure we have to special case this, we could handle it in - # i/o instead - - # Groups are a bit more complicated because they can also behave like - # range declarations: - # eg. a group can have multiple groups with `neurodata_type_inc`, no name, - # and quantity of *, - # the group can then contain any number of groups of those included types as direct children - group_res = BuildResult() for group in self.cls.groups: @@ -190,6 +185,33 @@ class GroupAdapter(ClassAdapter): return res + def build_self_slot(self) -> SlotDefinition: + """ + If we are a child class, we make a slot so our parent can refer to us + + Groups are a bit more complicated because they can also behave like + range declarations: + eg. a group can have multiple groups with `neurodata_type_inc`, no name, + and quantity of *, + the group can then contain any number of groups of those included types as direct children + + We make sure that we're inlined as a dict so our parent class can refer to us like:: + + parent.{slot_name}[{name}] = self + + """ + slot = SlotDefinition( + name=self._get_slot_name(), + description=self.cls.doc, + range=self._get_full_name(), + inlined=True, + inlined_as_list=True, + **QUANTITY_MAP[self.cls.quantity], + ) + if self.debug: + slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_slot"} + return slot + def _check_if_container(self, group: Group) -> bool: """ Check if a given subgroup is a container subgroup, diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 20619d2..adffd45 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -5,6 +5,7 @@ customized to support NWB models. See class and module docstrings for details :) """ +import pdb import re import sys from dataclasses import dataclass, field @@ -27,7 +28,11 @@ from linkml_runtime.utils.compile_python import file_text from linkml_runtime.utils.formatutils import remove_empty_items from linkml_runtime.utils.schemaview import SchemaView -from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE +from nwb_linkml.includes.base import ( + BASEMODEL_GETITEM, + BASEMODEL_COERCE_VALUE, + BASEMODEL_COERCE_CHILD, +) from nwb_linkml.includes.hdmf import ( DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_INJECTS, @@ -36,7 +41,7 @@ from nwb_linkml.includes.hdmf import ( ) from nwb_linkml.includes.types import ModelTypeString, NamedImports, NamedString, _get_name -OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]") +OPTIONAL_PATTERN = re.compile(r"Optional\[(.*)\]") @dataclass @@ -53,6 +58,7 @@ class NWBPydanticGenerator(PydanticGenerator): 'object_id: Optional[str] = Field(None, description="Unique UUID for each object")', BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE, + BASEMODEL_COERCE_CHILD, ) split: bool = True imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")]) @@ -134,6 +140,7 @@ class NWBPydanticGenerator(PydanticGenerator): cls = AfterGenerateClass.inject_dynamictable(cls) cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv) cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import) + cls = AfterGenerateClass.strip_vector_data_slots(cls, sv) return cls def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: @@ -227,7 +234,8 @@ class AfterGenerateSlot: """ if "named" in slot.source.annotations and slot.source.annotations["named"].value: - slot.attribute.range = f"Named[{slot.attribute.range}]" + + slot.attribute.range = wrap_preserving_optional(slot.attribute.range, "Named") named_injects = [ModelTypeString, _get_name, NamedString] if slot.injected_classes is None: slot.injected_classes = named_injects @@ -325,7 +333,9 @@ class AfterGenerateClass: else: wrap_cls = "VectorData" - cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"]) + cls.cls.attributes[an_attr].range = wrap_preserving_optional( + slot_range, wrap_cls + ) return cls @@ -343,6 +353,15 @@ class AfterGenerateClass: cls.imports += [imp] return cls + @staticmethod + def strip_vector_data_slots(cls: ClassResult, sv: SchemaView) -> ClassResult: + """ + Remove spurious ``vector_data`` slots from DynamicTables + """ + if "vector_data" in cls.cls.attributes: + del cls.cls.attributes["vector_data"] + return cls + def compile_python( text_or_fn: str, package_path: Path = None, module_name: str = "test" @@ -364,3 +383,26 @@ def compile_python( exec(spec, module.__dict__) sys.modules[module_name] = module return module + + +def wrap_preserving_optional(annotation: str, wrap: str) -> str: + """ + Add a wrapping type to a type annotation string, + preserving any `Optional[]` annotation, bumping it to the outside + + Examples: + + >>> wrap_preserving_optional('Optional[list[str]]', 'NewType') + 'Optional[NewType[list[str]]]' + + """ + + is_optional = OPTIONAL_PATTERN.match(annotation) + if is_optional: + annotation = is_optional.groups()[0] + annotation = f"Optional[{wrap}[{annotation}]]" + else: + if "Optional" in annotation: + pdb.set_trace() + annotation = f"{wrap}[{annotation}]" + return annotation diff --git a/nwb_linkml/src/nwb_linkml/includes/base.py b/nwb_linkml/src/nwb_linkml/includes/base.py index cb446c6..9c7896b 100644 --- a/nwb_linkml/src/nwb_linkml/includes/base.py +++ b/nwb_linkml/src/nwb_linkml/includes/base.py @@ -23,8 +23,11 @@ BASEMODEL_COERCE_VALUE = """ except Exception as e1: try: return handler(v.value) - except: - raise e1 + except AttributeError: + try: + return handler(v["value"]) + except (KeyError, TypeError): + raise e1 """ BASEMODEL_COERCE_CHILD = """ @@ -32,6 +35,10 @@ BASEMODEL_COERCE_CHILD = """ @classmethod def coerce_subclass(cls, v: Any, info) -> Any: \"\"\"Recast parent classes into child classes\"\"\" + if isinstance(v, BaseModel): + annotation = cls.model_fields[info.field_name].annotation + annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation + if issubclass(annotation, type(v)) and annotation is not type(v): + v = annotation(**{**v.__dict__, **v.__pydantic_extra__}) return v - pdb.set_trace() """ diff --git a/nwb_linkml/src/nwb_linkml/includes/hdmf.py b/nwb_linkml/src/nwb_linkml/includes/hdmf.py index c05213c..278f133 100644 --- a/nwb_linkml/src/nwb_linkml/includes/hdmf.py +++ b/nwb_linkml/src/nwb_linkml/includes/hdmf.py @@ -249,6 +249,7 @@ class DynamicTableMixin(BaseModel): if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index") and not isinstance(model[k], VectorIndexMixin) + and model[k] is not None ] model["colnames"] = colnames else: @@ -264,6 +265,7 @@ class DynamicTableMixin(BaseModel): and not k.endswith("_index") and k not in model["colnames"] and not isinstance(model[k], VectorIndexMixin) + and model[k] is not None ] ) model["colnames"] = colnames @@ -336,9 +338,9 @@ class DynamicTableMixin(BaseModel): """ Ensure that all columns are equal length """ - lengths = [len(v) for v in self._columns.values()] + [len(self.id)] + lengths = [len(v) for v in self._columns.values() if v is not None] + [len(self.id)] assert all([length == lengths[0] for length in lengths]), ( - "Columns are not of equal length! " + "DynamicTable columns are not of equal length! " f"Got colnames:\n{self.colnames}\nand lengths: {lengths}" ) return self @@ -386,11 +388,6 @@ class VectorDataMixin(BaseModel, Generic[T]): # redefined in `VectorData`, but included here for testing and type checking value: Optional[T] = None - # def __init__(self, value: Optional[NDArray] = None, **kwargs): - # if value is not None and "value" not in kwargs: - # kwargs["value"] = value - # super().__init__(**kwargs) - def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: if self._index: # Following hdmf, VectorIndex is the thing that knows how to do the slicing @@ -587,6 +584,7 @@ class AlignedDynamicTableMixin(BaseModel): __pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]] NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = ( + "id", "name", "categories", "colnames", @@ -622,7 +620,7 @@ class AlignedDynamicTableMixin(BaseModel): elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str): # get a slice of a single table return self._categories[item[1]][item[0]] - elif isinstance(item, (int, slice, Iterable)): + elif isinstance(item, (int, slice, Iterable, np.int_)): # get a slice of all the tables ids = self.id[item] if not isinstance(ids, Iterable): @@ -634,9 +632,9 @@ class AlignedDynamicTableMixin(BaseModel): if isinstance(table, pd.DataFrame): table = table.reset_index() elif isinstance(table, np.ndarray): - table = pd.DataFrame({category_name: [table]}) + table = pd.DataFrame({category_name: [table]}, index=ids.index) elif isinstance(table, Iterable): - table = pd.DataFrame({category_name: table}) + table = pd.DataFrame({category_name: table}, index=ids.index) else: raise ValueError( f"Don't know how to construct category table for {category_name}" @@ -756,7 +754,7 @@ class AlignedDynamicTableMixin(BaseModel): """ lengths = [len(v) for v in self._categories.values()] + [len(self.id)] assert all([length == lengths[0] for length in lengths]), ( - "Columns are not of equal length! " + "AlignedDynamicTableColumns are not of equal length! " f"Got colnames:\n{self.categories}\nand lengths: {lengths}" ) return self diff --git a/nwb_linkml/src/nwb_linkml/maps/dtype.py b/nwb_linkml/src/nwb_linkml/maps/dtype.py index d618dbe..2497a65 100644 --- a/nwb_linkml/src/nwb_linkml/maps/dtype.py +++ b/nwb_linkml/src/nwb_linkml/maps/dtype.py @@ -3,7 +3,7 @@ Dtype mappings """ from datetime import datetime -from typing import Any +from typing import Any, Optional import numpy as np @@ -160,14 +160,28 @@ def handle_dtype(dtype: DTypeType | None) -> str: elif isinstance(dtype, FlatDtype): return dtype.value elif isinstance(dtype, list) and isinstance(dtype[0], CompoundDtype): - # there is precisely one class that uses compound dtypes: - # TimeSeriesReferenceVectorData - # compoundDtypes are able to define a ragged table according to the schema - # but are used in this single case equivalently to attributes. - # so we'll... uh... treat them as slots. - # TODO + # Compound Dtypes are handled by the MapCompoundDtype dataset map, + # but this function is also used within ``check`` methods, so we should always + # return something from it rather than raise return "AnyType" else: # flat dtype return dtype + + +def inlined(dtype: DTypeType | None) -> Optional[bool]: + """ + Check if a slot should be inlined based on its dtype + + for now that is equivalent to checking whether that dtype is another a reference dtype, + but the function remains semantically reserved for answering this question w.r.t. dtype. + + Returns ``None`` if not inlined to not clutter generated models with unnecessary props + """ + return ( + True + if isinstance(dtype, ReferenceDtype) + or (isinstance(dtype, CompoundDtype) and isinstance(dtype.dtype, ReferenceDtype)) + else None + ) diff --git a/nwb_linkml/tests/test_io/test_io_nwb.py b/nwb_linkml/tests/test_io/test_io_nwb.py index 54f4d0f..d2c5d73 100644 --- a/nwb_linkml/tests/test_io/test_io_nwb.py +++ b/nwb_linkml/tests/test_io/test_io_nwb.py @@ -10,6 +10,7 @@ def test_read_from_nwbfile(nwb_file): Read data from a pynwb HDF5 NWB file """ res = HDF5IO(nwb_file).read() + res.model_dump_json() def test_read_from_yaml(nwb_file): diff --git a/scripts/generate_core.py b/scripts/generate_core.py index aa07dd5..4aeb21a 100644 --- a/scripts/generate_core.py +++ b/scripts/generate_core.py @@ -222,6 +222,11 @@ def parser() -> ArgumentParser: ), action="store_true", ) + parser.add_argument( + "--debug", + help="Add annotations to generated schema that indicate how they were generated", + action="store_true", + ) parser.add_argument("--pdb", help="Launch debugger on an error", action="store_true") return parser @@ -229,6 +234,12 @@ def parser() -> ArgumentParser: def main(): args = parser().parse_args() + if args.debug: + os.environ["NWB_LINKML_DEBUG"] = "true" + else: + if "NWB_LINKML_DEBUG" in os.environ: + del os.environ["NWB_LINKML_DEBUG"] + tmp_dir = make_tmp_dir(clear=True) git_dir = tmp_dir / "git" git_dir.mkdir(exist_ok=True)