This commit is contained in:
sneakers-the-rat 2024-07-31 01:27:45 -07:00
parent abf1b0e6c0
commit 38e8a6f7a0
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
5 changed files with 48 additions and 70 deletions

View file

@ -1,75 +1,43 @@
""" """
Subclass of :class:`linkml.generators.PydanticGenerator` Subclass of :class:`linkml.generators.PydanticGenerator`
customized to support NWB models.
The pydantic generator is a subclass of See class and module docstrings for details :)
- :class:`linkml.utils.generator.Generator`
- :class:`linkml.generators.oocodegen.OOCodeGenerator`
The default `__main__` method
- Instantiates the class
- Calls :meth:`~linkml.generators.PydanticGenerator.serialize`
The `serialize` method:
- Accepts an optional jinja-style template, otherwise it uses the default template
- Uses :class:`linkml_runtime.utils.schemaview.SchemaView` to interact with the schema
- Generates linkML Classes
- `generate_enums` runs first
.. note::
This module is heinous. We have mostly copied and pasted the existing :class:`linkml.generators.PydanticGenerator`
and overridden what we need to make this work for NWB, but the source is...
a little messy. We will be tidying this up and trying to pull changes upstream,
but for now this is just our hacky little secret.
""" """
# FIXME: Remove this after we refactor this generator
# ruff: noqa
import inspect
import pdb
import re import re
import sys import sys
import warnings
from copy import copy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import ClassVar, Dict, List, Optional, Tuple, Type, Union from typing import ClassVar, Dict, List, Optional, Tuple
from linkml.generators import PydanticGenerator from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.build import SlotResult, ClassResult
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
from linkml.generators.pydanticgen.template import PydanticModule, Import, Imports from linkml.generators.pydanticgen.build import ClassResult, SlotResult
from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule
from linkml_runtime.linkml_model.meta import ( from linkml_runtime.linkml_model.meta import (
Annotation,
AnonymousSlotExpression,
ArrayExpression, ArrayExpression,
ClassDefinition,
ClassDefinitionName,
ElementName,
SchemaDefinition, SchemaDefinition,
SlotDefinition, SlotDefinition,
SlotDefinitionName, SlotDefinitionName,
) )
from linkml_runtime.utils.compile_python import file_text from linkml_runtime.utils.compile_python import file_text
from linkml_runtime.utils.formatutils import camelcase, underscore, remove_empty_items from linkml_runtime.utils.formatutils import remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView from linkml_runtime.utils.schemaview import SchemaView
from pydantic import BaseModel
from nwb_linkml.maps import flat_to_nptyping
from nwb_linkml.maps.naming import module_case, version_module_case
from nwb_linkml.includes.types import ModelTypeString, _get_name, NamedString, NamedImports
from nwb_linkml.includes.hdmf import DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_INJECTS from nwb_linkml.includes.hdmf import DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_INJECTS
from nwb_linkml.includes.types import ModelTypeString, NamedImports, NamedString, _get_name
OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]") OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
@dataclass @dataclass
class NWBPydanticGenerator(PydanticGenerator): class NWBPydanticGenerator(PydanticGenerator):
"""
Subclass of pydantic generator, custom behavior is in overridden lifecycle methods :)
"""
injected_fields: List[str] = ( injected_fields: List[str] = (
( (
@ -96,7 +64,7 @@ class NWBPydanticGenerator(PydanticGenerator):
def _check_anyof( def _check_anyof(
self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView
): # pragma: no cover ) -> None: # pragma: no cover
""" """
Overridden to allow `array` in any_of Overridden to allow `array` in any_of
""" """
@ -108,7 +76,7 @@ class NWBPydanticGenerator(PydanticGenerator):
allowed = True allowed = True
for option in s.any_of: for option in s.any_of:
items = remove_empty_items(option) items = remove_empty_items(option)
if not all([key in allowed_keys for key in items.keys()]): if not all([key in allowed_keys for key in items]):
allowed = False allowed = False
if allowed: if allowed:
return return
@ -132,10 +100,14 @@ class NWBPydanticGenerator(PydanticGenerator):
return slot return slot
def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult: def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
"""Customize dynamictable behavior"""
cls = AfterGenerateClass.inject_dynamictable(cls) cls = AfterGenerateClass.inject_dynamictable(cls)
return cls return cls
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
"""
Remove source file from metadata
"""
if "source_file" in template.meta: if "source_file" in template.meta:
del template.meta["source_file"] del template.meta["source_file"]
return template return template
@ -167,6 +139,9 @@ class AfterGenerateSlot:
@staticmethod @staticmethod
def skip_meta(slot: SlotResult, skip_meta: tuple[str]) -> SlotResult: def skip_meta(slot: SlotResult, skip_meta: tuple[str]) -> SlotResult:
"""
Skip additional metadata slots
"""
for key in skip_meta: for key in skip_meta:
if key in slot.attribute.meta: if key in slot.attribute.meta:
del slot.attribute.meta[key] del slot.attribute.meta[key]
@ -242,6 +217,14 @@ class AfterGenerateClass:
@staticmethod @staticmethod
def inject_dynamictable(cls: ClassResult) -> ClassResult: def inject_dynamictable(cls: ClassResult) -> ClassResult:
"""
Modify dynamictable class bases and inject needed objects :)
Args:
cls:
Returns:
"""
if cls.cls.name == "DynamicTable": if cls.cls.name == "DynamicTable":
cls.cls.bases = ["DynamicTableMixin"] cls.cls.bases = ["DynamicTableMixin"]
@ -269,7 +252,8 @@ def compile_python(
""" """
Compile the text or file and return the resulting module Compile the text or file and return the resulting module
@param text_or_fn: Python text or file name that references python file @param text_or_fn: Python text or file name that references python file
@param package_path: Root package path. If omitted and we've got a python file, the package is the containing @param package_path: Root package path. If omitted and we've got a python file,
the package is the containing
directory directory
@return: Compiled module @return: Compiled module
""" """

View file

@ -2,10 +2,9 @@
Special types for mimicking HDMF special case behavior Special types for mimicking HDMF special case behavior
""" """
from typing import Any, ClassVar, Dict, List, Optional, Union, Tuple, overload, TYPE_CHECKING from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple, Union, overload
from linkml.generators.pydanticgen.template import Import, Imports, ObjectImport
from linkml.generators.pydanticgen.template import Imports, Import, ObjectImport
from numpydantic import NDArray from numpydantic import NDArray
from pandas import DataFrame from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
@ -133,7 +132,7 @@ class DynamicTableMixin(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def create_colnames(cls, model: Dict[str, Any]): def create_colnames(cls, model: Dict[str, Any]) -> None:
""" """
Construct colnames from arguments. Construct colnames from arguments.
@ -142,19 +141,17 @@ class DynamicTableMixin(BaseModel):
""" """
if "colnames" not in model: if "colnames" not in model:
colnames = [ colnames = [
k k for k in model if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index")
for k in model.keys()
if k not in cls.NON_COLUMN_FIELDS and not k.endswith("_index")
] ]
model["colnames"] = colnames model["colnames"] = colnames
else: else:
# add any columns not explicitly given an order at the end # add any columns not explicitly given an order at the end
colnames = [ colnames = [
k k
for k in model.keys() for k in model
if k not in cls.NON_COLUMN_FIELDS if k not in cls.NON_COLUMN_FIELDS
and not k.endswith("_index") and not k.endswith("_index")
and k not in model["colnames"].keys() and k not in model["colnames"]
] ]
model["colnames"].extend(colnames) model["colnames"].extend(colnames)
return model return model
@ -171,13 +168,11 @@ class DynamicTableMixin(BaseModel):
for field_name in self.model_fields_set: for field_name in self.model_fields_set:
# implicit name-based index # implicit name-based index
field = getattr(self, field_name) field = getattr(self, field_name)
if isinstance(field, VectorIndex): if isinstance(field, VectorIndex) and (
if field_name == f"{key}_index": field_name == f"{key}_index" or field.target is col
idx = field ):
break idx = field
elif field.target is col: break
idx = field
break
if idx is not None: if idx is not None:
col._index = idx col._index = idx
idx.target = col idx.target = col
@ -201,7 +196,7 @@ class VectorDataMixin(BaseModel):
else: else:
return self.array[item] return self.array[item]
def __setitem__(self, key, value) -> None: def __setitem__(self, key: Union[int, str, slice], value: Any) -> None:
if self._index: if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing # Following hdmf, VectorIndex is the thing that knows how to do the slicing
self._index[key] = value self._index[key] = value
@ -218,7 +213,7 @@ class VectorIndexMixin(BaseModel):
array: Optional[NDArray] = None array: Optional[NDArray] = None
target: Optional["VectorData"] = None target: Optional["VectorData"] = None
def _getitem_helper(self, arg: int): def _getitem_helper(self, arg: int) -> Union[list, NDArray]:
""" """
Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper` Mimicking :func:`hdmf.common.table.VectorIndex.__getitem_helper`
""" """
@ -239,7 +234,7 @@ class VectorIndexMixin(BaseModel):
else: else:
raise NotImplementedError("DynamicTableRange not supported yet") raise NotImplementedError("DynamicTableRange not supported yet")
def __setitem__(self, key, value) -> None: def __setitem__(self, key: Union[int, slice], value: Any) -> None:
if self._index: if self._index:
# VectorIndex is the thing that knows how to do the slicing # VectorIndex is the thing that knows how to do the slicing
self._index[key] = value self._index[key] = value

View file

@ -3,9 +3,9 @@ Provider for LinkML schema built from NWB schema
""" """
import shutil import shutil
from pathlib import Path
from typing import Dict, Optional, TypedDict
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional
from linkml_runtime import SchemaView from linkml_runtime import SchemaView
from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.dumpers import yaml_dumper

View file

@ -15,9 +15,9 @@ from linkml_runtime.linkml_model import (
) )
from nwb_linkml.adapters.namespaces import NamespacesAdapter from nwb_linkml.adapters.namespaces import NamespacesAdapter
from nwb_linkml.io import schema as io
from nwb_linkml.providers import LinkMLProvider, PydanticProvider from nwb_linkml.providers import LinkMLProvider, PydanticProvider
from nwb_linkml.providers.linkml import LinkMLSchemaBuild from nwb_linkml.providers.linkml import LinkMLSchemaBuild
from nwb_linkml.io import schema as io
from nwb_schema_language import Attribute, Dataset, Group from nwb_schema_language import Attribute, Dataset, Group
__all__ = [ __all__ = [

View file

@ -1,5 +1,4 @@
from typing import Tuple, TYPE_CHECKING from typing import Tuple
from types import ModuleType
import numpy as np import numpy as np
import pytest import pytest