mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-09 05:34:28 +00:00
Updating model generation methods to make both loader tests and hdmf include unit tests pass (pending following model update commit)
This commit is contained in:
parent
27b5dddfdd
commit
000ddde000
12 changed files with 88 additions and 140 deletions
|
@ -2,6 +2,7 @@
|
|||
Base class for adapters
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
@ -101,6 +102,19 @@ class Adapter(BaseModel):
|
|||
"""Abstract base class for adapters"""
|
||||
|
||||
_logger: Optional[Logger] = None
|
||||
_debug: Optional[bool] = None
|
||||
|
||||
@property
|
||||
def debug(self) -> bool:
|
||||
"""
|
||||
Whether we are in debug mode, which adds extra metadata in generated elements.
|
||||
|
||||
Set explicitly via ``_debug`` , or else checks for the truthiness of the
|
||||
environment variable ``NWB_LINKML_DEBUG``
|
||||
"""
|
||||
if self._debug is None:
|
||||
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
|
||||
return self._debug
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
|
|
|
@ -10,7 +10,7 @@ from linkml_runtime.linkml_model.meta import SlotDefinition
|
|||
from nwb_linkml.adapters.adapter import Adapter, BuildResult, is_1d
|
||||
from nwb_linkml.adapters.array import ArrayAdapter
|
||||
from nwb_linkml.maps import Map
|
||||
from nwb_linkml.maps.dtype import handle_dtype
|
||||
from nwb_linkml.maps.dtype import handle_dtype, inlined
|
||||
from nwb_schema_language import Attribute
|
||||
|
||||
|
||||
|
@ -104,6 +104,7 @@ class MapScalar(AttributeMap):
|
|||
range=handle_dtype(attr.dtype),
|
||||
description=attr.doc,
|
||||
required=attr.required,
|
||||
inlined=inlined(attr.dtype),
|
||||
**cls.handle_defaults(attr),
|
||||
)
|
||||
return BuildResult(slots=[slot])
|
||||
|
@ -151,6 +152,7 @@ class MapArray(AttributeMap):
|
|||
multivalued=multivalued,
|
||||
description=attr.doc,
|
||||
required=attr.required,
|
||||
inlined=inlined(attr.dtype),
|
||||
**expressions,
|
||||
**cls.handle_defaults(attr),
|
||||
)
|
||||
|
@ -171,7 +173,10 @@ class AttributeAdapter(Adapter):
|
|||
Build the slot definitions, every attribute should have a map.
|
||||
"""
|
||||
map = self.match()
|
||||
return map.apply(self.cls)
|
||||
res = map.apply(self.cls)
|
||||
if self.debug:
|
||||
res = self._amend_debug(res, map)
|
||||
return res
|
||||
|
||||
def match(self) -> Optional[Type[AttributeMap]]:
|
||||
"""
|
||||
|
@ -195,3 +200,13 @@ class AttributeAdapter(Adapter):
|
|||
return None
|
||||
else:
|
||||
return matches[0]
|
||||
|
||||
def _amend_debug(
|
||||
self, res: BuildResult, map: Optional[Type[AttributeMap]] = None
|
||||
) -> BuildResult:
|
||||
map_name = "None" if map is None else map.__name__
|
||||
for cls in res.classes:
|
||||
cls.annotations["attribute_map"] = {"tag": "attribute_map", "value": map_name}
|
||||
for slot in res.slots:
|
||||
slot.annotations["attribute_map"] = {"tag": "attribute_map", "value": map_name}
|
||||
return res
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
Adapters to linkML classes
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional, Type, TypeVar
|
||||
|
||||
|
@ -33,14 +32,6 @@ class ClassAdapter(Adapter):
|
|||
cls: TI
|
||||
parent: Optional["ClassAdapter"] = None
|
||||
|
||||
_debug: Optional[bool] = None
|
||||
|
||||
@property
|
||||
def debug(self) -> bool:
|
||||
if self._debug is None:
|
||||
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
|
||||
return self._debug
|
||||
|
||||
@field_validator("cls", mode="before")
|
||||
@classmethod
|
||||
def cast_from_string(cls, value: str | TI) -> TI:
|
||||
|
@ -260,6 +251,7 @@ class ClassAdapter(Adapter):
|
|||
name=self._get_slot_name(),
|
||||
description=self.cls.doc,
|
||||
range=self._get_full_name(),
|
||||
inlined=True,
|
||||
**QUANTITY_MAP[self.cls.quantity],
|
||||
)
|
||||
if self.debug:
|
||||
|
|
|
@ -147,6 +147,7 @@ class MapScalarAttributes(DatasetMap):
|
|||
name:
|
||||
name: name
|
||||
ifabsent: string(starting_time)
|
||||
identifier: true
|
||||
range: string
|
||||
required: true
|
||||
equals_string: starting_time
|
||||
|
@ -245,6 +246,7 @@ class MapListlike(DatasetMap):
|
|||
attributes:
|
||||
name:
|
||||
name: name
|
||||
identifier: true
|
||||
range: string
|
||||
required: true
|
||||
value:
|
||||
|
@ -257,6 +259,8 @@ class MapListlike(DatasetMap):
|
|||
range: Image
|
||||
required: true
|
||||
multivalued: true
|
||||
inlined: true
|
||||
inlined_as_list: true
|
||||
tree_root: true
|
||||
|
||||
"""
|
||||
|
@ -386,13 +390,11 @@ class MapArraylike(DatasetMap):
|
|||
- ``False``
|
||||
|
||||
"""
|
||||
dtype = handle_dtype(cls.dtype)
|
||||
return (
|
||||
cls.name
|
||||
and (all([cls.dims, cls.shape]) or cls.neurodata_type_inc == "VectorData")
|
||||
and not has_attrs(cls)
|
||||
and not is_compound(cls)
|
||||
and dtype in flat_to_linkml
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -420,6 +422,7 @@ class MapArraylike(DatasetMap):
|
|||
range=handle_dtype(cls.dtype),
|
||||
description=cls.doc,
|
||||
required=cls.quantity not in ("*", "?"),
|
||||
inlined=inlined(cls.dtype),
|
||||
**expressions,
|
||||
)
|
||||
]
|
||||
|
@ -484,6 +487,7 @@ class MapArrayLikeAttributes(DatasetMap):
|
|||
attributes:
|
||||
name:
|
||||
name: name
|
||||
identifier: true
|
||||
range: string
|
||||
required: true
|
||||
resolution:
|
||||
|
@ -598,103 +602,6 @@ class MapClassRange(DatasetMap):
|
|||
# DynamicTable special cases
|
||||
# --------------------------------------------------
|
||||
|
||||
|
||||
class MapVectorClassRange(DatasetMap):
|
||||
"""
|
||||
Map a ``VectorData`` class that is a reference to another class as simply
|
||||
a multivalued slot range, rather than an independent class
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def check(c, cls: Dataset) -> bool:
|
||||
"""
|
||||
Check that we are a VectorData object without any additional attributes
|
||||
with a dtype that refers to another class
|
||||
"""
|
||||
dtype = handle_dtype(cls.dtype)
|
||||
return (
|
||||
cls.neurodata_type_inc == "VectorData"
|
||||
and cls.name
|
||||
and not has_attrs(cls)
|
||||
and not (cls.shape or cls.dims)
|
||||
and not is_compound(cls)
|
||||
and dtype not in flat_to_linkml
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
c, cls: Dataset, res: Optional[BuildResult] = None, name: Optional[str] = None
|
||||
) -> BuildResult:
|
||||
"""
|
||||
Create a slot that replaces the base class just as a list[ClassRef]
|
||||
"""
|
||||
this_slot = SlotDefinition(
|
||||
name=cls.name,
|
||||
description=cls.doc,
|
||||
multivalued=True,
|
||||
range=handle_dtype(cls.dtype),
|
||||
required=cls.quantity not in ("*", "?"),
|
||||
)
|
||||
res = BuildResult(slots=[this_slot])
|
||||
return res
|
||||
|
||||
|
||||
#
|
||||
# class Map1DVector(DatasetMap):
|
||||
# """
|
||||
# ``VectorData`` is subclassed with a name but without dims or attributes,
|
||||
# treat this as a normal 1D array slot that replaces any class that would be built for this
|
||||
#
|
||||
# eg. all the datasets in epoch.TimeIntervals:
|
||||
#
|
||||
# .. code-block:: yaml
|
||||
#
|
||||
# groups:
|
||||
# - neurodata_type_def: TimeIntervals
|
||||
# neurodata_type_inc: DynamicTable
|
||||
# doc: A container for aggregating epoch data and the TimeSeries that each epoch applies
|
||||
# to.
|
||||
# datasets:
|
||||
# - name: start_time
|
||||
# neurodata_type_inc: VectorData
|
||||
# dtype: float32
|
||||
# doc: Start time of epoch, in seconds.
|
||||
#
|
||||
# """
|
||||
#
|
||||
# @classmethod
|
||||
# def check(c, cls: Dataset) -> bool:
|
||||
# """
|
||||
# Check that we're a 1d VectorData class
|
||||
# """
|
||||
# return (
|
||||
# cls.neurodata_type_inc == "VectorData"
|
||||
# and not cls.dims
|
||||
# and not cls.shape
|
||||
# and not cls.attributes
|
||||
# and not cls.neurodata_type_def
|
||||
# and not is_compound(cls)
|
||||
# and cls.name
|
||||
# )
|
||||
#
|
||||
# @classmethod
|
||||
# def apply(
|
||||
# c, cls: Dataset, res: Optional[BuildResult] = None, name: Optional[str] = None
|
||||
# ) -> BuildResult:
|
||||
# """
|
||||
# Return a simple multivalued slot
|
||||
# """
|
||||
# this_slot = SlotDefinition(
|
||||
# name=cls.name,
|
||||
# description=cls.doc,
|
||||
# range=handle_dtype(cls.dtype),
|
||||
# multivalued=True,
|
||||
# )
|
||||
# # No need to make a class for us, so we replace the existing build results
|
||||
# res = BuildResult(slots=[this_slot])
|
||||
# return res
|
||||
|
||||
|
||||
class MapNVectors(DatasetMap):
|
||||
"""
|
||||
An unnamed container that indicates an arbitrary quantity of some other neurodata type.
|
||||
|
@ -864,10 +771,7 @@ class DatasetAdapter(ClassAdapter):
|
|||
return matches[0]
|
||||
|
||||
def _amend_debug(self, res: BuildResult, map: Optional[Type[DatasetMap]] = None) -> BuildResult:
|
||||
if map is None:
|
||||
map_name = "None"
|
||||
else:
|
||||
map_name = map.__name__
|
||||
map_name = "None" if map is None else map.__name__
|
||||
for cls in res.classes:
|
||||
cls.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name}
|
||||
for slot in res.slots:
|
||||
|
|
|
@ -68,11 +68,17 @@ class GroupAdapter(ClassAdapter):
|
|||
if not self.cls.links:
|
||||
return []
|
||||
|
||||
annotations = [{"tag": "source_type", "value": "link"}]
|
||||
|
||||
if self.debug:
|
||||
annotations.append({"tag": "group_adapter", "value": "link"})
|
||||
|
||||
slots = [
|
||||
SlotDefinition(
|
||||
name=link.name,
|
||||
any_of=[{"range": link.target_type}, {"range": "string"}],
|
||||
annotations=[{"tag": "source_type", "value": "link"}],
|
||||
annotations=annotations,
|
||||
inlined=True,
|
||||
**QUANTITY_MAP[link.quantity],
|
||||
)
|
||||
for link in self.cls.links
|
||||
|
|
|
@ -48,7 +48,16 @@ class NamespacesAdapter(Adapter):
|
|||
|
||||
need_imports = []
|
||||
for needed in ns_adapter.needed_imports.values():
|
||||
need_imports.extend([n for n in needed if n not in ns_adapter.needed_imports])
|
||||
# try to locate imports implied by the namespace schema,
|
||||
# but are either not provided by the current namespace
|
||||
# or are otherwise already provided in `imported` by the loader function
|
||||
need_imports.extend(
|
||||
[
|
||||
n
|
||||
for n in needed
|
||||
if n not in ns_adapter.needed_imports and n not in ns_adapter.versions
|
||||
]
|
||||
)
|
||||
|
||||
for needed in need_imports:
|
||||
if needed in DEFAULT_REPOS:
|
||||
|
|
|
@ -11,7 +11,7 @@ import sys
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import ClassVar, Dict, List, Optional, Tuple, Literal
|
||||
from typing import Callable, ClassVar, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from linkml.generators import PydanticGenerator
|
||||
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
|
||||
|
@ -29,9 +29,9 @@ from linkml_runtime.utils.formatutils import remove_empty_items
|
|||
from linkml_runtime.utils.schemaview import SchemaView
|
||||
|
||||
from nwb_linkml.includes.base import (
|
||||
BASEMODEL_GETITEM,
|
||||
BASEMODEL_COERCE_VALUE,
|
||||
BASEMODEL_COERCE_CHILD,
|
||||
BASEMODEL_COERCE_VALUE,
|
||||
BASEMODEL_GETITEM,
|
||||
)
|
||||
from nwb_linkml.includes.hdmf import (
|
||||
DYNAMIC_TABLE_IMPORTS,
|
||||
|
@ -265,7 +265,7 @@ class AfterGenerateClass:
|
|||
Returns:
|
||||
|
||||
"""
|
||||
if cls.cls.name in "DynamicTable":
|
||||
if cls.cls.name == "DynamicTable":
|
||||
cls.cls.bases = ["DynamicTableMixin", "ConfiguredBaseModel"]
|
||||
|
||||
if cls.injected_classes is None:
|
||||
|
@ -328,10 +328,7 @@ class AfterGenerateClass:
|
|||
cls.cls.attributes[an_attr].range = "ElementIdentifiers"
|
||||
return cls
|
||||
|
||||
if an_attr.endswith("_index"):
|
||||
wrap_cls = "VectorIndex"
|
||||
else:
|
||||
wrap_cls = "VectorData"
|
||||
wrap_cls = "VectorIndex" if an_attr.endswith("_index") else "VectorData"
|
||||
|
||||
cls.cls.attributes[an_attr].range = wrap_preserving_optional(
|
||||
slot_range, wrap_cls
|
||||
|
@ -340,7 +337,9 @@ class AfterGenerateClass:
|
|||
return cls
|
||||
|
||||
@staticmethod
|
||||
def inject_elementidentifiers(cls: ClassResult, sv: SchemaView, import_method) -> ClassResult:
|
||||
def inject_elementidentifiers(
|
||||
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
|
||||
) -> ClassResult:
|
||||
"""
|
||||
Inject ElementIdentifiers into module that define dynamictables -
|
||||
needed to handle ID columns
|
||||
|
|
|
@ -26,7 +26,7 @@ BASEMODEL_COERCE_VALUE = """
|
|||
except AttributeError:
|
||||
try:
|
||||
return handler(v["value"])
|
||||
except (KeyError, TypeError):
|
||||
except (IndexError, KeyError, TypeError):
|
||||
raise e1
|
||||
"""
|
||||
|
||||
|
@ -37,8 +37,13 @@ BASEMODEL_COERCE_CHILD = """
|
|||
\"\"\"Recast parent classes into child classes\"\"\"
|
||||
if isinstance(v, BaseModel):
|
||||
annotation = cls.model_fields[info.field_name].annotation
|
||||
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
|
||||
if issubclass(annotation, type(v)) and annotation is not type(v):
|
||||
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
|
||||
while hasattr(annotation, "__args__"):
|
||||
annotation = annotation.__args__[0]
|
||||
try:
|
||||
if issubclass(annotation, type(v)) and annotation is not type(v):
|
||||
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
|
||||
except TypeError:
|
||||
# fine, annotation is a non-class type like a TypeVar
|
||||
pass
|
||||
return v
|
||||
"""
|
||||
|
|
|
@ -288,14 +288,11 @@ class DynamicTableMixin(BaseModel):
|
|||
continue
|
||||
if not isinstance(val, (VectorData, VectorIndex)):
|
||||
try:
|
||||
if key.endswith("_index"):
|
||||
to_cast = VectorIndex
|
||||
else:
|
||||
to_cast = VectorData
|
||||
to_cast = VectorIndex if key.endswith("_index") else VectorData
|
||||
if isinstance(val, dict):
|
||||
model[key] = to_cast(**val)
|
||||
else:
|
||||
model[key] = VectorIndex(name=key, description="", value=val)
|
||||
model[key] = to_cast(name=key, description="", value=val)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise ValidationError.from_exception_data(
|
||||
title=f"field {key} cannot be cast to VectorData from {val}",
|
||||
|
@ -388,6 +385,11 @@ class VectorDataMixin(BaseModel, Generic[T]):
|
|||
# redefined in `VectorData`, but included here for testing and type checking
|
||||
value: Optional[T] = None
|
||||
|
||||
def __init__(self, value: Optional[T] = None, **kwargs):
|
||||
if value is not None and "value" not in kwargs:
|
||||
kwargs["value"] = value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||
if self._index:
|
||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||
|
|
|
@ -151,7 +151,7 @@ def test_name_slot():
|
|||
assert slot.name == "name"
|
||||
assert slot.required
|
||||
assert slot.range == "string"
|
||||
assert slot.identifier is None
|
||||
assert slot.identifier
|
||||
assert slot.ifabsent is None
|
||||
assert slot.equals_string is None
|
||||
|
||||
|
@ -160,7 +160,7 @@ def test_name_slot():
|
|||
assert slot.name == "name"
|
||||
assert slot.required
|
||||
assert slot.range == "string"
|
||||
assert slot.identifier is None
|
||||
assert slot.identifier
|
||||
assert slot.ifabsent == "string(FixedName)"
|
||||
assert slot.equals_string == "FixedName"
|
||||
|
||||
|
|
|
@ -284,14 +284,14 @@ def test_dynamictable_assert_equal_length():
|
|||
"existing_col": np.arange(10),
|
||||
"new_col_1": hdmf.VectorData(value=np.arange(11)),
|
||||
}
|
||||
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||
with pytest.raises(ValidationError, match="columns are not of equal length"):
|
||||
_ = MyDT(**cols)
|
||||
|
||||
cols = {
|
||||
"existing_col": np.arange(11),
|
||||
"new_col_1": hdmf.VectorData(value=np.arange(10)),
|
||||
}
|
||||
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||
with pytest.raises(ValidationError, match="columns are not of equal length"):
|
||||
_ = MyDT(**cols)
|
||||
|
||||
# wrong lengths are fine as long as the index is good
|
||||
|
@ -308,7 +308,7 @@ def test_dynamictable_assert_equal_length():
|
|||
"new_col_1": hdmf.VectorData(value=np.arange(100)),
|
||||
"new_col_1_index": hdmf.VectorIndex(value=np.arange(0, 100, 5) + 5),
|
||||
}
|
||||
with pytest.raises(ValidationError, match="Columns are not of equal length"):
|
||||
with pytest.raises(ValidationError, match="columns are not of equal length"):
|
||||
_ = MyDT(**cols)
|
||||
|
||||
|
||||
|
|
|
@ -8,9 +8,11 @@ from nwb_linkml.io.hdf5 import HDF5IO
|
|||
def test_read_from_nwbfile(nwb_file):
|
||||
"""
|
||||
Read data from a pynwb HDF5 NWB file
|
||||
|
||||
Placeholder that just ensures that reads work and all pydantic models validate,
|
||||
testing of correctness of read will happen elsewhere.
|
||||
"""
|
||||
res = HDF5IO(nwb_file).read()
|
||||
res.model_dump_json()
|
||||
|
||||
|
||||
def test_read_from_yaml(nwb_file):
|
||||
|
|
Loading…
Reference in a new issue