working pydantic generator

This commit is contained in:
sneakers-the-rat 2024-07-03 01:34:24 -07:00
parent 087064be48
commit 01cfb54a5a
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 68 additions and 88 deletions

View file

@ -13,7 +13,8 @@ dependencies = [
"linkml-runtime>=1.7.7",
"nwb-schema-language>=0.1.3",
"rich>=13.5.2",
"linkml>=1.7.10",
#"linkml>=1.7.10",
"linkml @ git+https://github.com/sneakers-the-rat/linkml@arrays-numpydantic",
"nptyping>=2.5.0",
"pydantic>=2.3.0",
"h5py>=3.9.0",

View file

@ -29,22 +29,26 @@ The `serialize` method:
# ruff: noqa
import inspect
import pdb
import sys
import warnings
from copy import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type, Union
from jinja2 import Template
from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation
from linkml.generators.common.type_designators import (
get_type_designator_value,
)
from linkml.utils.ifabsent_functions import ifabsent_value_declaration
from linkml_runtime.linkml_model.meta import (
Annotation,
AnonymousSlotExpression,
ArrayExpression,
ClassDefinition,
ClassDefinitionName,
ElementName,
@ -53,8 +57,9 @@ from linkml_runtime.linkml_model.meta import (
SlotDefinitionName,
)
from linkml_runtime.utils.compile_python import file_text
from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.formatutils import camelcase, underscore, remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView
from pydantic import BaseModel
from nwb_linkml.maps import flat_to_nptyping
@ -268,6 +273,9 @@ class NWBPydanticGenerator(PydanticGenerator):
versions: dict = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""
pydantic_version = "2"
array_representations: List[ArrayRepresentation] = field(
default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC])
black: bool = True
def _locate_imports(self, needed_classes: List[str], sv: SchemaView) -> Dict[str, List[str]]:
"""
@ -427,7 +435,16 @@ class NWBPydanticGenerator(PydanticGenerator):
): # pragma: no cover
# Confirm that the original slot range (ignoring the default that comes in from
# induced_slot) isn't in addition to setting any_of
allowed_keys = ('array',)
if len(s.any_of) > 0 and sv.get_slot(sn).range is not None:
allowed = True
for option in s.any_of:
items = remove_empty_items(option)
if not all([key in allowed_keys for key in items.keys()]):
allowed=False
if allowed:
return
base_range_subsumes_any_of = False
base_range = sv.get_slot(sn).range
base_range_cls = sv.get_class(base_range, strict=False)
@ -436,75 +453,6 @@ class NWBPydanticGenerator(PydanticGenerator):
if not base_range_subsumes_any_of:
raise ValueError("Slot cannot have both range and any_of defined")
def _make_npytyping_range(self, attrs: Dict[str, SlotDefinition]) -> str:
# slot always starts with...
prefix = "NDArray["
# and then we specify the shape:
shape_prefix = 'Shape["'
# using the cardinality from the attributes
dim_pieces = []
for attr in attrs.values():
if attr.maximum_cardinality:
shape_part = str(attr.maximum_cardinality)
else:
shape_part = "*"
# do this with the most heinous chain of string replacements rather than regex
# because i am still figuring out what needs to be subbed lol
name_part = (
attr.name.replace(",", "_")
.replace(" ", "_")
.replace("__", "_")
.replace("|", "_")
.replace("-", "_")
.replace("+", "plus")
)
dim_pieces.append(" ".join([shape_part, name_part]))
dimension = ", ".join(dim_pieces)
shape_suffix = '"], '
# all dimensions should be the same dtype
try:
dtype = flat_to_nptyping[list(attrs.values())[0].range]
except KeyError as e: # pragma: no cover
warnings.warn(str(e))
range = list(attrs.values())[0].range
return f"List[{range}] | {range}"
suffix = "]"
slot = "".join([prefix, shape_prefix, dimension, shape_suffix, dtype, suffix])
return slot
def _get_numpy_slot_range(self, cls: ClassDefinition) -> str:
# if none of the dimensions are optional, we just have one possible array shape
if all([s.required for s in cls.attributes.values()]): # pragma: no cover
return self._make_npytyping_range(cls.attributes)
# otherwise we need to make permutations
# but not all permutations, because we typically just want to be able to exclude the last possible dimensions
# the array classes should always be well-defined where the optional dimensions are at the end, so
requireds = {k: v for k, v in cls.attributes.items() if v.required}
optionals = [(k, v) for k, v in cls.attributes.items() if not v.required]
annotations = []
if len(requireds) > 0:
# first the base case
annotations.append(self._make_npytyping_range(requireds))
# then add back each optional dimension
for i in range(len(optionals)):
attrs = {**requireds, **{k: v for k, v in optionals[0 : i + 1]}}
annotations.append(self._make_npytyping_range(attrs))
# now combine with a union:
union = "Union[\n" + " " * 8
union += (",\n" + " " * 8).join(annotations)
union += "\n" + " " * 4 + "]"
return union
def _get_linkml_classvar(self, cls: ClassDefinition) -> SlotDefinition:
"""A class variable that holds additional linkml attrs"""
@ -566,17 +514,6 @@ class NWBPydanticGenerator(PydanticGenerator):
self.sorted_class_names += [camelcase(c.name) for c in slist]
return slist
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
"""
Monkeypatch to convert Array typed slots and classes into npytyped hints
"""
sv = self.schemaview
range_cls = sv.get_class(slot_range)
if range_cls.is_a == "Arraylike":
return self._get_numpy_slot_range(range_cls)
else:
return self._get_class_slot_range_origin(slot_range, inlined, inlined_as_list)
def _get_class_slot_range_origin(
self, slot_range: str, inlined: bool, inlined_as_list: bool
) -> str:
@ -694,6 +631,36 @@ class NWBPydanticGenerator(PydanticGenerator):
return slot_value
def generate_python_range(self, slot_range, slot_def: SlotDefinition, class_def: ClassDefinition) -> str:
"""
Generate the python range for a slot range value
"""
if isinstance(slot_range, ArrayExpression):
temp_slot = SlotDefinition(name='array', array=slot_range)
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
results = super().get_array_representations_range(temp_slot, inner_range)
return results[0].annotation
elif isinstance(slot_range, AnonymousSlotExpression):
if slot_range.range is None:
inner_range = slot_def.range
else:
inner_range = slot_range.range
inner_range = super().generate_python_range(inner_range, slot_def, class_def)
if slot_range.array is not None:
temp_slot = SlotDefinition(name='array', array=slot_range.array)
results = super().get_array_representations_range(temp_slot, inner_range)
inner_range = results[0].annotation
return inner_range
elif isinstance(slot_range, dict):
pdb.set_trace()
elif slot_def.array is not None:
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
results = super().get_array_representations_range(slot_def, inner_range)
return results[0].annotation
else:
return super().generate_python_range(slot_range, slot_def, class_def)
def serialize(self) -> str:
predefined_slot_values = {}
"""splitting up parent class :meth:`.get_predefined_slot_values`"""
@ -763,13 +730,20 @@ class NWBPydanticGenerator(PydanticGenerator):
s.description = s.description.replace('"', '\\"')
class_def.attributes[s.name] = s
slot_ranges: List[str] = []
slot_ranges: List[Union[str, ArrayExpression, AnonymousSlotExpression]] = []
self._check_anyof(s, sn, sv)
if s.any_of is not None and len(s.any_of) > 0:
# list comprehension here is pulling ranges from within AnonymousSlotExpression
slot_ranges.extend([r.range for r in s.any_of])
if isinstance(s.any_of, dict):
any_ofs = list(s.any_of.values())
else:
any_ofs = s.any_of
slot_ranges.extend(any_ofs)
else:
if s.array is not None:
slot_ranges.append(s.array)
else:
slot_ranges.append(s.range)
@ -798,7 +772,12 @@ class NWBPydanticGenerator(PydanticGenerator):
if s.multivalued:
if s.inlined or s.inlined_as_list:
try:
collection_key = self.generate_collection_key(slot_ranges, s, class_def)
except TypeError:
# from not being able to hash an anonymous slot expression.
# hack, we can fix this by merging upstream pydantic generator cleanup
collection_key = None
else: # pragma: no cover
collection_key = None
if (