Updating model generation methods to make both loader tests and hdmf include unit tests pass (pending following model update commit)

This commit is contained in:
sneakers-the-rat 2024-09-11 15:44:57 -07:00
parent 27b5dddfdd
commit 000ddde000
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
12 changed files with 88 additions and 140 deletions

View file

@ -2,6 +2,7 @@
Base class for adapters Base class for adapters
""" """
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -101,6 +102,19 @@ class Adapter(BaseModel):
"""Abstract base class for adapters""" """Abstract base class for adapters"""
_logger: Optional[Logger] = None _logger: Optional[Logger] = None
_debug: Optional[bool] = None
@property
def debug(self) -> bool:
"""
Whether we are in debug mode, which adds extra metadata in generated elements.
Set explicitly via ``_debug`` , or else checks for the truthiness of the
environment variable ``NWB_LINKML_DEBUG``
"""
if self._debug is None:
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
return self._debug
@property @property
def logger(self) -> Logger: def logger(self) -> Logger:

View file

@ -10,7 +10,7 @@ from linkml_runtime.linkml_model.meta import SlotDefinition
from nwb_linkml.adapters.adapter import Adapter, BuildResult, is_1d from nwb_linkml.adapters.adapter import Adapter, BuildResult, is_1d
from nwb_linkml.adapters.array import ArrayAdapter from nwb_linkml.adapters.array import ArrayAdapter
from nwb_linkml.maps import Map from nwb_linkml.maps import Map
from nwb_linkml.maps.dtype import handle_dtype from nwb_linkml.maps.dtype import handle_dtype, inlined
from nwb_schema_language import Attribute from nwb_schema_language import Attribute
@ -104,6 +104,7 @@ class MapScalar(AttributeMap):
range=handle_dtype(attr.dtype), range=handle_dtype(attr.dtype),
description=attr.doc, description=attr.doc,
required=attr.required, required=attr.required,
inlined=inlined(attr.dtype),
**cls.handle_defaults(attr), **cls.handle_defaults(attr),
) )
return BuildResult(slots=[slot]) return BuildResult(slots=[slot])
@ -151,6 +152,7 @@ class MapArray(AttributeMap):
multivalued=multivalued, multivalued=multivalued,
description=attr.doc, description=attr.doc,
required=attr.required, required=attr.required,
inlined=inlined(attr.dtype),
**expressions, **expressions,
**cls.handle_defaults(attr), **cls.handle_defaults(attr),
) )
@ -171,7 +173,10 @@ class AttributeAdapter(Adapter):
Build the slot definitions, every attribute should have a map. Build the slot definitions, every attribute should have a map.
""" """
map = self.match() map = self.match()
return map.apply(self.cls) res = map.apply(self.cls)
if self.debug:
res = self._amend_debug(res, map)
return res
def match(self) -> Optional[Type[AttributeMap]]: def match(self) -> Optional[Type[AttributeMap]]:
""" """
@ -195,3 +200,13 @@ class AttributeAdapter(Adapter):
return None return None
else: else:
return matches[0] return matches[0]
def _amend_debug(
self, res: BuildResult, map: Optional[Type[AttributeMap]] = None
) -> BuildResult:
map_name = "None" if map is None else map.__name__
for cls in res.classes:
cls.annotations["attribute_map"] = {"tag": "attribute_map", "value": map_name}
for slot in res.slots:
slot.annotations["attribute_map"] = {"tag": "attribute_map", "value": map_name}
return res

View file

@ -2,7 +2,6 @@
Adapters to linkML classes Adapters to linkML classes
""" """
import os
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Type, TypeVar from typing import List, Optional, Type, TypeVar
@ -33,14 +32,6 @@ class ClassAdapter(Adapter):
cls: TI cls: TI
parent: Optional["ClassAdapter"] = None parent: Optional["ClassAdapter"] = None
_debug: Optional[bool] = None
@property
def debug(self) -> bool:
if self._debug is None:
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
return self._debug
@field_validator("cls", mode="before") @field_validator("cls", mode="before")
@classmethod @classmethod
def cast_from_string(cls, value: str | TI) -> TI: def cast_from_string(cls, value: str | TI) -> TI:
@ -260,6 +251,7 @@ class ClassAdapter(Adapter):
name=self._get_slot_name(), name=self._get_slot_name(),
description=self.cls.doc, description=self.cls.doc,
range=self._get_full_name(), range=self._get_full_name(),
inlined=True,
**QUANTITY_MAP[self.cls.quantity], **QUANTITY_MAP[self.cls.quantity],
) )
if self.debug: if self.debug:

View file

