working pydantic generator

This commit is contained in:
sneakers-the-rat 2024-07-03 01:34:24 -07:00
parent 087064be48
commit 01cfb54a5a
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
2 changed files with 68 additions and 88 deletions

View file

@ -13,7 +13,8 @@ dependencies = [
"linkml-runtime>=1.7.7", "linkml-runtime>=1.7.7",
"nwb-schema-language>=0.1.3", "nwb-schema-language>=0.1.3",
"rich>=13.5.2", "rich>=13.5.2",
"linkml>=1.7.10", #"linkml>=1.7.10",
"linkml @ git+https://github.com/sneakers-the-rat/linkml@arrays-numpydantic",
"nptyping>=2.5.0", "nptyping>=2.5.0",
"pydantic>=2.3.0", "pydantic>=2.3.0",
"h5py>=3.9.0", "h5py>=3.9.0",

View file

@ -29,22 +29,26 @@ The `serialize` method:
# ruff: noqa # ruff: noqa
import inspect import inspect
import pdb
import sys import sys
import warnings import warnings
from copy import copy from copy import copy
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type, Union
from jinja2 import Template from jinja2 import Template
from linkml.generators import PydanticGenerator from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation
from linkml.generators.common.type_designators import ( from linkml.generators.common.type_designators import (
get_type_designator_value, get_type_designator_value,
) )
from linkml.utils.ifabsent_functions import ifabsent_value_declaration from linkml.utils.ifabsent_functions import ifabsent_value_declaration
from linkml_runtime.linkml_model.meta import ( from linkml_runtime.linkml_model.meta import (
Annotation, Annotation,
AnonymousSlotExpression,
ArrayExpression,
ClassDefinition, ClassDefinition,
ClassDefinitionName, ClassDefinitionName,
ElementName, ElementName,
@ -53,8 +57,9 @@ from linkml_runtime.linkml_model.meta import (
SlotDefinitionName, SlotDefinitionName,
) )
from linkml_runtime.utils.compile_python import file_text from linkml_runtime.utils.compile_python import file_text
from linkml_runtime.utils.formatutils import camelcase, underscore from linkml_runtime.utils.formatutils import camelcase, underscore, remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView from linkml_runtime.utils.schemaview import SchemaView
from pydantic import BaseModel from pydantic import BaseModel
from nwb_linkml.maps import flat_to_nptyping from nwb_linkml.maps import flat_to_nptyping
@ -268,6 +273,9 @@ class NWBPydanticGenerator(PydanticGenerator):
versions: dict = None versions: dict = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from""" """See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""
pydantic_version = "2" pydantic_version = "2"
array_representations: List[ArrayRepresentation] = field(
default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC])
black: bool = True
def _locate_imports(self, needed_classes: List[str], sv: SchemaView) -> Dict[str, List[str]]: def _locate_imports(self, needed_classes: List[str], sv: SchemaView) -> Dict[str, List[str]]:
""" """
@ -427,7 +435,16 @@ class NWBPydanticGenerator(PydanticGenerator):
): # pragma: no cover ): # pragma: no cover
# Confirm that the original slot range (ignoring the default that comes in from # Confirm that the original slot range (ignoring the default that comes in from
# induced_slot) isn't in addition to setting any_of # induced_slot) isn't in addition to setting any_of
allowed_keys = ('array',)
if len(s.any_of) > 0 and sv.get_slot(sn).range is not None: if len(s.any_of) > 0 and sv.get_slot(sn).range is not None:
allowed = True
for option in s.any_of:
items = remove_empty_items(option)
if not all([key in allowed_keys for key in items.keys()]):
allowed=False
if allowed:
return
base_range_subsumes_any_of = False base_range_subsumes_any_of = False
base_range = sv.get_slot(sn).range base_range = sv.get_slot(sn).range
base_range_cls = sv.get_class(base_range, strict=False) base_range_cls = sv.get_class(base_range, strict=False)
@ -436,75 +453,6 @@ class NWBPydanticGenerator(PydanticGenerator):
if not base_range_subsumes_any_of: if not base_range_subsumes_any_of:
raise ValueError("Slot cannot have both range and any_of defined") raise ValueError("Slot cannot have both range and any_of defined")
def _make_npytyping_range(self, attrs: Dict[str, SlotDefinition]) -> str:
# slot always starts with...
prefix = "NDArray["
# and then we specify the shape:
shape_prefix = 'Shape["'
# using the cardinality from the attributes
dim_pieces = []
for attr in attrs.values():
if attr.maximum_cardinality:
shape_part = str(attr.maximum_cardinality)
else:
shape_part = "*"
# do this with the most heinous chain of string replacements rather than regex
# because i am still figuring out what needs to be subbed lol
name_part = (
attr.name.replace(",", "_")
.replace(" ", "_")
.replace("__", "_")
.replace("|", "_")
.replace("-", "_")
.replace("+", "plus")
)
dim_pieces.append(" ".join([shape_part, name_part]))
dimension = ", ".join(dim_pieces)
shape_suffix = '"], '
# all dimensions should be the same dtype
try:
dtype = flat_to_nptyping[list(attrs.values())[0].range]
except KeyError as e: # pragma: no cover
warnings.warn(str(e))
range = list(attrs.values())[0].range
return f"List[{range}] | {range}"
suffix = "]"
slot = "".join([prefix, shape_prefix, dimension, shape_suffix, dtype, suffix])
return slot
def _get_numpy_slot_range(self, cls: ClassDefinition) -> str:
# if none of the dimensions are optional, we just have one possible array shape
if all([s.required for s in cls.attributes.values()]): # pragma: no cover
return self._make_npytyping_range(cls.attributes)
# otherwise we need to make permutations
# but not all permutations, because we typically just want to be able to exclude the last possible dimensions
# the array classes should always be well-defined where the optional dimensions are at the end, so
requireds = {k: v for k, v in cls.attributes.items() if v.required}
optionals = [(k, v) for k, v in cls.attributes.items() if not v.required]
annotations = []
if len(requireds) > 0:
# first the base case
annotations.append(self._make_npytyping_range(requireds))
# then add back each optional dimension
for i in range(len(optionals)):
attrs = {**requireds, **{k: v for k, v in optionals[0 : i + 1]}}
annotations.append(self._make_npytyping_range(attrs))
# now combine with a union:
union = "Union[\n" + " " * 8
union += (",\n" + " " * 8).join(annotations)
union += "\n" + " " * 4 + "]"
return union
def _get_linkml_classvar(self, cls: ClassDefinition) -> SlotDefinition: def _get_linkml_classvar(self, cls: ClassDefinition) -> SlotDefinition:
"""A class variable that holds additional linkml attrs""" """A class variable that holds additional linkml attrs"""
@ -566,17 +514,6 @@ class NWBPydanticGenerator(PydanticGenerator):
self.sorted_class_names += [camelcase(c.name) for c in slist] self.sorted_class_names += [camelcase(c.name) for c in slist]
return slist return slist
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
"""
Monkeypatch to convert Array typed slots and classes into npytyped hints
"""
sv = self.schemaview
range_cls = sv.get_class(slot_range)
if range_cls.is_a == "Arraylike":
return self._get_numpy_slot_range(range_cls)
else:
return self._get_class_slot_range_origin(slot_range, inlined, inlined_as_list)
def _get_class_slot_range_origin( def _get_class_slot_range_origin(
self, slot_range: str, inlined: bool, inlined_as_list: bool self, slot_range: str, inlined: bool, inlined_as_list: bool
) -> str: ) -> str:
@ -694,6 +631,36 @@ class NWBPydanticGenerator(PydanticGenerator):
return slot_value return slot_value
def generate_python_range(self, slot_range, slot_def: SlotDefinition, class_def: ClassDefinition) -> str:
"""
Generate the python range for a slot range value
"""
if isinstance(slot_range, ArrayExpression):
temp_slot = SlotDefinition(name='array', array=slot_range)
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
results = super().get_array_representations_range(temp_slot, inner_range)
return results[0].annotation
elif isinstance(slot_range, AnonymousSlotExpression):
if slot_range.range is None:
inner_range = slot_def.range
else:
inner_range = slot_range.range
inner_range = super().generate_python_range(inner_range, slot_def, class_def)
if slot_range.array is not None:
temp_slot = SlotDefinition(name='array', array=slot_range.array)
results = super().get_array_representations_range(temp_slot, inner_range)
inner_range = results[0].annotation
return inner_range
elif isinstance(slot_range, dict):
pdb.set_trace()
elif slot_def.array is not None:
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
results = super().get_array_representations_range(slot_def, inner_range)
return results[0].annotation
else:
return super().generate_python_range(slot_range, slot_def, class_def)
def serialize(self) -> str: def serialize(self) -> str:
predefined_slot_values = {} predefined_slot_values = {}
"""splitting up parent class :meth:`.get_predefined_slot_values`""" """splitting up parent class :meth:`.get_predefined_slot_values`"""
@ -763,15 +730,22 @@ class NWBPydanticGenerator(PydanticGenerator):
s.description = s.description.replace('"', '\\"') s.description = s.description.replace('"', '\\"')
class_def.attributes[s.name] = s class_def.attributes[s.name] = s
slot_ranges: List[str] = [] slot_ranges: List[Union[str, ArrayExpression, AnonymousSlotExpression]] = []
self._check_anyof(s, sn, sv) self._check_anyof(s, sn, sv)
if s.any_of is not None and len(s.any_of) > 0: if s.any_of is not None and len(s.any_of) > 0:
# list comprehension here is pulling ranges from within AnonymousSlotExpression # list comprehension here is pulling ranges from within AnonymousSlotExpression
slot_ranges.extend([r.range for r in s.any_of]) if isinstance(s.any_of, dict):
any_ofs = list(s.any_of.values())
else:
any_ofs = s.any_of
slot_ranges.extend(any_ofs)
else: else:
slot_ranges.append(s.range) if s.array is not None:
slot_ranges.append(s.array)
else:
slot_ranges.append(s.range)
pyranges = [ pyranges = [
self.generate_python_range(slot_range, s, class_def) self.generate_python_range(slot_range, s, class_def)
@ -798,7 +772,12 @@ class NWBPydanticGenerator(PydanticGenerator):
if s.multivalued: if s.multivalued:
if s.inlined or s.inlined_as_list: if s.inlined or s.inlined_as_list:
collection_key = self.generate_collection_key(slot_ranges, s, class_def) try:
collection_key = self.generate_collection_key(slot_ranges, s, class_def)
except TypeError:
# from not being able to hash an anonymous slot expression.
# hack, we can fix this by merging upstream pydantic generator cleanup
collection_key = None
else: # pragma: no cover else: # pragma: no cover
collection_key = None collection_key = None
if ( if (