updating model generation methods, still some models being converted to str instead of being inlined, but almost there

This commit is contained in:
sneakers-the-rat 2024-09-04 00:04:21 -07:00
parent 8078492f90
commit 27b5dddfdd
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
9 changed files with 182 additions and 57 deletions

View file

@ -2,6 +2,7 @@
Adapters to linkML classes
"""
import os
from abc import abstractmethod
from typing import List, Optional, Type, TypeVar
@ -32,6 +33,14 @@ class ClassAdapter(Adapter):
cls: TI
parent: Optional["ClassAdapter"] = None
_debug: Optional[bool] = None
@property
def debug(self) -> bool:
if self._debug is None:
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
return self._debug
@field_validator("cls", mode="before")
@classmethod
def cast_from_string(cls, value: str | TI) -> TI:
@ -92,6 +101,13 @@ class ClassAdapter(Adapter):
# Get vanilla top-level attributes
kwargs["attributes"].extend(self.build_attrs(self.cls))
if self.debug:
kwargs["annotations"] = {}
kwargs["annotations"]["group_adapter"] = {
"tag": "group_adapter",
"value": "container_slot",
}
if extra_attrs is not None:
if isinstance(extra_attrs, SlotDefinition):
extra_attrs = [extra_attrs]
@ -230,18 +246,22 @@ class ClassAdapter(Adapter):
ifabsent=f"string({name})",
equals_string=equals_string,
range="string",
identifier=True,
)
else:
name_slot = SlotDefinition(name="name", required=True, range="string")
name_slot = SlotDefinition(name="name", required=True, range="string", identifier=True)
return name_slot
def build_self_slot(self) -> SlotDefinition:
"""
If we are a child class, we make a slot so our parent can refer to us
"""
return SlotDefinition(
slot = SlotDefinition(
name=self._get_slot_name(),
description=self.cls.doc,
range=self._get_full_name(),
**QUANTITY_MAP[self.cls.quantity],
)
if self.debug:
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "self_slot"}
return slot

View file

@ -11,7 +11,7 @@ from nwb_linkml.adapters.adapter import BuildResult, has_attrs, is_1d, is_compou
from nwb_linkml.adapters.array import ArrayAdapter
from nwb_linkml.adapters.classes import ClassAdapter
from nwb_linkml.maps import QUANTITY_MAP, Map
from nwb_linkml.maps.dtype import flat_to_linkml, handle_dtype
from nwb_linkml.maps.dtype import flat_to_linkml, handle_dtype, inlined
from nwb_linkml.maps.naming import camel_to_snake
from nwb_schema_language import Dataset
@ -299,6 +299,8 @@ class MapListlike(DatasetMap):
description=cls.doc,
required=cls.quantity not in ("*", "?"),
annotations=[{"source_type": "reference"}],
inlined=True,
inlined_as_list=True,
)
res.classes[0].attributes["value"] = slot
return res
@ -544,7 +546,9 @@ class MapArrayLikeAttributes(DatasetMap):
array_adapter = ArrayAdapter(cls.dims, cls.shape)
expressions = array_adapter.make_slot()
# make a slot for the arraylike class
array_slot = SlotDefinition(name="value", range=handle_dtype(cls.dtype), **expressions)
array_slot = SlotDefinition(
name="value", range=handle_dtype(cls.dtype), inlined=inlined(cls.dtype), **expressions
)
res.classes[0].attributes.update({"value": array_slot})
return res
@ -583,6 +587,7 @@ class MapClassRange(DatasetMap):
description=cls.doc,
range=f"{cls.neurodata_type_inc}",
annotations=[{"named": True}, {"source_type": "neurodata_type_inc"}],
inlined=True,
**QUANTITY_MAP[cls.quantity],
)
res = BuildResult(slots=[this_slot])
@ -799,6 +804,7 @@ class MapCompoundDtype(DatasetMap):
description=a_dtype.doc,
range=handle_dtype(a_dtype.dtype),
array=ArrayExpression(exact_number_dimensions=1),
inlined=inlined(a_dtype.dtype),
**QUANTITY_MAP[cls.quantity],
)
res.classes[0].attributes.update(slots)
@ -830,6 +836,8 @@ class DatasetAdapter(ClassAdapter):
if map is not None:
res = map.apply(self.cls, res, self._get_full_name())
if self.debug:
res = self._amend_debug(res, map)
return res
def match(self) -> Optional[Type[DatasetMap]]:
@ -855,11 +863,13 @@ class DatasetAdapter(ClassAdapter):
else:
return matches[0]
def special_cases(self, res: BuildResult) -> BuildResult:
"""
Apply special cases to build result
"""
res = self._datetime_or_str(res)
def _datetime_or_str(self, res: BuildResult) -> BuildResult:
"""HDF5 doesn't support datetime, so"""
def _amend_debug(self, res: BuildResult, map: Optional[Type[DatasetMap]] = None) -> BuildResult:
if map is None:
map_name = "None"
else:
map_name = map.__name__
for cls in res.classes:
cls.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name}
for slot in res.slots:
slot.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name}
return res

