mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
working pydantic generator
This commit is contained in:
parent
087064be48
commit
01cfb54a5a
2 changed files with 68 additions and 88 deletions
|
@ -13,7 +13,8 @@ dependencies = [
|
|||
"linkml-runtime>=1.7.7",
|
||||
"nwb-schema-language>=0.1.3",
|
||||
"rich>=13.5.2",
|
||||
"linkml>=1.7.10",
|
||||
#"linkml>=1.7.10",
|
||||
"linkml @ git+https://github.com/sneakers-the-rat/linkml@arrays-numpydantic",
|
||||
"nptyping>=2.5.0",
|
||||
"pydantic>=2.3.0",
|
||||
"h5py>=3.9.0",
|
||||
|
|
|
@ -29,22 +29,26 @@ The `serialize` method:
|
|||
# ruff: noqa
|
||||
|
||||
import inspect
|
||||
import pdb
|
||||
import sys
|
||||
import warnings
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from jinja2 import Template
|
||||
from linkml.generators import PydanticGenerator
|
||||
from linkml.generators.pydanticgen.array import ArrayRepresentation
|
||||
from linkml.generators.common.type_designators import (
|
||||
get_type_designator_value,
|
||||
)
|
||||
from linkml.utils.ifabsent_functions import ifabsent_value_declaration
|
||||
from linkml_runtime.linkml_model.meta import (
|
||||
Annotation,
|
||||
AnonymousSlotExpression,
|
||||
ArrayExpression,
|
||||
ClassDefinition,
|
||||
ClassDefinitionName,
|
||||
ElementName,
|
||||
|
@ -53,8 +57,9 @@ from linkml_runtime.linkml_model.meta import (
|
|||
SlotDefinitionName,
|
||||
)
|
||||
from linkml_runtime.utils.compile_python import file_text
|
||||
from linkml_runtime.utils.formatutils import camelcase, underscore
|
||||
from linkml_runtime.utils.formatutils import camelcase, underscore, remove_empty_items
|
||||
from linkml_runtime.utils.schemaview import SchemaView
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nwb_linkml.maps import flat_to_nptyping
|
||||
|
@ -268,6 +273,9 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
versions: dict = None
|
||||
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""
|
||||
pydantic_version = "2"
|
||||
array_representations: List[ArrayRepresentation] = field(
|
||||
default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC])
|
||||
black: bool = True
|
||||
|
||||
def _locate_imports(self, needed_classes: List[str], sv: SchemaView) -> Dict[str, List[str]]:
|
||||
"""
|
||||
|
@ -427,7 +435,16 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
): # pragma: no cover
|
||||
# Confirm that the original slot range (ignoring the default that comes in from
|
||||
# induced_slot) isn't in addition to setting any_of
|
||||
allowed_keys = ('array',)
|
||||
|
||||
if len(s.any_of) > 0 and sv.get_slot(sn).range is not None:
|
||||
allowed = True
|
||||
for option in s.any_of:
|
||||
items = remove_empty_items(option)
|
||||
if not all([key in allowed_keys for key in items.keys()]):
|
||||
allowed=False
|
||||
if allowed:
|
||||
return
|
||||
base_range_subsumes_any_of = False
|
||||
base_range = sv.get_slot(sn).range
|
||||
base_range_cls = sv.get_class(base_range, strict=False)
|
||||
|
@ -436,75 +453,6 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
if not base_range_subsumes_any_of:
|
||||
raise ValueError("Slot cannot have both range and any_of defined")
|
||||
|
||||
def _make_npytyping_range(self, attrs: Dict[str, SlotDefinition]) -> str:
|
||||
# slot always starts with...
|
||||
prefix = "NDArray["
|
||||
|
||||
# and then we specify the shape:
|
||||
shape_prefix = 'Shape["'
|
||||
|
||||
# using the cardinality from the attributes
|
||||
dim_pieces = []
|
||||
for attr in attrs.values():
|
||||
|
||||
if attr.maximum_cardinality:
|
||||
shape_part = str(attr.maximum_cardinality)
|
||||
else:
|
||||
shape_part = "*"
|
||||
|
||||
# do this with the most heinous chain of string replacements rather than regex
|
||||
# because i am still figuring out what needs to be subbed lol
|
||||
name_part = (
|
||||
attr.name.replace(",", "_")
|
||||
.replace(" ", "_")
|
||||
.replace("__", "_")
|
||||
.replace("|", "_")
|
||||
.replace("-", "_")
|
||||
.replace("+", "plus")
|
||||
)
|
||||
|
||||
dim_pieces.append(" ".join([shape_part, name_part]))
|
||||
|
||||
dimension = ", ".join(dim_pieces)
|
||||
|
||||
shape_suffix = '"], '
|
||||
|
||||
# all dimensions should be the same dtype
|
||||
try:
|
||||
dtype = flat_to_nptyping[list(attrs.values())[0].range]
|
||||
except KeyError as e: # pragma: no cover
|
||||
warnings.warn(str(e))
|
||||
range = list(attrs.values())[0].range
|
||||
return f"List[{range}] | {range}"
|
||||
suffix = "]"
|
||||
|
||||
slot = "".join([prefix, shape_prefix, dimension, shape_suffix, dtype, suffix])
|
||||
return slot
|
||||
|
||||
def _get_numpy_slot_range(self, cls: ClassDefinition) -> str:
|
||||
# if none of the dimensions are optional, we just have one possible array shape
|
||||
if all([s.required for s in cls.attributes.values()]): # pragma: no cover
|
||||
return self._make_npytyping_range(cls.attributes)
|
||||
# otherwise we need to make permutations
|
||||
# but not all permutations, because we typically just want to be able to exclude the last possible dimensions
|
||||
# the array classes should always be well-defined where the optional dimensions are at the end, so
|
||||
requireds = {k: v for k, v in cls.attributes.items() if v.required}
|
||||
optionals = [(k, v) for k, v in cls.attributes.items() if not v.required]
|
||||
|
||||
annotations = []
|
||||
if len(requireds) > 0:
|
||||
# first the base case
|
||||
annotations.append(self._make_npytyping_range(requireds))
|
||||
# then add back each optional dimension
|
||||
for i in range(len(optionals)):
|
||||
attrs = {**requireds, **{k: v for k, v in optionals[0 : i + 1]}}
|
||||
annotations.append(self._make_npytyping_range(attrs))
|
||||
|
||||
# now combine with a union:
|
||||
union = "Union[\n" + " " * 8
|
||||
union += (",\n" + " " * 8).join(annotations)
|
||||
union += "\n" + " " * 4 + "]"
|
||||
return union
|
||||
|
||||
def _get_linkml_classvar(self, cls: ClassDefinition) -> SlotDefinition:
|
||||
"""A class variable that holds additional linkml attrs"""
|
||||
|
@ -566,17 +514,6 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
self.sorted_class_names += [camelcase(c.name) for c in slist]
|
||||
return slist
|
||||
|
||||
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
|
||||
"""
|
||||
Monkeypatch to convert Array typed slots and classes into npytyped hints
|
||||
"""
|
||||
sv = self.schemaview
|
||||
range_cls = sv.get_class(slot_range)
|
||||
if range_cls.is_a == "Arraylike":
|
||||
return self._get_numpy_slot_range(range_cls)
|
||||
else:
|
||||
return self._get_class_slot_range_origin(slot_range, inlined, inlined_as_list)
|
||||
|
||||
def _get_class_slot_range_origin(
|
||||
self, slot_range: str, inlined: bool, inlined_as_list: bool
|
||||
) -> str:
|
||||
|
@ -694,6 +631,36 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
|
||||
return slot_value
|
||||
|
||||
def generate_python_range(self, slot_range, slot_def: SlotDefinition, class_def: ClassDefinition) -> str:
|
||||
"""
|
||||
Generate the python range for a slot range value
|
||||
"""
|
||||
if isinstance(slot_range, ArrayExpression):
|
||||
temp_slot = SlotDefinition(name='array', array=slot_range)
|
||||
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
|
||||
results = super().get_array_representations_range(temp_slot, inner_range)
|
||||
return results[0].annotation
|
||||
elif isinstance(slot_range, AnonymousSlotExpression):
|
||||
if slot_range.range is None:
|
||||
inner_range = slot_def.range
|
||||
else:
|
||||
inner_range = slot_range.range
|
||||
|
||||
inner_range = super().generate_python_range(inner_range, slot_def, class_def)
|
||||
if slot_range.array is not None:
|
||||
temp_slot = SlotDefinition(name='array', array=slot_range.array)
|
||||
results = super().get_array_representations_range(temp_slot, inner_range)
|
||||
inner_range = results[0].annotation
|
||||
return inner_range
|
||||
elif isinstance(slot_range, dict):
|
||||
pdb.set_trace()
|
||||
elif slot_def.array is not None:
|
||||
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
|
||||
results = super().get_array_representations_range(slot_def, inner_range)
|
||||
return results[0].annotation
|
||||
else:
|
||||
return super().generate_python_range(slot_range, slot_def, class_def)
|
||||
|
||||
def serialize(self) -> str:
|
||||
predefined_slot_values = {}
|
||||
"""splitting up parent class :meth:`.get_predefined_slot_values`"""
|
||||
|
@ -763,15 +730,22 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
s.description = s.description.replace('"', '\\"')
|
||||
class_def.attributes[s.name] = s
|
||||
|
||||
slot_ranges: List[str] = []
|
||||
slot_ranges: List[Union[str, ArrayExpression, AnonymousSlotExpression]] = []
|
||||
|
||||
self._check_anyof(s, sn, sv)
|
||||
|
||||
if s.any_of is not None and len(s.any_of) > 0:
|
||||
# list comprehension here is pulling ranges from within AnonymousSlotExpression
|
||||
slot_ranges.extend([r.range for r in s.any_of])
|
||||
if isinstance(s.any_of, dict):
|
||||
any_ofs = list(s.any_of.values())
|
||||
else:
|
||||
any_ofs = s.any_of
|
||||
slot_ranges.extend(any_ofs)
|
||||
else:
|
||||
slot_ranges.append(s.range)
|
||||
if s.array is not None:
|
||||
slot_ranges.append(s.array)
|
||||
else:
|
||||
slot_ranges.append(s.range)
|
||||
|
||||
pyranges = [
|
||||
self.generate_python_range(slot_range, s, class_def)
|
||||
|
@ -798,7 +772,12 @@ class NWBPydanticGenerator(PydanticGenerator):
|
|||
|
||||
if s.multivalued:
|
||||
if s.inlined or s.inlined_as_list:
|
||||
collection_key = self.generate_collection_key(slot_ranges, s, class_def)
|
||||
try:
|
||||
collection_key = self.generate_collection_key(slot_ranges, s, class_def)
|
||||
except TypeError:
|
||||
# from not being able to hash an anonymous slot expression.
|
||||
# hack, we can fix this by merging upstream pydantic generator cleanup
|
||||
collection_key = None
|
||||
else: # pragma: no cover
|
||||
collection_key = None
|
||||
if (
|
||||
|
|
Loading…
Reference in a new issue