nwb-linkml/nwb_linkml/generators/pydantic.py
sneakers-the-rat be21325123 Cleaner code generation, npytyping type hints for arrays
- split off generated subclasses into "include" files, not sure if that's good, but in any case we have them separable now.
- more work on cleanly un-nesting scalar and 1D-vector data into attributes and lists, respectively
- brought the pydantic generator in-repo to do a bunch of overrides
2023-08-28 22:16:58 -07:00

447 lines
No EOL
16 KiB
Python

"""
Subclass of :class:`linkml.generators.PydanticGenerator`
The pydantic generator is a subclass of
- :class:`linkml.utils.generator.Generator`
- :class:`linkml.generators.oocodegen.OOCodeGenerator`
The default `__main__` method
- Instantiates the class
- Calls :meth:`~linkml.generators.PydanticGenerator.serialize`
The `serialize` method
- Accepts an optional jinja-style template, otherwise it uses the default template
- Uses :class:`linkml_runtime.utils.schemaview.SchemaView` to interact with the schema
- Generates linkML Classes
- `generate_enums` runs first
"""
import pdb
from typing import List, Dict, Set
from copy import deepcopy
import warnings
from nwb_linkml.maps.dtype import flat_to_npytyping
from linkml.generators import PydanticGenerator
from linkml_runtime.linkml_model.meta import (
Annotation,
ClassDefinition,
SchemaDefinition,
SlotDefinition,
SlotDefinitionName,
TypeDefinition,
ElementName
)
from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.schemaview import SchemaView
from jinja2 import Template
def default_template(pydantic_ver: str = "1") -> str:
"""Constructs a default template for pydantic classes based on the version of pydantic"""
### HEADER ###
template = """
{#-
Jinja2 Template for a pydantic classes
-#}
from __future__ import annotations
from datetime import datetime, date
from enum import Enum
from typing import List, Dict, Optional, Any, Union
from pydantic import BaseModel as BaseModel, Field
from nptyping import NDArray, Shape, Float, Float32, Double, Float64, LongLong, Int64, Int, Int32, Int16, Short, Int8, UInt, UInt32, UInt16, UInt8, UInt64, Number, String, Unicode, Unicode, Unicode, String, Bool, Datetime64
import sys
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
{% for import_module, import_classes in imports.items() %}
from {{ import_module }} import (
{{ import_classes | join(',\n ') }}
)
{% endfor %}
metamodel_version = "{{metamodel_version}}"
version = "{{version if version else None}}"
"""
### BASE MODEL ###
if pydantic_ver == "1":
template += """
class WeakRefShimBaseModel(BaseModel):
__slots__ = '__weakref__'
class ConfiguredBaseModel(WeakRefShimBaseModel,
validate_assignment = True,
validate_all = True,
underscore_attrs_are_private = True,
extra = {% if allow_extra %}'allow'{% else %}'forbid'{% endif %},
arbitrary_types_allowed = True,
use_enum_values = True):
pass
"""
else:
template += """
class ConfiguredBaseModel(BaseModel,
validate_assignment = True,
validate_default = True,
extra = {% if allow_extra %}'allow'{% else %}'forbid'{% endif %},
arbitrary_types_allowed = True,
use_enum_values = True):
pass
"""
### ENUMS ###
template += """
{% for e in enums.values() %}
class {{ e.name }}(str, Enum):
{% if e.description -%}
\"\"\"
{{ e.description }}
\"\"\"
{%- endif %}
{% for _, pv in e['values'].items() -%}
{% if pv.description -%}
# {{pv.description}}
{%- endif %}
{{pv.label}} = "{{pv.value}}"
{% endfor %}
{% if not e['values'] -%}
dummy = "dummy"
{% endif %}
{% endfor %}
"""
### CLASSES ###
template += """
{%- for c in schema.classes.values() %}
class {{ c.name }}
{%- if class_isa_plus_mixins[c.name] -%}
({{class_isa_plus_mixins[c.name]|join(', ')}})
{%- else -%}
(ConfiguredBaseModel)
{%- endif -%}
:
{% if c.description -%}
\"\"\"
{{ c.description }}
\"\"\"
{%- endif %}
{% for attr in c.attributes.values() if c.attributes -%}
{{attr.name}}: {{ attr.annotations['python_range'].value }} = Field(
{%- if predefined_slot_values[c.name][attr.name] -%}
{{ predefined_slot_values[c.name][attr.name] }}
{%- elif attr.required -%}
...
{%- else -%}
None
{%- endif -%}
{%- if attr.title != None %}, title="{{attr.title}}"{% endif -%}
{%- if attr.description %}, description=\"\"\"{{attr.description}}\"\"\"{% endif -%}
{%- if attr.minimum_value != None %}, ge={{attr.minimum_value}}{% endif -%}
{%- if attr.maximum_value != None %}, le={{attr.maximum_value}}{% endif -%}
)
{% else -%}
None
{% endfor %}
{% endfor %}
"""
### FWD REFS / REBUILD MODEL ###
if pydantic_ver == "1":
template += """
# Update forward refs
# see https://pydantic-docs.helpmanual.io/usage/postponed_annotations/
{% for c in schema.classes.values() -%}
{{ c.name }}.update_forward_refs()
{% endfor %}
"""
else:
template += """
# Model rebuild
# see https://pydantic-docs.helpmanual.io/usage/models/#rebuilding-a-model
{% for c in schema.classes.values() -%}
{{ c.name }}.model_rebuild()
{% endfor %}
"""
return template
class NWBPydanticGenerator(PydanticGenerator):
SKIP_ENUM=('FlatDType',)
def _get_imports(self, sv:SchemaView) -> Dict[str, List[str]]:
all_classes = sv.all_classes(imports=True)
local_classes = sv.all_classes(imports=False)
needed_classes = []
# find needed classes - is_a and slot ranges
for clsname, cls in local_classes.items():
needed_classes.append(cls.is_a)
for slot_name, slot in cls.attributes.items():
if slot.range in all_classes:
needed_classes.append(slot.range)
needed_classes = [cls for cls in set(needed_classes) if cls is not None]
imports = {}
# These classes are not generated by pydantic!
skips = ('AnyType',)
for cls in needed_classes:
if cls in skips:
continue
# Find module that contains class
module_name = sv.element_by_schema_map()[ElementName(cls)]
# Don't get classes that are defined in this schema!
if module_name == self.schema.name:
continue
local_mod_name = '.' + module_name.replace('.', '_').replace('-','_')
if local_mod_name not in imports:
imports[local_mod_name] = [camelcase(cls)]
else:
imports[local_mod_name].append(camelcase(cls))
return imports
def _get_classes(self, sv:SchemaView, imports: Dict[str, List[str]]) -> List[ClassDefinition]:
module_classes = sv.all_classes(imports=False).values()
imported_classes = []
for classes in imports.values():
imported_classes.extend(classes)
# pdb.set_trace()
sorted_classes = self.sort_classes(list(module_classes), imported_classes)
self.sorted_class_names = [camelcase(cname) for cname in imported_classes]
self.sorted_class_names += [camelcase(c.name) for c in sorted_classes]
# Don't want to generate classes when class_uri is linkml:Any, will
# just swap in typing.Any instead down below
sorted_classes = [c for c in sorted_classes if c.class_uri != "linkml:Any"]
return sorted_classes
def _build_class(self, class_original:ClassDefinition) -> ClassDefinition:
class_def: ClassDefinition
class_def = deepcopy(class_original)
class_def.name = camelcase(class_original.name)
if class_def.is_a:
class_def.is_a = camelcase(class_def.is_a)
class_def.mixins = [camelcase(p) for p in class_def.mixins]
if class_def.description:
class_def.description = class_def.description.replace('"', '\\"')
return class_def
def _check_anyof(self, s:SlotDefinition, sn: SlotDefinitionName, sv:SchemaView):
# Confirm that the original slot range (ignoring the default that comes in from
# induced_slot) isn't in addition to setting any_of
if len(s.any_of) > 0 and sv.get_slot(sn).range is not None:
base_range_subsumes_any_of = False
base_range = sv.get_slot(sn).range
base_range_cls = sv.get_class(base_range, strict=False)
if base_range_cls is not None and base_range_cls.class_uri == "linkml:Any":
base_range_subsumes_any_of = True
if not base_range_subsumes_any_of:
raise ValueError("Slot cannot have both range and any_of defined")
def _get_numpy_slot_range(self, cls:ClassDefinition) -> str:
# slot always starts with...
prefix='NDArray['
# and then we specify the shape:
shape_prefix = 'Shape["'
# using the cardinality from the attributes
dim_pieces = []
for attr in cls.attributes.values():
if attr.maximum_cardinality:
shape_part = str(attr.maximum_cardinality)
else:
shape_part = "*"
# do this cheaply instead of using regex because i want to see if this works at all first...
name_part = attr.name.replace(',', '_').replace(' ', '_').replace('__', '_')
dim_pieces.append(' '.join([shape_part, name_part]))
dimension = ', '.join(dim_pieces)
shape_suffix = '"], '
# all dimensions should be the same dtype
try:
dtype = flat_to_npytyping[list(cls.attributes.values())[0].range]
except KeyError as e:
warnings.warn(e)
range = list(cls.attributes.values())[0].range
return f'List[{range}] | {range}'
suffix = "]"
slot = ''.join([prefix, shape_prefix, dimension, shape_suffix, dtype, suffix])
return slot
def sort_classes(self, clist: List[ClassDefinition], imports:List[str]) -> List[ClassDefinition]:
"""
sort classes such that if C is a child of P then C appears after P in the list
Overridden method include mixin classes
Modified from original to allow for imported classes
"""
clist = list(clist)
slist = [] # sorted
while len(clist) > 0:
can_add = False
for i in range(len(clist)):
candidate = clist[i]
can_add = False
if candidate.is_a:
candidates = [candidate.is_a] + candidate.mixins
else:
candidates = candidate.mixins
if not candidates:
can_add = True
else:
if set(candidates) <= set([p.name for p in slist] + imports):
can_add = True
if can_add:
slist = slist + [candidate]
del clist[i]
break
if not can_add:
raise ValueError(
f"could not find suitable element in {clist} that does not ref {slist}"
)
return slist
def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
"""
Monkeypatch to convert Array typed slots and classes into npytyped hints
"""
sv = self.schemaview
range_cls = sv.get_class(slot_range)
if range_cls.is_a == "Arraylike":
return self._get_numpy_slot_range(range_cls)
else:
return super().get_class_slot_range(slot_range, inlined, inlined_as_list)
def get_class_isa_plus_mixins(self) -> Dict[str, List[str]]:
"""
Generate the inheritance list for each class from is_a plus mixins
Patched to only get local classes
:return:
"""
sv = self.schemaview
parents = {}
for class_def in sv.all_classes(imports=False).values():
class_parents = []
if class_def.is_a:
class_parents.append(camelcase(class_def.is_a))
if self.gen_mixin_inheritance and class_def.mixins:
class_parents.extend([camelcase(mixin) for mixin in class_def.mixins])
if len(class_parents) > 0:
# Use the sorted list of classes to order the parent classes, but reversed to match MRO needs
class_parents.sort(key=lambda x: self.sorted_class_names.index(x))
class_parents.reverse()
parents[camelcase(class_def.name)] = class_parents
return parents
def serialize(self) -> str:
if self.template_file is not None:
with open(self.template_file) as template_file:
template_obj = Template(template_file.read())
else:
template_obj = Template(default_template(self.pydantic_version))
sv: SchemaView
sv = self.schemaview
schema = sv.schema
pyschema = SchemaDefinition(
id=schema.id,
name=schema.name,
description=schema.description.replace('"', '\\"') if schema.description else None,
)
enums = self.generate_enums(sv.all_enums())
# filter skipped enums
enums = {k:v for k,v in enums.items() if k not in self.SKIP_ENUM}
# import from local references, rather than serializing every class in every file
imports = self._get_imports(sv)
sorted_classes = self._get_classes(sv, imports)
for class_original in sorted_classes:
# Generate class definition
class_def = self._build_class(class_original)
pyschema.classes[class_def.name] = class_def
# Not sure why this happens
for attribute in list(class_def.attributes.keys()):
del class_def.attributes[attribute]
class_name = class_original.name
for sn in sv.class_slots(class_name):
# TODO: fix runtime, copy should not be necessary
s = deepcopy(sv.induced_slot(sn, class_name))
# logging.error(f'Induced slot {class_name}.{sn} == {s.name} {s.range}')
s.name = underscore(s.name)
if s.description:
s.description = s.description.replace('"', '\\"')
class_def.attributes[s.name] = s
slot_ranges: List[str] = []
self._check_anyof(s, sn, sv)
if s.any_of is not None and len(s.any_of) > 0:
# list comprehension here is pulling ranges from within AnonymousSlotExpression
slot_ranges.extend([r.range for r in s.any_of])
else:
slot_ranges.append(s.range)
pyranges = [
self.generate_python_range(slot_range, s, class_def)
for slot_range in slot_ranges
]
pyranges = list(set(pyranges)) # remove duplicates
pyranges.sort()
if len(pyranges) == 1:
pyrange = pyranges[0]
elif len(pyranges) > 1:
pyrange = f"Union[{', '.join(pyranges)}]"
else:
raise Exception(f"Could not generate python range for {class_name}.{s.name}")
if s.multivalued:
if s.inlined or s.inlined_as_list:
collection_key = self.generate_collection_key(slot_ranges, s, class_def)
else:
collection_key = None
if s.inlined is False or collection_key is None or s.inlined_as_list is True:
pyrange = f"List[{pyrange}]"
else:
pyrange = f"Dict[{collection_key}, {pyrange}]"
if not s.required and not s.designates_type:
pyrange = f"Optional[{pyrange}]"
ann = Annotation("python_range", pyrange)
s.annotations[ann.tag] = ann
code = template_obj.render(
imports=imports,
schema=pyschema,
underscore=underscore,
enums=enums,
predefined_slot_values=self.get_predefined_slot_values(),
allow_extra=self.allow_extra,
metamodel_version=self.schema.metamodel_version,
version=self.schema.version,
class_isa_plus_mixins=self.get_class_isa_plus_mixins(),
)
return code