View file

@ -111,6 +111,9 @@ class GroupAdapter(ClassAdapter):
inlined_as_list=False,
)
if self.debug:
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_group"}
if self.parent is not None:
# if we have a parent,
# just return the slot itself without the class
@ -144,17 +147,20 @@ class GroupAdapter(ClassAdapter):
"""
name = camel_to_snake(self.cls.neurodata_type_inc) if not self.cls.name else cls.name
return BuildResult(
slots=[
SlotDefinition(
name=name,
range=self.cls.neurodata_type_inc,
description=self.cls.doc,
**QUANTITY_MAP[cls.quantity],
)
]
slot = SlotDefinition(
name=name,
range=self.cls.neurodata_type_inc,
description=self.cls.doc,
inlined=True,
inlined_as_list=False,
**QUANTITY_MAP[cls.quantity],
)
if self.debug:
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_slot"}
return BuildResult(slots=[slot])
def build_subclasses(self) -> BuildResult:
"""
Build nested groups and datasets
@ -166,20 +172,9 @@ class GroupAdapter(ClassAdapter):
# for creating slots vs. classes is handled by the adapter class
dataset_res = BuildResult()
for dset in self.cls.datasets:
# if dset.name == 'timestamps':
# pdb.set_trace()
dset_adapter = DatasetAdapter(cls=dset, parent=self)
dataset_res += dset_adapter.build()
# Actually i'm not sure we have to special case this, we could handle it in
# i/o instead
# Groups are a bit more complicated because they can also behave like
# range declarations:
# eg. a group can have multiple groups with `neurodata_type_inc`, no name,
# and quantity of *,
# the group can then contain any number of groups of those included types as direct children
group_res = BuildResult()
for group in self.cls.groups:
@ -190,6 +185,33 @@ class GroupAdapter(ClassAdapter):
return res
def build_self_slot(self) -> SlotDefinition:
"""
If we are a child class, we make a slot so our parent can refer to us
Groups are a bit more complicated because they can also behave like
range declarations:
eg. a group can have multiple groups with `neurodata_type_inc`, no name,
and quantity of *,
the group can then contain any number of groups of those included types as direct children
We make sure that we're inlined as a dict so our parent class can refer to us like::
parent.{slot_name}[{name}] = self
"""
slot = SlotDefinition(
name=self._get_slot_name(),
description=self.cls.doc,
range=self._get_full_name(),
inlined=True,
inlined_as_list=True,
**QUANTITY_MAP[self.cls.quantity],
)
if self.debug:
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_slot"}
return slot
def _check_if_container(self, group: Group) -> bool:
"""
Check if a given subgroup is a container subgroup,

View file

@ -5,6 +5,7 @@ customized to support NWB models.
See class and module docstrings for details :)
"""
import pdb
import re
import sys
from dataclasses import dataclass, field
@ -27,7 +28,11 @@ from linkml_runtime.utils.compile_python import file_text
from linkml_runtime.utils.formatutils import remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView
from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE
from nwb_linkml.includes.base import (
BASEMODEL_GETITEM,
BASEMODEL_COERCE_VALUE,
BASEMODEL_COERCE_CHILD,
)
from nwb_linkml.includes.hdmf import (
DYNAMIC_TABLE_IMPORTS,
DYNAMIC_TABLE_INJECTS,
@ -36,7 +41,7 @@ from nwb_linkml.includes.hdmf import (
)
from nwb_linkml.includes.types import ModelTypeString, NamedImports, NamedString, _get_name
OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
OPTIONAL_PATTERN = re.compile(r"Optional\[(.*)\]")
@dataclass
@ -53,6 +58,7 @@ class NWBPydanticGenerator(PydanticGenerator):
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
BASEMODEL_GETITEM,
BASEMODEL_COERCE_VALUE,
BASEMODEL_COERCE_CHILD,
)
split: bool = True
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
@ -134,6 +140,7 @@ class NWBPydanticGenerator(PydanticGenerator):
cls = AfterGenerateClass.inject_dynamictable(cls)
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
cls = AfterGenerateClass.strip_vector_data_slots(cls, sv)
return cls
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
@ -227,7 +234,8 @@ class AfterGenerateSlot:
"""
if "named" in slot.source.annotations and slot.source.annotations["named"].value:
slot.attribute.range = f"Named[{slot.attribute.range}]"
slot.attribute.range = wrap_preserving_optional(slot.attribute.range, "Named")
named_injects = [ModelTypeString, _get_name, NamedString]
if slot.injected_classes is None:
slot.injected_classes = named_injects
@ -325,7 +333,9 @@ class AfterGenerateClass:
else:
wrap_cls = "VectorData"
cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"])
cls.cls.attributes[an_attr].range = wrap_preserving_optional(
slot_range, wrap_cls
)
return cls
@ -343,6 +353,15 @@ class AfterGenerateClass:
cls.imports += [imp]
return cls
@staticmethod
def strip_vector_data_slots(cls: ClassResult, sv: SchemaView) -> ClassResult:
"""
Remove spurious ``vector_data`` slots from DynamicTables
"""
if "vector_data" in cls.cls.attributes:
del cls.cls.attributes["vector_data"]
return cls
def compile_python(
text_or_fn: str, package_path: Path = None, module_name: str = "test"
@ -364,3 +383,26 @@ def compile_python(
exec(spec, module.__dict__)
sys.modules[module_name] = module
return module
def wrap_preserving_optional(annotation: str, wrap: str) -> str:
"""
Add a wrapping type to a type annotation string,
preserving any `Optional[]` annotation, bumping it to the outside
Examples:
>>> wrap_preserving_optional('Optional[list[str]]', 'NewType')
'Optional[NewType[list[str]]]'
"""
is_optional = OPTIONAL_PATTERN.match(annotation)
if is_optional:
annotation = is_optional.groups()[0]
annotation = f"Optional[{wrap}[{annotation}]]"
else:
if "Optional" in annotation:
pdb.set_trace()
annotation = f"{wrap}[{annotation}]"
return annotation

View file

@ -23,8 +23,11 @@ BASEMODEL_COERCE_VALUE = """
except Exception as e1:
try:
return handler(v.value)
except:
raise e1
except AttributeError:
try:
return handler(v["value"])
except (KeyError, TypeError):
raise e1
"""
BASEMODEL_COERCE_CHILD = """
@ -32,6 +35,10 @@ BASEMODEL_COERCE_CHILD = """
@classmethod
def coerce_subclass(cls, v: Any, info) -> Any:
\"\"\"Recast parent classes into child classes\"\"\"
if isinstance(v, BaseModel):
annotation = cls.model_fields[info.field_name].annotation
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
if issubclass(annotation, type(v)) and annotation is not type(v):
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
return v
pdb.set_trace()
"""

