diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 5d64d38..d478065 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -13,7 +13,8 @@ dependencies = [ "linkml-runtime>=1.7.7", "nwb-schema-language>=0.1.3", "rich>=13.5.2", - "linkml>=1.7.10", + #"linkml>=1.7.10", + "linkml @ git+https://github.com/sneakers-the-rat/linkml@arrays-numpydantic", "nptyping>=2.5.0", "pydantic>=2.3.0", "h5py>=3.9.0", diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 9cd63fa..285f4ce 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -29,22 +29,26 @@ The `serialize` method: # ruff: noqa import inspect +import pdb import sys import warnings from copy import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from types import ModuleType -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type, Union from jinja2 import Template from linkml.generators import PydanticGenerator +from linkml.generators.pydanticgen.array import ArrayRepresentation from linkml.generators.common.type_designators import ( get_type_designator_value, ) from linkml.utils.ifabsent_functions import ifabsent_value_declaration from linkml_runtime.linkml_model.meta import ( Annotation, + AnonymousSlotExpression, + ArrayExpression, ClassDefinition, ClassDefinitionName, ElementName, @@ -53,8 +57,9 @@ from linkml_runtime.linkml_model.meta import ( SlotDefinitionName, ) from linkml_runtime.utils.compile_python import file_text -from linkml_runtime.utils.formatutils import camelcase, underscore +from linkml_runtime.utils.formatutils import camelcase, underscore, remove_empty_items from linkml_runtime.utils.schemaview import SchemaView + from pydantic import BaseModel from nwb_linkml.maps import flat_to_nptyping @@ -268,6 +273,9 @@ class NWBPydanticGenerator(PydanticGenerator): versions: dict = None """See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from""" pydantic_version = "2" + array_representations: List[ArrayRepresentation] = field( + default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC]) + black: bool = True def _locate_imports(self, needed_classes: List[str], sv: SchemaView) -> Dict[str, List[str]]: """ @@ -427,7 +435,16 @@ class NWBPydanticGenerator(PydanticGenerator): ): # pragma: no cover # Confirm that the original slot range (ignoring the default that comes in from # induced_slot) isn't in addition to setting any_of + allowed_keys = ('array',) + if len(s.any_of) > 0 and sv.get_slot(sn).range is not None: + allowed = True + for option in s.any_of: + items = remove_empty_items(option) + if not all([key in allowed_keys for key in items.keys()]): + allowed=False + if allowed: + return base_range_subsumes_any_of = False base_range = sv.get_slot(sn).range base_range_cls = sv.get_class(base_range, strict=False) @@ -436,75 +453,6 @@ class NWBPydanticGenerator(PydanticGenerator): if not base_range_subsumes_any_of: raise ValueError("Slot cannot have both range and any_of defined") - def _make_npytyping_range(self, attrs: Dict[str, SlotDefinition]) -> str: - # slot always starts with... - prefix = "NDArray[" - - # and then we specify the shape: - shape_prefix = 'Shape["' - - # using the cardinality from the attributes - dim_pieces = [] - for attr in attrs.values(): - - if attr.maximum_cardinality: - shape_part = str(attr.maximum_cardinality) - else: - shape_part = "*" - - # do this with the most heinous chain of string replacements rather than regex - # because i am still figuring out what needs to be subbed lol - name_part = ( - attr.name.replace(",", "_") - .replace(" ", "_") - .replace("__", "_") - .replace("|", "_") - .replace("-", "_") - .replace("+", "plus") - ) - - dim_pieces.append(" ".join([shape_part, name_part])) - - dimension = ", ".join(dim_pieces) - - shape_suffix = '"], ' - - # all dimensions should be the same dtype - try: - dtype = flat_to_nptyping[list(attrs.values())[0].range] - except KeyError as e: # pragma: no cover - warnings.warn(str(e)) - range = list(attrs.values())[0].range - return f"List[{range}] | {range}" - suffix = "]" - - slot = "".join([prefix, shape_prefix, dimension, shape_suffix, dtype, suffix]) - return slot - - def _get_numpy_slot_range(self, cls: ClassDefinition) -> str: - # if none of the dimensions are optional, we just have one possible array shape - if all([s.required for s in cls.attributes.values()]): # pragma: no cover - return self._make_npytyping_range(cls.attributes) - # otherwise we need to make permutations - # but not all permutations, because we typically just want to be able to exclude the last possible dimensions - # the array classes should always be well-defined where the optional dimensions are at the end, so - requireds = {k: v for k, v in cls.attributes.items() if v.required} - optionals = [(k, v) for k, v in cls.attributes.items() if not v.required] - - annotations = [] - if len(requireds) > 0: - # first the base case - annotations.append(self._make_npytyping_range(requireds)) - # then add back each optional dimension - for i in range(len(optionals)): - attrs = {**requireds, **{k: v for k, v in optionals[0 : i + 1]}} - annotations.append(self._make_npytyping_range(attrs)) - - # now combine with a union: - union = "Union[\n" + " " * 8 - union += (",\n" + " " * 8).join(annotations) - union += "\n" + " " * 4 + "]" - return union def _get_linkml_classvar(self, cls: ClassDefinition) -> SlotDefinition: """A class variable that holds additional linkml attrs""" @@ -566,17 +514,6 @@ class NWBPydanticGenerator(PydanticGenerator): self.sorted_class_names += [camelcase(c.name) for c in slist] return slist - def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str: - """ - Monkeypatch to convert Array typed slots and classes into npytyped hints - """ - sv = self.schemaview - range_cls = sv.get_class(slot_range) - if range_cls.is_a == "Arraylike": - return self._get_numpy_slot_range(range_cls) - else: - return self._get_class_slot_range_origin(slot_range, inlined, inlined_as_list) - def _get_class_slot_range_origin( self, slot_range: str, inlined: bool, inlined_as_list: bool ) -> str: @@ -694,6 +631,36 @@ class NWBPydanticGenerator(PydanticGenerator): return slot_value + def generate_python_range(self, slot_range, slot_def: SlotDefinition, class_def: ClassDefinition) -> str: + """ + Generate the python range for a slot range value + """ + if isinstance(slot_range, ArrayExpression): + temp_slot = SlotDefinition(name='array', array=slot_range) + inner_range = super().generate_python_range(slot_def.range, slot_def, class_def) + results = super().get_array_representations_range(temp_slot, inner_range) + return results[0].annotation + elif isinstance(slot_range, AnonymousSlotExpression): + if slot_range.range is None: + inner_range = slot_def.range + else: + inner_range = slot_range.range + + inner_range = super().generate_python_range(inner_range, slot_def, class_def) + if slot_range.array is not None: + temp_slot = SlotDefinition(name='array', array=slot_range.array) + results = super().get_array_representations_range(temp_slot, inner_range) + inner_range = results[0].annotation + return inner_range + elif isinstance(slot_range, dict): + pdb.set_trace() + elif slot_def.array is not None: + inner_range = super().generate_python_range(slot_def.range, slot_def, class_def) + results = super().get_array_representations_range(slot_def, inner_range) + return results[0].annotation + else: + return super().generate_python_range(slot_range, slot_def, class_def) + def serialize(self) -> str: predefined_slot_values = {} """splitting up parent class :meth:`.get_predefined_slot_values`""" @@ -763,15 +730,22 @@ class NWBPydanticGenerator(PydanticGenerator): s.description = s.description.replace('"', '\\"') class_def.attributes[s.name] = s - slot_ranges: List[str] = [] + slot_ranges: List[Union[str, ArrayExpression, AnonymousSlotExpression]] = [] self._check_anyof(s, sn, sv) if s.any_of is not None and len(s.any_of) > 0: # list comprehension here is pulling ranges from within AnonymousSlotExpression - slot_ranges.extend([r.range for r in s.any_of]) + if isinstance(s.any_of, dict): + any_ofs = list(s.any_of.values()) + else: + any_ofs = s.any_of + slot_ranges.extend(any_ofs) else: - slot_ranges.append(s.range) + if s.array is not None: + slot_ranges.append(s.array) + else: + slot_ranges.append(s.range) pyranges = [ self.generate_python_range(slot_range, s, class_def) @@ -798,7 +772,12 @@ class NWBPydanticGenerator(PydanticGenerator): if s.multivalued: if s.inlined or s.inlined_as_list: - collection_key = self.generate_collection_key(slot_ranges, s, class_def) + try: + collection_key = self.generate_collection_key(slot_ranges, s, class_def) + except TypeError: + # from not being able to hash an anonymous slot expression. + # hack, we can fix this by merging upstream pydantic generator cleanup + collection_key = None else: # pragma: no cover collection_key = None if (