mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2025-01-09 21:54:27 +00:00
updating model generation methods, still some models being converted to str instead of being inlined, but almost there
This commit is contained in:
parent
8078492f90
commit
27b5dddfdd
9 changed files with 182 additions and 57 deletions
|
@ -2,6 +2,7 @@
|
||||||
Adapters to linkML classes
|
Adapters to linkML classes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Optional, Type, TypeVar
|
from typing import List, Optional, Type, TypeVar
|
||||||
|
|
||||||
|
@ -32,6 +33,14 @@ class ClassAdapter(Adapter):
|
||||||
cls: TI
|
cls: TI
|
||||||
parent: Optional["ClassAdapter"] = None
|
parent: Optional["ClassAdapter"] = None
|
||||||
|
|
||||||
|
_debug: Optional[bool] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def debug(self) -> bool:
|
||||||
|
if self._debug is None:
|
||||||
|
self._debug = bool(os.environ.get("NWB_LINKML_DEBUG", False))
|
||||||
|
return self._debug
|
||||||
|
|
||||||
@field_validator("cls", mode="before")
|
@field_validator("cls", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def cast_from_string(cls, value: str | TI) -> TI:
|
def cast_from_string(cls, value: str | TI) -> TI:
|
||||||
|
@ -92,6 +101,13 @@ class ClassAdapter(Adapter):
|
||||||
# Get vanilla top-level attributes
|
# Get vanilla top-level attributes
|
||||||
kwargs["attributes"].extend(self.build_attrs(self.cls))
|
kwargs["attributes"].extend(self.build_attrs(self.cls))
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
kwargs["annotations"] = {}
|
||||||
|
kwargs["annotations"]["group_adapter"] = {
|
||||||
|
"tag": "group_adapter",
|
||||||
|
"value": "container_slot",
|
||||||
|
}
|
||||||
|
|
||||||
if extra_attrs is not None:
|
if extra_attrs is not None:
|
||||||
if isinstance(extra_attrs, SlotDefinition):
|
if isinstance(extra_attrs, SlotDefinition):
|
||||||
extra_attrs = [extra_attrs]
|
extra_attrs = [extra_attrs]
|
||||||
|
@ -230,18 +246,22 @@ class ClassAdapter(Adapter):
|
||||||
ifabsent=f"string({name})",
|
ifabsent=f"string({name})",
|
||||||
equals_string=equals_string,
|
equals_string=equals_string,
|
||||||
range="string",
|
range="string",
|
||||||
|
identifier=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
name_slot = SlotDefinition(name="name", required=True, range="string")
|
name_slot = SlotDefinition(name="name", required=True, range="string", identifier=True)
|
||||||
return name_slot
|
return name_slot
|
||||||
|
|
||||||
def build_self_slot(self) -> SlotDefinition:
|
def build_self_slot(self) -> SlotDefinition:
|
||||||
"""
|
"""
|
||||||
If we are a child class, we make a slot so our parent can refer to us
|
If we are a child class, we make a slot so our parent can refer to us
|
||||||
"""
|
"""
|
||||||
return SlotDefinition(
|
slot = SlotDefinition(
|
||||||
name=self._get_slot_name(),
|
name=self._get_slot_name(),
|
||||||
description=self.cls.doc,
|
description=self.cls.doc,
|
||||||
range=self._get_full_name(),
|
range=self._get_full_name(),
|
||||||
**QUANTITY_MAP[self.cls.quantity],
|
**QUANTITY_MAP[self.cls.quantity],
|
||||||
)
|
)
|
||||||
|
if self.debug:
|
||||||
|
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "self_slot"}
|
||||||
|
return slot
|
||||||
|
|
|
@ -11,7 +11,7 @@ from nwb_linkml.adapters.adapter import BuildResult, has_attrs, is_1d, is_compou
|
||||||
from nwb_linkml.adapters.array import ArrayAdapter
|
from nwb_linkml.adapters.array import ArrayAdapter
|
||||||
from nwb_linkml.adapters.classes import ClassAdapter
|
from nwb_linkml.adapters.classes import ClassAdapter
|
||||||
from nwb_linkml.maps import QUANTITY_MAP, Map
|
from nwb_linkml.maps import QUANTITY_MAP, Map
|
||||||
from nwb_linkml.maps.dtype import flat_to_linkml, handle_dtype
|
from nwb_linkml.maps.dtype import flat_to_linkml, handle_dtype, inlined
|
||||||
from nwb_linkml.maps.naming import camel_to_snake
|
from nwb_linkml.maps.naming import camel_to_snake
|
||||||
from nwb_schema_language import Dataset
|
from nwb_schema_language import Dataset
|
||||||
|
|
||||||
|
@ -299,6 +299,8 @@ class MapListlike(DatasetMap):
|
||||||
description=cls.doc,
|
description=cls.doc,
|
||||||
required=cls.quantity not in ("*", "?"),
|
required=cls.quantity not in ("*", "?"),
|
||||||
annotations=[{"source_type": "reference"}],
|
annotations=[{"source_type": "reference"}],
|
||||||
|
inlined=True,
|
||||||
|
inlined_as_list=True,
|
||||||
)
|
)
|
||||||
res.classes[0].attributes["value"] = slot
|
res.classes[0].attributes["value"] = slot
|
||||||
return res
|
return res
|
||||||
|
@ -544,7 +546,9 @@ class MapArrayLikeAttributes(DatasetMap):
|
||||||
array_adapter = ArrayAdapter(cls.dims, cls.shape)
|
array_adapter = ArrayAdapter(cls.dims, cls.shape)
|
||||||
expressions = array_adapter.make_slot()
|
expressions = array_adapter.make_slot()
|
||||||
# make a slot for the arraylike class
|
# make a slot for the arraylike class
|
||||||
array_slot = SlotDefinition(name="value", range=handle_dtype(cls.dtype), **expressions)
|
array_slot = SlotDefinition(
|
||||||
|
name="value", range=handle_dtype(cls.dtype), inlined=inlined(cls.dtype), **expressions
|
||||||
|
)
|
||||||
res.classes[0].attributes.update({"value": array_slot})
|
res.classes[0].attributes.update({"value": array_slot})
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -583,6 +587,7 @@ class MapClassRange(DatasetMap):
|
||||||
description=cls.doc,
|
description=cls.doc,
|
||||||
range=f"{cls.neurodata_type_inc}",
|
range=f"{cls.neurodata_type_inc}",
|
||||||
annotations=[{"named": True}, {"source_type": "neurodata_type_inc"}],
|
annotations=[{"named": True}, {"source_type": "neurodata_type_inc"}],
|
||||||
|
inlined=True,
|
||||||
**QUANTITY_MAP[cls.quantity],
|
**QUANTITY_MAP[cls.quantity],
|
||||||
)
|
)
|
||||||
res = BuildResult(slots=[this_slot])
|
res = BuildResult(slots=[this_slot])
|
||||||
|
@ -799,6 +804,7 @@ class MapCompoundDtype(DatasetMap):
|
||||||
description=a_dtype.doc,
|
description=a_dtype.doc,
|
||||||
range=handle_dtype(a_dtype.dtype),
|
range=handle_dtype(a_dtype.dtype),
|
||||||
array=ArrayExpression(exact_number_dimensions=1),
|
array=ArrayExpression(exact_number_dimensions=1),
|
||||||
|
inlined=inlined(a_dtype.dtype),
|
||||||
**QUANTITY_MAP[cls.quantity],
|
**QUANTITY_MAP[cls.quantity],
|
||||||
)
|
)
|
||||||
res.classes[0].attributes.update(slots)
|
res.classes[0].attributes.update(slots)
|
||||||
|
@ -830,6 +836,8 @@ class DatasetAdapter(ClassAdapter):
|
||||||
if map is not None:
|
if map is not None:
|
||||||
res = map.apply(self.cls, res, self._get_full_name())
|
res = map.apply(self.cls, res, self._get_full_name())
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
res = self._amend_debug(res, map)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def match(self) -> Optional[Type[DatasetMap]]:
|
def match(self) -> Optional[Type[DatasetMap]]:
|
||||||
|
@ -855,11 +863,13 @@ class DatasetAdapter(ClassAdapter):
|
||||||
else:
|
else:
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
||||||
def special_cases(self, res: BuildResult) -> BuildResult:
|
def _amend_debug(self, res: BuildResult, map: Optional[Type[DatasetMap]] = None) -> BuildResult:
|
||||||
"""
|
if map is None:
|
||||||
Apply special cases to build result
|
map_name = "None"
|
||||||
"""
|
else:
|
||||||
res = self._datetime_or_str(res)
|
map_name = map.__name__
|
||||||
|
for cls in res.classes:
|
||||||
def _datetime_or_str(self, res: BuildResult) -> BuildResult:
|
cls.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name}
|
||||||
"""HDF5 doesn't support datetime, so"""
|
for slot in res.slots:
|
||||||
|
slot.annotations["dataset_map"] = {"tag": "dataset_map", "value": map_name}
|
||||||
|
return res
|
||||||
|
|
|
@ -111,6 +111,9 @@ class GroupAdapter(ClassAdapter):
|
||||||
inlined_as_list=False,
|
inlined_as_list=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_group"}
|
||||||
|
|
||||||
if self.parent is not None:
|
if self.parent is not None:
|
||||||
# if we have a parent,
|
# if we have a parent,
|
||||||
# just return the slot itself without the class
|
# just return the slot itself without the class
|
||||||
|
@ -144,17 +147,20 @@ class GroupAdapter(ClassAdapter):
|
||||||
"""
|
"""
|
||||||
name = camel_to_snake(self.cls.neurodata_type_inc) if not self.cls.name else cls.name
|
name = camel_to_snake(self.cls.neurodata_type_inc) if not self.cls.name else cls.name
|
||||||
|
|
||||||
return BuildResult(
|
slot = SlotDefinition(
|
||||||
slots=[
|
name=name,
|
||||||
SlotDefinition(
|
range=self.cls.neurodata_type_inc,
|
||||||
name=name,
|
description=self.cls.doc,
|
||||||
range=self.cls.neurodata_type_inc,
|
inlined=True,
|
||||||
description=self.cls.doc,
|
inlined_as_list=False,
|
||||||
**QUANTITY_MAP[cls.quantity],
|
**QUANTITY_MAP[cls.quantity],
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_slot"}
|
||||||
|
|
||||||
|
return BuildResult(slots=[slot])
|
||||||
|
|
||||||
def build_subclasses(self) -> BuildResult:
|
def build_subclasses(self) -> BuildResult:
|
||||||
"""
|
"""
|
||||||
Build nested groups and datasets
|
Build nested groups and datasets
|
||||||
|
@ -166,20 +172,9 @@ class GroupAdapter(ClassAdapter):
|
||||||
# for creating slots vs. classes is handled by the adapter class
|
# for creating slots vs. classes is handled by the adapter class
|
||||||
dataset_res = BuildResult()
|
dataset_res = BuildResult()
|
||||||
for dset in self.cls.datasets:
|
for dset in self.cls.datasets:
|
||||||
# if dset.name == 'timestamps':
|
|
||||||
# pdb.set_trace()
|
|
||||||
dset_adapter = DatasetAdapter(cls=dset, parent=self)
|
dset_adapter = DatasetAdapter(cls=dset, parent=self)
|
||||||
dataset_res += dset_adapter.build()
|
dataset_res += dset_adapter.build()
|
||||||
|
|
||||||
# Actually i'm not sure we have to special case this, we could handle it in
|
|
||||||
# i/o instead
|
|
||||||
|
|
||||||
# Groups are a bit more complicated because they can also behave like
|
|
||||||
# range declarations:
|
|
||||||
# eg. a group can have multiple groups with `neurodata_type_inc`, no name,
|
|
||||||
# and quantity of *,
|
|
||||||
# the group can then contain any number of groups of those included types as direct children
|
|
||||||
|
|
||||||
group_res = BuildResult()
|
group_res = BuildResult()
|
||||||
|
|
||||||
for group in self.cls.groups:
|
for group in self.cls.groups:
|
||||||
|
@ -190,6 +185,33 @@ class GroupAdapter(ClassAdapter):
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def build_self_slot(self) -> SlotDefinition:
|
||||||
|
"""
|
||||||
|
If we are a child class, we make a slot so our parent can refer to us
|
||||||
|
|
||||||
|
Groups are a bit more complicated because they can also behave like
|
||||||
|
range declarations:
|
||||||
|
eg. a group can have multiple groups with `neurodata_type_inc`, no name,
|
||||||
|
and quantity of *,
|
||||||
|
the group can then contain any number of groups of those included types as direct children
|
||||||
|
|
||||||
|
We make sure that we're inlined as a dict so our parent class can refer to us like::
|
||||||
|
|
||||||
|
parent.{slot_name}[{name}] = self
|
||||||
|
|
||||||
|
"""
|
||||||
|
slot = SlotDefinition(
|
||||||
|
name=self._get_slot_name(),
|
||||||
|
description=self.cls.doc,
|
||||||
|
range=self._get_full_name(),
|
||||||
|
inlined=True,
|
||||||
|
inlined_as_list=True,
|
||||||
|
**QUANTITY_MAP[self.cls.quantity],
|
||||||
|
)
|
||||||
|
if self.debug:
|
||||||
|
slot.annotations["group_adapter"] = {"tag": "group_adapter", "value": "container_slot"}
|
||||||
|
return slot
|
||||||
|
|
||||||
def _check_if_container(self, group: Group) -> bool:
|
def _check_if_container(self, group: Group) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a given subgroup is a container subgroup,
|
Check if a given subgroup is a container subgroup,
|
||||||
|
|
|
@ -5,6 +5,7 @@ customized to support NWB models.
|
||||||
See class and module docstrings for details :)
|
See class and module docstrings for details :)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pdb
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
@ -27,7 +28,11 @@ from linkml_runtime.utils.compile_python import file_text
|
||||||
from linkml_runtime.utils.formatutils import remove_empty_items
|
from linkml_runtime.utils.formatutils import remove_empty_items
|
||||||
from linkml_runtime.utils.schemaview import SchemaView
|
from linkml_runtime.utils.schemaview import SchemaView
|
||||||
|
|
||||||
from nwb_linkml.includes.base import BASEMODEL_GETITEM, BASEMODEL_COERCE_VALUE
|
from nwb_linkml.includes.base import (
|
||||||
|
BASEMODEL_GETITEM,
|
||||||
|
BASEMODEL_COERCE_VALUE,
|
||||||
|
BASEMODEL_COERCE_CHILD,
|
||||||
|
)
|
||||||
from nwb_linkml.includes.hdmf import (
|
from nwb_linkml.includes.hdmf import (
|
||||||
DYNAMIC_TABLE_IMPORTS,
|
DYNAMIC_TABLE_IMPORTS,
|
||||||
DYNAMIC_TABLE_INJECTS,
|
DYNAMIC_TABLE_INJECTS,
|
||||||
|
@ -36,7 +41,7 @@ from nwb_linkml.includes.hdmf import (
|
||||||
)
|
)
|
||||||
from nwb_linkml.includes.types import ModelTypeString, NamedImports, NamedString, _get_name
|
from nwb_linkml.includes.types import ModelTypeString, NamedImports, NamedString, _get_name
|
||||||
|
|
||||||
OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
|
OPTIONAL_PATTERN = re.compile(r"Optional\[(.*)\]")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -53,6 +58,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
|
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")',
|
||||||
BASEMODEL_GETITEM,
|
BASEMODEL_GETITEM,
|
||||||
BASEMODEL_COERCE_VALUE,
|
BASEMODEL_COERCE_VALUE,
|
||||||
|
BASEMODEL_COERCE_CHILD,
|
||||||
)
|
)
|
||||||
split: bool = True
|
split: bool = True
|
||||||
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
|
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
|
||||||
|
@ -134,6 +140,7 @@ class NWBPydanticGenerator(PydanticGenerator):
|
||||||
cls = AfterGenerateClass.inject_dynamictable(cls)
|
cls = AfterGenerateClass.inject_dynamictable(cls)
|
||||||
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
|
||||||
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
|
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
|
||||||
|
cls = AfterGenerateClass.strip_vector_data_slots(cls, sv)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
|
||||||
|
@ -227,7 +234,8 @@ class AfterGenerateSlot:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "named" in slot.source.annotations and slot.source.annotations["named"].value:
|
if "named" in slot.source.annotations and slot.source.annotations["named"].value:
|
||||||
slot.attribute.range = f"Named[{slot.attribute.range}]"
|
|
||||||
|
slot.attribute.range = wrap_preserving_optional(slot.attribute.range, "Named")
|
||||||
named_injects = [ModelTypeString, _get_name, NamedString]
|
named_injects = [ModelTypeString, _get_name, NamedString]
|
||||||
if slot.injected_classes is None:
|
if slot.injected_classes is None:
|
||||||
slot.injected_classes = named_injects
|
slot.injected_classes = named_injects
|
||||||
|
@ -325,7 +333,9 @@ class AfterGenerateClass:
|
||||||
else:
|
else:
|
||||||
wrap_cls = "VectorData"
|
wrap_cls = "VectorData"
|
||||||
|
|
||||||
cls.cls.attributes[an_attr].range = "".join([wrap_cls, "[", slot_range, "]"])
|
cls.cls.attributes[an_attr].range = wrap_preserving_optional(
|
||||||
|
slot_range, wrap_cls
|
||||||
|
)
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@ -343,6 +353,15 @@ class AfterGenerateClass:
|
||||||
cls.imports += [imp]
|
cls.imports += [imp]
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def strip_vector_data_slots(cls: ClassResult, sv: SchemaView) -> ClassResult:
|
||||||
|
"""
|
||||||
|
Remove spurious ``vector_data`` slots from DynamicTables
|
||||||
|
"""
|
||||||
|
if "vector_data" in cls.cls.attributes:
|
||||||
|
del cls.cls.attributes["vector_data"]
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
def compile_python(
|
def compile_python(
|
||||||
text_or_fn: str, package_path: Path = None, module_name: str = "test"
|
text_or_fn: str, package_path: Path = None, module_name: str = "test"
|
||||||
|
@ -364,3 +383,26 @@ def compile_python(
|
||||||
exec(spec, module.__dict__)
|
exec(spec, module.__dict__)
|
||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
return module
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_preserving_optional(annotation: str, wrap: str) -> str:
|
||||||
|
"""
|
||||||
|
Add a wrapping type to a type annotation string,
|
||||||
|
preserving any `Optional[]` annotation, bumping it to the outside
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> wrap_preserving_optional('Optional[list[str]]', 'NewType')
|
||||||
|
'Optional[NewType[list[str]]]'
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_optional = OPTIONAL_PATTERN.match(annotation)
|
||||||
|
if is_optional:
|
||||||
|
annotation = is_optional.groups()[0]
|
||||||
|
annotation = f"Optional[{wrap}[{annotation}]]"
|
||||||
|
else:
|
||||||
|
if "Optional" in annotation:
|
||||||
|
pdb.set_trace()
|
||||||
|
annotation = f"{wrap}[{annotation}]"
|
||||||
|
return annotation
|
||||||
|
|
|
@ -23,8 +23,11 @@ BASEMODEL_COERCE_VALUE = """
|
||||||
except Exception as e1:
|
except Exception as e1:
|
||||||
try:
|
try:
|
||||||
return handler(v.value)
|
return handler(v.value)
|
||||||
except:
|
except AttributeError:
|
||||||
raise e1
|
try:
|
||||||
|
return handler(v["value"])
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
raise e1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BASEMODEL_COERCE_CHILD = """
|
BASEMODEL_COERCE_CHILD = """
|
||||||
|
@ -32,6 +35,10 @@ BASEMODEL_COERCE_CHILD = """
|
||||||
@classmethod
|
@classmethod
|
||||||
def coerce_subclass(cls, v: Any, info) -> Any:
|
def coerce_subclass(cls, v: Any, info) -> Any:
|
||||||
\"\"\"Recast parent classes into child classes\"\"\"
|
\"\"\"Recast parent classes into child classes\"\"\"
|
||||||
|
if isinstance(v, BaseModel):
|
||||||
|
annotation = cls.model_fields[info.field_name].annotation
|
||||||
|
annotation = annotation.__args__[0] if hasattr(annotation, "__args__") else annotation
|
||||||
|
if issubclass(annotation, type(v)) and annotation is not type(v):
|
||||||
|
v = annotation(**{**v.__dict__, **v.__pydantic_extra__})
|
||||||
return v
|
return v
|
||||||
pdb.set_trace()
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -249,6 +249,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
if k not in cls.NON_COLUMN_FIELDS
|
if k not in cls.NON_COLUMN_FIELDS
|
||||||
and not k.endswith("_index")
|
and not k.endswith("_index")
|
||||||
and not isinstance(model[k], VectorIndexMixin)
|
and not isinstance(model[k], VectorIndexMixin)
|
||||||
|
and model[k] is not None
|
||||||
]
|
]
|
||||||
model["colnames"] = colnames
|
model["colnames"] = colnames
|
||||||
else:
|
else:
|
||||||
|
@ -264,6 +265,7 @@ class DynamicTableMixin(BaseModel):
|
||||||
and not k.endswith("_index")
|
and not k.endswith("_index")
|
||||||
and k not in model["colnames"]
|
and k not in model["colnames"]
|
||||||
and not isinstance(model[k], VectorIndexMixin)
|
and not isinstance(model[k], VectorIndexMixin)
|
||||||
|
and model[k] is not None
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
model["colnames"] = colnames
|
model["colnames"] = colnames
|
||||||
|
@ -336,9 +338,9 @@ class DynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
Ensure that all columns are equal length
|
Ensure that all columns are equal length
|
||||||
"""
|
"""
|
||||||
lengths = [len(v) for v in self._columns.values()] + [len(self.id)]
|
lengths = [len(v) for v in self._columns.values() if v is not None] + [len(self.id)]
|
||||||
assert all([length == lengths[0] for length in lengths]), (
|
assert all([length == lengths[0] for length in lengths]), (
|
||||||
"Columns are not of equal length! "
|
"DynamicTable columns are not of equal length! "
|
||||||
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
f"Got colnames:\n{self.colnames}\nand lengths: {lengths}"
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
@ -386,11 +388,6 @@ class VectorDataMixin(BaseModel, Generic[T]):
|
||||||
# redefined in `VectorData`, but included here for testing and type checking
|
# redefined in `VectorData`, but included here for testing and type checking
|
||||||
value: Optional[T] = None
|
value: Optional[T] = None
|
||||||
|
|
||||||
# def __init__(self, value: Optional[NDArray] = None, **kwargs):
|
|
||||||
# if value is not None and "value" not in kwargs:
|
|
||||||
# kwargs["value"] = value
|
|
||||||
# super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
def __getitem__(self, item: Union[str, int, slice, Tuple[Union[str, int, slice], ...]]) -> Any:
|
||||||
if self._index:
|
if self._index:
|
||||||
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
# Following hdmf, VectorIndex is the thing that knows how to do the slicing
|
||||||
|
@ -587,6 +584,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
||||||
__pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]]
|
__pydantic_extra__: Dict[str, Union["DynamicTableMixin", "VectorDataMixin", "VectorIndexMixin"]]
|
||||||
|
|
||||||
NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = (
|
NON_CATEGORY_FIELDS: ClassVar[tuple[str]] = (
|
||||||
|
"id",
|
||||||
"name",
|
"name",
|
||||||
"categories",
|
"categories",
|
||||||
"colnames",
|
"colnames",
|
||||||
|
@ -622,7 +620,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
||||||
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str):
|
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], str):
|
||||||
# get a slice of a single table
|
# get a slice of a single table
|
||||||
return self._categories[item[1]][item[0]]
|
return self._categories[item[1]][item[0]]
|
||||||
elif isinstance(item, (int, slice, Iterable)):
|
elif isinstance(item, (int, slice, Iterable, np.int_)):
|
||||||
# get a slice of all the tables
|
# get a slice of all the tables
|
||||||
ids = self.id[item]
|
ids = self.id[item]
|
||||||
if not isinstance(ids, Iterable):
|
if not isinstance(ids, Iterable):
|
||||||
|
@ -634,9 +632,9 @@ class AlignedDynamicTableMixin(BaseModel):
|
||||||
if isinstance(table, pd.DataFrame):
|
if isinstance(table, pd.DataFrame):
|
||||||
table = table.reset_index()
|
table = table.reset_index()
|
||||||
elif isinstance(table, np.ndarray):
|
elif isinstance(table, np.ndarray):
|
||||||
table = pd.DataFrame({category_name: [table]})
|
table = pd.DataFrame({category_name: [table]}, index=ids.index)
|
||||||
elif isinstance(table, Iterable):
|
elif isinstance(table, Iterable):
|
||||||
table = pd.DataFrame({category_name: table})
|
table = pd.DataFrame({category_name: table}, index=ids.index)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Don't know how to construct category table for {category_name}"
|
f"Don't know how to construct category table for {category_name}"
|
||||||
|
@ -756,7 +754,7 @@ class AlignedDynamicTableMixin(BaseModel):
|
||||||
"""
|
"""
|
||||||
lengths = [len(v) for v in self._categories.values()] + [len(self.id)]
|
lengths = [len(v) for v in self._categories.values()] + [len(self.id)]
|
||||||
assert all([length == lengths[0] for length in lengths]), (
|
assert all([length == lengths[0] for length in lengths]), (
|
||||||
"Columns are not of equal length! "
|
"AlignedDynamicTableColumns are not of equal length! "
|
||||||
f"Got colnames:\n{self.categories}\nand lengths: {lengths}"
|
f"Got colnames:\n{self.categories}\nand lengths: {lengths}"
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -3,7 +3,7 @@ Dtype mappings
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -160,14 +160,28 @@ def handle_dtype(dtype: DTypeType | None) -> str:
|
||||||
elif isinstance(dtype, FlatDtype):
|
elif isinstance(dtype, FlatDtype):
|
||||||
return dtype.value
|
return dtype.value
|
||||||
elif isinstance(dtype, list) and isinstance(dtype[0], CompoundDtype):
|
elif isinstance(dtype, list) and isinstance(dtype[0], CompoundDtype):
|
||||||
# there is precisely one class that uses compound dtypes:
|
# Compound Dtypes are handled by the MapCompoundDtype dataset map,
|
||||||
# TimeSeriesReferenceVectorData
|
# but this function is also used within ``check`` methods, so we should always
|
||||||
# compoundDtypes are able to define a ragged table according to the schema
|
# return something from it rather than raise
|
||||||
# but are used in this single case equivalently to attributes.
|
|
||||||
# so we'll... uh... treat them as slots.
|
|
||||||
# TODO
|
|
||||||
return "AnyType"
|
return "AnyType"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# flat dtype
|
# flat dtype
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
|
def inlined(dtype: DTypeType | None) -> Optional[bool]:
|
||||||
|
"""
|
||||||
|
Check if a slot should be inlined based on its dtype
|
||||||
|
|
||||||
|
for now that is equivalent to checking whether that dtype is another a reference dtype,
|
||||||
|
but the function remains semantically reserved for answering this question w.r.t. dtype.
|
||||||
|
|
||||||
|
Returns ``None`` if not inlined to not clutter generated models with unnecessary props
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
True
|
||||||
|
if isinstance(dtype, ReferenceDtype)
|
||||||
|
or (isinstance(dtype, CompoundDtype) and isinstance(dtype.dtype, ReferenceDtype))
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
|
@ -10,6 +10,7 @@ def test_read_from_nwbfile(nwb_file):
|
||||||
Read data from a pynwb HDF5 NWB file
|
Read data from a pynwb HDF5 NWB file
|
||||||
"""
|
"""
|
||||||
res = HDF5IO(nwb_file).read()
|
res = HDF5IO(nwb_file).read()
|
||||||
|
res.model_dump_json()
|
||||||
|
|
||||||
|
|
||||||
def test_read_from_yaml(nwb_file):
|
def test_read_from_yaml(nwb_file):
|
||||||
|
|
|
@ -222,6 +222,11 @@ def parser() -> ArgumentParser:
|
||||||
),
|
),
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
help="Add annotations to generated schema that indicate how they were generated",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
parser.add_argument("--pdb", help="Launch debugger on an error", action="store_true")
|
parser.add_argument("--pdb", help="Launch debugger on an error", action="store_true")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -229,6 +234,12 @@ def parser() -> ArgumentParser:
|
||||||
def main():
|
def main():
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
os.environ["NWB_LINKML_DEBUG"] = "true"
|
||||||
|
else:
|
||||||
|
if "NWB_LINKML_DEBUG" in os.environ:
|
||||||
|
del os.environ["NWB_LINKML_DEBUG"]
|
||||||
|
|
||||||
tmp_dir = make_tmp_dir(clear=True)
|
tmp_dir = make_tmp_dir(clear=True)
|
||||||
git_dir = tmp_dir / "git"
|
git_dir = tmp_dir / "git"
|
||||||
git_dir.mkdir(exist_ok=True)
|
git_dir.mkdir(exist_ok=True)
|
||||||
|
|
Loading…
Reference in a new issue