Updating model generation methods to make both loader tests and hdmf include unit tests pass (pending following model update commit)

This commit is contained in:
sneakers-the-rat 2024-09-11 15:44:57 -07:00
parent 27b5dddfdd
commit 000ddde000
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
12 changed files with 88 additions and 140 deletions

View file

@ -2,6 +2,7 @@
Base class for adapters
"""
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
@ -101,6 +102,19 @@ class Adapter(BaseModel):
"""Abstract base class for adapters"""
_logger: Optional[Logger] = None
_debug: Optional[bool] = None
@property
def debug(self) -> bool:
"""
Whether we are in debug mode, which adds extra metadata in generated elements.
Set explicitly via ``_debug`` , or else checks for the truthiness of the
environment variable ``NWB_LINKML_DEBUG``
"""
if self._debug is None:
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
return self._debug
@property
def logger(self) -> Logger:

View file

@ -10,7 +10,7 @@ from linkml_runtime.linkml_model.meta import SlotDefinition
from nwb_linkml.adapters.adapter import Adapter, BuildResult, is_1d
from nwb_linkml.adapters.array import ArrayAdapter
from nwb_linkml.maps import Map
from nwb_linkml.maps.dtype import handle_dtype
from nwb_linkml.maps.dtype import handle_dtype, inlined
from nwb_schema_language import Attribute
@ -104,6 +104,7 @@ class MapScalar(AttributeMap):
range=handle_dtype(attr.dtype),
description=attr.doc,
required=attr.required,
inlined=inlined(attr.dtype),
**cls.handle_defaults(attr),
)
return BuildResult(slots=[slot])
@ -151,6 +152,7 @@ class MapArray(AttributeMap):
multivalued=multivalued,
description=attr.doc,
required=attr.required,
inlined=inlined(attr.dtype),
**expressions,
**cls.handle_defaults(attr),
)
@ -171,7 +173,10 @@ class AttributeAdapter(Adapter):
Build the slot definitions, every attribute should have a map.
"""
map = self.match()
return map.apply(self.cls)
res = map.apply(self.cls)
if self.debug:
res = self._amend_debug(res, map)
return res
def match(self) -> Optional[Type[AttributeMap]]:
"""
@ -195,3 +200,13 @@ class AttributeAdapter(Adapter):
return None
else:
return matches[0]
def _amend_debug(
self, res: BuildResult, map: Optional[Type[AttributeMap]] = None
) -> BuildResult:
map_name = "None" if map is None else map.__name__
for cls in res.classes:
cls.annotations["attribute_map"] = {"tag": "attribute_map", "value": map_name}
for slot in res.slots:
slot.annotations["attribute_map"] = {"tag": "attribute_map", "value": map_name}
return res

View file

@ -2,7 +2,6 @@
Adapters to linkML classes
"""
import os
from abc import abstractmethod
from typing import List, Optional, Type, TypeVar
@ -33,14 +32,6 @@ class ClassAdapter(Adapter):
cls: TI
parent: Optional["ClassAdapter"] = None
_debug: Optional[bool] = None
@property
def debug(self) -> bool:
if self._debug is None:
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
return self._debug
@field_validator("cls", mode="before")
@classmethod
def cast_from_string(cls, value: str | TI) -> TI:
@ -260,6 +251,7 @@ class ClassAdapter(Adapter):
name=self._get_slot_name(),
description=self.cls.doc,
range=self._get_full_name(),
inlined=True,
**QUANTITY_MAP[self.cls.quantity],
)
if self.debug:

View file