@ -147,6 +147,7 @@ class MapScalarAttributes(DatasetMap):
name: name:
name: name name: name
ifabsent: string(starting_time) ifabsent: string(starting_time)
identifier: true
range: string range: string
required: true required: true
equals_string: starting_time equals_string: starting_time
@ -245,6 +246,7 @@ class MapListlike(DatasetMap):
attributes: attributes:
name: name:
name: name name: name
identifier: true
range: string range: string
required: true required: true
value: value:
@ -257,6 +259,8 @@ class MapListlike(DatasetMap):
range: Image range: Image
required: true required: true
multivalued: true multivalued: true
inlined: true
inlined_as_list: true
tree_root: true tree_root: true
""" """
@ -386,13 +390,11 @@ class MapArraylike(DatasetMap):
- ``False`` - ``False``
""" """
dtype = handle_dtype(cls.dtype)
return ( return (
cls.name cls.name
and (all([cls.dims, cls.shape]) or cls.neurodata_type_inc == "VectorData") and (all([cls.dims, cls.shape]) or cls.neurodata_type_inc == "VectorData")
and not has_attrs(cls) and not has_attrs(cls)
and not is_compound(cls) and not is_compound(cls)
and dtype in flat_to_linkml
) )
@classmethod @classmethod
@ -420,6 +422,7 @@ class MapArraylike(DatasetMap):
range=handle_dtype(cls.dtype), range=handle_dtype(cls.dtype),
description=cls.doc, description=cls.doc,
required=cls.quantity not in ("*", "?"), required=cls.quantity not in ("*", "?"),
inlined=inlined(cls.dtype),
**expressions, **expressions,
) )
] ]
@ -484,6 +487,7 @@ class MapArrayLikeAttributes(DatasetMap):
attributes: attributes:
name: name:
name: name name: name
identifier: true
range: string range: string
required: true required: true
resolution: resolution:
@ -598,103 +602,6 @@ class MapClassRange(DatasetMap):
# DynamicTable special cases # DynamicTable special cases
# -------------------------------------------------- # --------------------------------------------------
class MapVectorClassRange(DatasetMap):
"""
Map a ``VectorData`` class that is a reference to another class as simply
a multivalued slot range, rather than an independent class
"""
@classmethod
def check(c, cls: Dataset) -> bool:
"""
Check that we are a VectorData object without any additional attributes
with a dtype that refers to another class
"""
dtype = handle_dtype(cls.dtype)
return (
cls.neurodata_type_inc == "VectorData"
and cls.name
and not has_attrs(cls)
and not (cls.shape or cls.dims)
and not is_compound(cls)
and dtype not in flat_to_linkml
)
@classmethod
def apply(
c, cls: Dataset, res: Optional[BuildResult] = None, name: Optional[str] = None
) -> BuildResult:
"""
Create a slot that replaces the base class just as a list[ClassRef]
"""
this_slot = SlotDefinition(
name=cls.name,
description=cls.doc,
multivalued=True,
range=handle_dtype(cls.dtype),
required=cls.quantity not in ("*", "?"),
)
res = BuildResult(slots=[this_slot])
return res
#
# class Map1DVector(DatasetMap):
# """
# ``VectorData`` is subclassed with a name but without dims or attributes,
# treat this as a normal 1D array slot that replaces any class that would be built for this
#
# eg. all the datasets in epoch.TimeIntervals:
#
# .. code-block:: yaml
#
# groups:
# - neurodata_type_def: TimeIntervals
# neurodata_type_inc: DynamicTable
# doc: A container for aggregating epoch data and the TimeSeries that each epoch applies
# to.
# datasets:
# - name: start_time
# neurodata_type_inc: VectorData
# dtype: float32
# doc: Start time of epoch, in seconds.
#
# """
#
# @classmethod
# def check(c, cls: Dataset) -> bool:
# """
# Check that we're a 1d VectorData class
# """
# return (
# cls.neurodata_type_inc == "VectorData"
# and not cls.dims
# and not cls.shape
# and not cls.attributes
# and not cls.neurodata_type_def
# and not is_compound(cls)
# and cls.name
# )
#
# @classmethod
# def apply(
# c, cls: Dataset, res: Optional[BuildResult] = None, name: Optional[str] = None
# ) -> BuildResult:
# """
# Return a simple multivalued slot
# """
# this_slot = SlotDefinition(
# name=cls.name,
# description=cls.doc,
# range=handle_dtype(cls.dtype),
# multivalued=True,
# )
# # No need to make a class for us, so we replace the existing build results
# res = BuildResult(slots=[this_slot])
# return res
class MapNVectors(DatasetMap): class MapNVectors(DatasetMap):
""" """
An unnamed container that indicates an arbitrary quantity of some other neurodata type. An unnamed container that indicates an arbitrary quantity of some other neurodata type.
@ -864,10 +771,7 @@ class DatasetAdapter(ClassAdapter):
return matches[0] return matches[0]
def _amend_debug(self, res: BuildResult, map: Optional[Type[DatasetMap]] = None) -> BuildResult: def _amend_debug(self, res: BuildResult, map: Optional[Type[DatasetMap]] = None) -> BuildResult:
if map is None: map_name = "None" if map is None else map.__name__
map_name = "None"
else:
map_name = map.__name__
for cls in res.classes: for cls in res.classes:
cls.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name} cls.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name}
for slot in res.slots: for slot in res.slots:

View file

@ -68,11 +68,17 @@ class GroupAdapter(ClassAdapter):
if not self.cls.links: if not self.cls.links:
return [] return []
annotations = [{"tag": "source_type", "value": "link"}]
if self.debug:
annotations.append({"tag": "group_adapter", "value": "link"})
slots = [ slots = [
SlotDefinition( SlotDefinition(
name=link.name, name=link.name,
any_of=[{"range": link.target_type}, {"range": "string"}], any_of=[{"range": link.target_type}, {"range": "string"}],
annotations=[{"tag": "source_type", "value": "link"}], annotations=annotations,
inlined=True,
**QUANTITY_MAP[link.quantity], **QUANTITY_MAP[link.quantity],
) )
for link in self.cls.links for link in self.cls.links

View file

@ -48,7 +48,16 @@ class NamespacesAdapter(Adapter):
need_imports = [] need_imports = []
for needed in ns_adapter.needed_imports.values(): for needed in ns_adapter.needed_imports.values():
need_imports.extend([n for n in needed if n not in ns_adapter.needed_imports]) # try to locate imports implied by the namespace schema,
# but are either not provided by the current namespace
# or are otherwise already provided in `imported` by the loader function
need_imports.extend(
[
n
for n in needed
if n not in ns_adapter.needed_imports and n not in ns_adapter.versions
]
)
for needed in need_imports: for needed in need_imports:
if needed in DEFAULT_REPOS: if needed in DEFAULT_REPOS:

View file

@ -11,7 +11,7 @@ import sys
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import ClassVar, Dict, List, Optional, Tuple, Literal from typing import Callable, ClassVar, Dict, List, Literal, Optional, Tuple
from linkml.generators import PydanticGenerator from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
@ -29,9 +29,9 @@ from linkml_runtime.utils.formatutils import remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView from linkml_runtime.utils.schemaview import SchemaView
from nwb_linkml.includes.base import ( from nwb_linkml.includes.base import (
BASEMODEL_GETITEM,
BASEMODEL_COERCE_VALUE,
BASEMODEL_COERCE_CHILD, BASEMODEL_COERCE_CHILD,
BASEMODEL_COERCE_VALUE,
BASEMODEL_GETITEM,
) )
from nwb_linkml.includes.hdmf import ( from nwb_linkml.includes.hdmf import (
DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_IMPORTS,
@ -265,7 +265,7 @@ class AfterGenerateClass:
Returns: Returns:
""" """
if cls.cls.name in "DynamicTable": if cls.cls.name == "DynamicTable":
cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"] cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"]
if cls.injected_classes is None: if cls.injected_classes is None:
@ -328,10 +328,7 @@ class AfterGenerateClass:
cls.cls.attributes[an_attr].range = "ElementIdentifiers" cls.cls.attributes[an_attr].range = "ElementIdentifiers"
return cls return cls
if an_attr.endswith("_index"): wrap_cls = "VectorIndex" if an_attr.endswith("_index") else "VectorData"
wrap_cls = "VectorIndex"
else:
wrap_cls = "VectorData"
cls.cls.attributes[an_attr].range = wrap_preserving_optional( cls.cls.attributes[an_attr].range = wrap_preserving_optional(
slot_range, wrap_cls slot_range, wrap_cls
@ -340,7 +337,9 @@ class AfterGenerateClass:
return cls return cls
@staticmethod @staticmethod
def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult: def inject_elementidentifiers(
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
) -> ClassResult:
""" """
Inject ElementIdentifiers into module that define dynamictables - Inject ElementIdentifiers into module that define dynamictables -
needed to handle ID columns needed to handle ID columns

View file

@ -26,7 +26,7 @@ BASEMODEL_COERCE_VALUE = """
except AttributeError: except AttributeError:
try: try:
return handler(v["value"]) return handler(v["value"])
except (KeyError, TypeError): except (IndexError, KeyError, TypeError):
raise e1 raise e1
""" """
@ -37,8 +37,13 @@ BASEMODEL_COERCE_CHILD = """
\"\"\"Recast parent classes into child classes\"\"\" \"\"\"Recast parent classes into child classes\"\"\"
if isinstance(v, BaseModel): if isinstance(v, BaseModel):
annotation = cls.model_fields[info.field_name].annotation annotation = cls.model_fields[info.field_name].annotation
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation while hasattr(annotation, "__args__"):
if issubclass(annotation, type(v)) and annotation is not type(v): annotation = annotation.__args__[0]
v = annotation(**{**v.__dict__, **v.__pydantic_extra__}) try:
if issubclass(annotation, type(v)) and annotation is not type(v):
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
except TypeError:
# fine, annotation is a non-class type like a TypeVar
pass
return v return v
""" """

View file

@ -288,14 +288,11 @@ class DynamicTableMixin(BaseModel):
continue continue
if not isinstance(val, (VectorData, VectorIndex)): if not isinstance(val, (VectorData, VectorIndex)):
try: try:
if key.endswith("_index"): to_cast = VectorIndex if key.endswith("_index") else VectorData
to_cast = VectorIndex
else:
to_cast = VectorData
if isinstance(val, dict): if isinstance(val, dict):
model[key] = to_cast(**val) model[key] = to_cast(**val)
else: else:
model[key] = VectorIndex(name=key, description="", value=val) model[key] = to_cast(name=key, description="", value=val)
except ValidationError as e: # pragma: no cover except ValidationError as e: # pragma: no cover
raise ValidationError.from_exception_data( raise ValidationError.from_exception_data(
title=f"field {key} cannot be cast to VectorData from {val}", title=f"field {key} cannot be cast to VectorData from {val}",
@ -388,6 +385,11 @@ class VectorDataMixin(BaseModel, Generic[T]):
# redefined in `VectorData`, but included here for testing and type checking # redefined in `VectorData`, but included here for testing and type checking
value: Optional[T] = None value: Optional[T] = None
def __init__(self, value: Optional[T] = None, **kwargs):
if value is not None and "value" not in kwargs:
kwargs["value"] = value
super().__init__(**kwargs)
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any: def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
if self._index: if self._index:
# Following hdmf, VectorIndex is the thing that knows how to do the slicing # Following hdmf, VectorIndex is the thing that knows how to do the slicing

View file

@ -151,7 +151,7 @@ def test_name_slot():
assert slot.name == "name" assert slot.name == "name"
assert slot.required assert slot.required
assert slot.range == "string" assert slot.range == "string"
assert slot.identifier is None assert slot.identifier
assert slot.ifabsent is None assert slot.ifabsent is None
assert slot.equals_string is None assert slot.equals_string is None
@ -160,7 +160,7 @@ def test_name_slot():
assert slot.name == "name" assert slot.name == "name"
assert slot.required assert slot.required
assert slot.range == "string" assert slot.range == "string"
assert slot.identifier is None assert slot.identifier
assert slot.ifabsent == "string(FixedName)" assert slot.ifabsent == "string(FixedName)"
assert slot.equals_string == "FixedName" assert slot.equals_string == "FixedName"

View file

@ -284,14 +284,14 @@ def test_dynamictable_assert_equal_length():
"existing_col": np.arange(10), "existing_col": np.arange(10),
"new_col_1": hdmf.VectorData(value=np.arange(11)), "new_col_1": hdmf.VectorData(value=np.arange(11)),
} }
with pytest.raises(ValidationError, match="Columns are not of equal length"): with pytest.raises(ValidationError, match="columns are not of equal length"):
_ = MyDT(**cols) _ = MyDT(**cols)
cols = { cols = {
"existing_col": np.arange(11), "existing_col": np.arange(11),
"new_col_1": hdmf.VectorData(value=np.arange(10)), "new_col_1": hdmf.VectorData(value=np.arange(10)),
} }
with pytest.raises(ValidationError, match="Columns are not of equal length"): with pytest.raises(ValidationError, match="columns are not of equal length"):
_ = MyDT(**cols) _ = MyDT(**cols)
# wrong lengths are fine as long as the index is good # wrong lengths are fine as long as the index is good
@ -308,7 +308,7 @@ def test_dynamictable_assert_equal_length():
"new_col_1": hdmf.VectorData(value=np.arange(100)), "new_col_1": hdmf.VectorData(value=np.arange(100)),
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5), "new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5),
} }
with pytest.raises(ValidationError, match="Columns are not of equal length"): with pytest.raises(ValidationError, match="columns are not of equal length"):
_ = MyDT(**cols) _ = MyDT(**cols)

View file

@ -8,9 +8,11 @@ from nwb_linkml.io.hdf5 import HDF5IO
def test_read_from_nwbfile(nwb_file): def test_read_from_nwbfile(nwb_file):
""" """
Read data from a pynwb HDF5 NWB file Read data from a pynwb HDF5 NWB file
Placeholder that just ensures that reads work and all pydantic models validate,
testing of correctness of read will happen elsewhere.
""" """
res = HDF5IO(nwb_file).read() res = HDF5IO(nwb_file).read()
res.model_dump_json()
def test_read_from_yaml(nwb_file): def test_read_from_yaml(nwb_file):