View file

@ -249,6 +249,7 @@ class DynamicTableMixin(BaseModel):
if k not in cls.NON_COLUMN_FIELDS
and not k.endswith("_index")
and not isinstance(model[k], VectorIndexMixin)
and model[k] is not None
]
model["colnames"] = colnames
else:
@ -264,6 +265,7 @@ class DynamicTableMixin(BaseModel):
and not k.endswith("_index")
and k not in model["colnames"]
and not isinstance(model[k], VectorIndexMixin)
and model[k] is not None
]
)
model["colnames"] = colnames
@ -336,9 +338,9 @@ class DynamicTableMixin(BaseModel):
"""
Ensure that all columns are equal length
"""
lengths = [len(v) for v in self._columns.values()] + [len(self.id)]
lengths = [len(v) for v in self._columns.values() if v is not None] + [len(self.id)]
assert all([length == lengths[0] for length in lengths]), (
"Columns are not of equal length! "
"DynamicTable columns are not of equal length! "
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
)
return self
@ -386,11 +388,6 @@ class VectorDataMixin(BaseModel, Generic[T]):
# redefined in `VectorData`, but included here for testing and type checking
value: Optional[T] = None
# def __init__(self, value: Optional[NDArray] = None, **kwargs):
# if value is not None and "value" not in kwargs:
# kwargs["value"] = value
# super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
@ -587,6 +584,7 @@ class AlignedDynamicTableMixin(BaseModel):
__pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]]
NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = (
"id",
"name",
"categories",
"colnames",
@ -622,7 +620,7 @@ class AlignedDynamicTableMixin(BaseModel):
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str):
# get a slice of a single table
return self._categories[item[1]][item[0]]
elif isinstance(item, (int, slice, Iterable)):
elif isinstance(item, (int, slice, Iterable, np.int_)):
# get a slice of all the tables
ids = self.id[item]
if not isinstance(ids, Iterable):
@ -634,9 +632,9 @@ class AlignedDynamicTableMixin(BaseModel):
if isinstance(table, pd.DataFrame):
table = table.reset_index()
elif isinstance(table, np.ndarray):
table = pd.DataFrame({category_name: [table]})
table = pd.DataFrame({category_name: [table]}, index=ids.index)
elif isinstance(table, Iterable):
table = pd.DataFrame({category_name: table})
table = pd.DataFrame({category_name: table}, index=ids.index)
else:
raise ValueError(
f"Don't know how to construct category table for {category_name}"
@ -756,7 +754,7 @@ class AlignedDynamicTableMixin(BaseModel):
"""
lengths = [len(v) for v in self._categories.values()] + [len(self.id)]
assert all([length == lengths[0] for length in lengths]), (
"Columns are not of equal length! "
"AlignedDynamicTableColumns are not of equal length! "
f"Got colnames:\n{self.categories}\nand lengths: {lengths}"
)
return self

View file

@ -3,7 +3,7 @@ Dtype mappings
"""
from datetime import datetime
from typing import Any
from typing import Any, Optional
import numpy as np
@ -160,14 +160,28 @@ def handle_dtype(dtype: DTypeType | None) -> str:
elif isinstance(dtype, FlatDtype):
return dtype.value
elif isinstance(dtype, list) and isinstance(dtype[0], CompoundDtype):
# there is precisely one class that uses compound dtypes:
# TimeSeriesReferenceVectorData
# compoundDtypes are able to define a ragged table according to the schema
# but are used in this single case equivalently to attributes.
# so we'll... uh... treat them as slots.
# TODO
# Compound Dtypes are handled by the MapCompoundDtype dataset map,
# but this function is also used within ``check`` methods, so we should always
# return something from it rather than raise
return "AnyType"
else:
# flat dtype
return dtype
def inlined(dtype: DTypeType | None) -> Optional[bool]:
"""
Check if a slot should be inlined based on its dtype
for now that is equivalent to checking whether that dtype is another a reference dtype,
but the function remains semantically reserved for answering this question w.r.t. dtype.
Returns ``None`` if not inlined to not clutter generated models with unnecessary props
"""
return (
True
if isinstance(dtype, ReferenceDtype)
or (isinstance(dtype, CompoundDtype) and isinstance(dtype.dtype, ReferenceDtype))
else None
)

View file

@ -10,6 +10,7 @@ def test_read_from_nwbfile(nwb_file):
Read data from a pynwb HDF5 NWB file
"""
res = HDF5IO(nwb_file).read()
res.model_dump_json()
def test_read_from_yaml(nwb_file):

View file

@ -222,6 +222,11 @@ def parser() -> ArgumentParser:
),
action="store_true",
)
parser.add_argument(
"--debug",
help="Add annotations to generated schema that indicate how they were generated",
action="store_true",
)
parser.add_argument("--pdb", help="Launch debugger on an error", action="store_true")
return parser
@ -229,6 +234,12 @@ def parser() -> ArgumentParser:
def main():
args = parser().parse_args()
if args.debug:
os.environ["NWB_LINKML_DEBUG"] = "true"
else:
if "NWB_LINKML_DEBUG" in os.environ:
del os.environ["NWB_LINKML_DEBUG"]
tmp_dir = make_tmp_dir(clear=True)
git_dir = tmp_dir / "git"
git_dir.mkdir(exist_ok=True)