@ -147,6 +147,7 @@ class MapScalarAttributes(DatasetMap):
name:
name: name
ifabsent: string(starting_time)
identifier: true
range: string
required: true
equals_string: starting_time
@ -245,6 +246,7 @@ class MapListlike(DatasetMap):
attributes:
name:
name: name
identifier: true
range: string
required: true
value:
@ -257,6 +259,8 @@ class MapListlike(DatasetMap):
range: Image
required: true
multivalued: true
inlined: true
inlined_as_list: true
tree_root: true
"""
@ -386,13 +390,11 @@ class MapArraylike(DatasetMap):
- ``False``
"""
dtype = handle_dtype(cls.dtype)
return (
cls.name
and (all([cls.dims, cls.shape]) or cls.neurodata_type_inc == "VectorData")
and not has_attrs(cls)
and not is_compound(cls)
and dtype in flat_to_linkml
)
@classmethod
@ -420,6 +422,7 @@ class MapArraylike(DatasetMap):
range=handle_dtype(cls.dtype),
description=cls.doc,
required=cls.quantity not in ("*", "?"),
inlined=inlined(cls.dtype),
**expressions,
)
]
@ -484,6 +487,7 @@ class MapArrayLikeAttributes(DatasetMap):
attributes:
name:
name: name
identifier: true
range: string
required: true
resolution:
@ -598,103 +602,6 @@ class MapClassRange(DatasetMap):
# DynamicTable special cases
# --------------------------------------------------
class MapVectorClassRange(DatasetMap):
"""
Map a ``VectorData`` class that is a reference to another class as simply
a multivalued slot range, rather than an independent class
"""
@classmethod
def check(c, cls: Dataset) -> bool:
"""
Check that we are a VectorData object without any additional attributes
with a dtype that refers to another class
"""
dtype = handle_dtype(cls.dtype)
return (
cls.neurodata_type_inc == "VectorData"
and cls.name
and not has_attrs(cls)
and not (cls.shape or cls.dims)
and not is_compound(cls)
and dtype not in flat_to_linkml
)
@classmethod
def apply(
c, cls: Dataset, res: Optional[BuildResult] = None, name: Optional[str] = None
) -> BuildResult:
"""
Create a slot that replaces the base class just as a list[ClassRef]
"""
this_slot = SlotDefinition(
name=cls.name,
description=cls.doc,
multivalued=True,
range=handle_dtype(cls.dtype),
required=cls.quantity not in ("*", "?"),
)
res = BuildResult(slots=[this_slot])
return res
#
# class Map1DVector(DatasetMap):
# """
# ``VectorData`` is subclassed with a name but without dims or attributes,
# treat this as a normal 1D array slot that replaces any class that would be built for this
#
# eg. all the datasets in epoch.TimeIntervals:
#
# .. code-block:: yaml
#
# groups:
# - neurodata_type_def: TimeIntervals
# neurodata_type_inc: DynamicTable
# doc: A container for aggregating epoch data and the TimeSeries that each epoch applies
# to.
# datasets:
# - name: start_time
# neurodata_type_inc: VectorData
# dtype: float32
# doc: Start time of epoch, in seconds.
#
# """
#
# @classmethod
# def check(c, cls: Dataset) -> bool:
# """
# Check that we're a 1d VectorData class
# """
# return (
# cls.neurodata_type_inc == "VectorData"
# and not cls.dims
# and not cls.shape
# and not cls.attributes
# and not cls.neurodata_type_def
# and not is_compound(cls)
# and cls.name
# )
#
# @classmethod
# def apply(
# c, cls: Dataset, res: Optional[BuildResult] = None, name: Optional[str] = None
# ) -> BuildResult:
# """
# Return a simple multivalued slot
# """
# this_slot = SlotDefinition(
# name=cls.name,
# description=cls.doc,
# range=handle_dtype(cls.dtype),
# multivalued=True,
# )
# # No need to make a class for us, so we replace the existing build results
# res = BuildResult(slots=[this_slot])
# return res
class MapNVectors(DatasetMap):
"""
An unnamed container that indicates an arbitrary quantity of some other neurodata type.
@ -864,10 +771,7 @@ class DatasetAdapter(ClassAdapter):
return matches[0]
def _amend_debug(self, res: BuildResult, map: Optional[Type[DatasetMap]] = None) -> BuildResult:
if map is None:
map_name = "None"
else:
map_name = map.__name__
map_name = "None" if map is None else map.__name__
for cls in res.classes:
cls.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name}
for slot in res.slots:

View file

@ -68,11 +68,17 @@ class GroupAdapter(ClassAdapter):
if not self.cls.links:
return []
annotations = [{"tag": "source_type", "value": "link"}]
if self.debug:
annotations.append({"tag": "group_adapter", "value": "link"})
slots = [
SlotDefinition(
name=link.name,
any_of=[{"range": link.target_type}, {"range": "string"}],
annotations=[{"tag": "source_type", "value": "link"}],
annotations=annotations,
inlined=True,
**QUANTITY_MAP[link.quantity],
)
for link in self.cls.links

View file

@ -48,7 +48,16 @@ class NamespacesAdapter(Adapter):
need_imports = []
for needed in ns_adapter.needed_imports.values():
need_imports.extend([n for n in needed if n not in ns_adapter.needed_imports])
# try to locate imports implied by the namespace schema,
# but are either not provided by the current namespace
# or are otherwise already provided in `imported` by the loader function
need_imports.extend(
[
n
for n in needed
if n not in ns_adapter.needed_imports and n not in ns_adapter.versions
]
)
for needed in need_imports:
if needed in DEFAULT_REPOS:

View file

@ -11,7 +11,7 @@ import sys
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from typing import ClassVar, Dict, List, Optional, Tuple, Literal
from typing import Callable, ClassVar, Dict, List, Literal, Optional, Tuple
from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
@ -29,9 +29,9 @@ from linkml_runtime.utils.formatutils import remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView
from nwb_linkml.includes.base import (
BASEMODEL_GETITEM,
BASEMODEL_COERCE_VALUE,
BASEMODEL_COERCE_CHILD,
BASEMODEL_COERCE_VALUE,
BASEMODEL_GETITEM,
)
from nwb_linkml.includes.hdmf import (
DYNAMIC_TABLE_IMPORTS,
@ -265,7 +265,7 @@ class AfterGenerateClass:
Returns:
"""
if cls.cls.name in "DynamicTable":
if cls.cls.name == "DynamicTable":
cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"]
if cls.injected_classes is None:
@ -328,10 +328,7 @@ class AfterGenerateClass:
cls.cls.attributes[an_attr].range = "ElementIdentifiers"
return cls
if an_attr.endswith("_index"):
wrap_cls = "VectorIndex"
else:
wrap_cls = "VectorData"
wrap_cls = "VectorIndex" if an_attr.endswith("_index") else "VectorData"
cls.cls.attributes[an_attr].range = wrap_preserving_optional(
slot_range, wrap_cls
@ -340,7 +337,9 @@ class AfterGenerateClass:
return cls
@staticmethod
def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult:
def inject_elementidentifiers(
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
) -> ClassResult:
"""
Inject ElementIdentifiers into module that define dynamictables -
needed to handle ID columns

View file

@ -26,7 +26,7 @@ BASEMODEL_COERCE_VALUE = """
except AttributeError:
try:
return handler(v["value"])
except (KeyError, TypeError):
except (IndexError, KeyError, TypeError):
raise e1
"""
@ -37,8 +37,13 @@ BASEMODEL_COERCE_CHILD = """
\"\"\"Recast parent classes into child classes\"\"\"
if isinstance(v, BaseModel):
annotation = cls.model_fields[info.field_name].annotation
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
if issubclass(annotation, type(v)) and annotation is not type(v):
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
while hasattr(annotation, "__args__"):
annotation = annotation.__args__[0]
try:
if issubclass(annotation, type(v)) and annotation is not type(v):
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
except TypeError:
# fine, annotation is a non-class type like a TypeVar
pass
return v
"""

View file

@ -288,14 +288,11 @@ class DynamicTableMixin(BaseModel):
continue
if not isinstance(val, (VectorData, VectorIndex)):
try:
if key.endswith("_index"):
to_cast = VectorIndex
else:
to_cast = VectorData
to_cast = VectorIndex if key.endswith("_index") else VectorData
if isinstance(val, dict):
model[key] = to_cast(**val)
else:
model[key] = VectorIndex(name=key, description="", value=val)
model[key] = to_cast(name=key, description="", value=val)
except ValidationError as e: # pragma: no cover
raise ValidationError.from_exception_data(
title=f"field {key} cannot be cast to VectorData from {val}",
@ -388,6 +385,11 @@ class VectorDataMixin(BaseModel, Generic[T]):
# redefined in `VectorData`, but included here for testing and type checking
value: Optional[T] = None
def __init__(self, value: Optional[T] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing

View file

@ -151,7 +151,7 @@ def test_name_slot():
assert slot.name == "name"
assert slot.required
assert slot.range == "string"
assert slot.identifier is None
assert slot.identifier
assert slot.ifabsent is None
assert slot.equals_string is None
@ -160,7 +160,7 @@ def test_name_slot():
assert slot.name == "name"
assert slot.required
assert slot.range == "string"
assert slot.identifier is None
assert slot.identifier
assert slot.ifabsent == "string(FixedName)"
assert slot.equals_string == "FixedName"

View file

@ -284,14 +284,14 @@ def test_dynamictable_assert_equal_length():
"existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(11)),
}
with pytest.raises(ValidationError, match="Columns are not of equal length"):
with pytest.raises(ValidationError, match="columns are not of equal length"):
_ = MyDT(**cols)
cols = {
"existing_col": np.arange(11),
"new_col_1": hdmf.VectorData(value=np.arange(10)),
}
with pytest.raises(ValidationError, match="Columns are not of equal length"):
with pytest.raises(ValidationError, match="columns are not of equal length"):
_ = MyDT(**cols)
# wrong lengths are fine as long as the index is good
@ -308,7 +308,7 @@ def test_dynamictable_assert_equal_length():
"new_col_1": hdmf.VectorData(value=np.arange(100)),
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5),
}
with pytest.raises(ValidationError, match="Columns are not of equal length"):
with pytest.raises(ValidationError, match="columns are not of equal length"):
_ = MyDT(**cols)

View file

@ -8,9 +8,11 @@ from nwb_linkml.io.hdf5 import HDF5IO
def test_read_from_nwbfile(nwb_file):
"""
Read data from a pynwb HDF5 NWB file
Placeholder that just ensures that reads work and all pydantic models validate,
testing of correctness of read will happen elsewhere.
"""
res = HDF5IO(nwb_file).read()
res.model_dump_json()
def test_read_from_yaml(nwb_file):