correctly handle attributes

This commit is contained in:
sneakers-the-rat 2024-08-05 18:41:38 -07:00
parent 652ddb3b48
commit da6d0d8608
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
6 changed files with 290 additions and 96 deletions

View file

@ -26,7 +26,7 @@ from linkml_runtime.linkml_model import (
) )
from pydantic import BaseModel from pydantic import BaseModel
from nwb_schema_language import Attribute, Dataset, Group, Schema from nwb_schema_language import Attribute, Dataset, Group, Schema, CompoundDtype
if sys.version_info.minor >= 11: if sys.version_info.minor >= 11:
from typing import TypeVarTuple, Unpack from typing import TypeVarTuple, Unpack
@ -238,3 +238,36 @@ class Adapter(BaseModel):
for item in self.walk(input): for item in self.walk(input):
if any([type(item) is atype for atype in get_type]): if any([type(item) is atype for atype in get_type]):
yield item yield item
def is_1d(cls: Dataset | Attribute) -> bool:
"""
Check if the values of a dataset are 1-dimensional.
Specifically:
* a single-layer dim/shape list of length 1, or
* a nested dim/shape list where every nested spec is of length 1
"""
return (
not any([isinstance(dim, list) for dim in cls.dims]) and len(cls.dims) == 1
) or ( # nested list
all([isinstance(dim, list) for dim in cls.dims])
and len(cls.dims) == 1
and len(cls.dims[0]) == 1
)
def is_compound(cls: Dataset) -> bool:
"""Check if dataset has a compound dtype"""
return (
isinstance(cls.dtype, list)
and len(cls.dtype) > 0
and isinstance(cls.dtype[0], CompoundDtype)
)
def has_attrs(cls: Dataset) -> bool:
"""
Check if a dataset has any attributes at all without defaults
"""
return len(cls.attributes) > 0 and all([not a.value for a in cls.attributes])

View file

@ -0,0 +1,194 @@
"""
Adapters for attribute types
"""
from abc import abstractmethod
from typing import ClassVar, Optional, TypedDict, Type
from linkml_runtime.linkml_model.meta import SlotDefinition
from nwb_linkml.adapters.array import ArrayAdapter
from nwb_linkml.adapters.adapter import BuildResult, is_1d, Adapter
from nwb_linkml.maps import Map
from nwb_linkml.maps.dtype import handle_dtype
from nwb_schema_language import Attribute
def _make_ifabsent(val: str | int | float | None) -> str | None:
if val is None:
return None
elif isinstance(val, str):
return f"string({val})"
elif isinstance(val, int):
return f"integer({val})"
elif isinstance(val, float):
return f"float({val})"
else:
return str(value)
class AttrDefaults(TypedDict):
equals_string: str | None
equals_number: float | int | None
ifabsent: str | None
class AttributeMap(Map):
@classmethod
def handle_defaults(cls, attr: Attribute) -> AttrDefaults:
"""
Construct arguments for linkml slot default metaslots from nwb schema lang attribute props
"""
equals_string = None
equals_number = None
default_value = None
if attr.value:
if isinstance(attr.value, (int, float)):
equals_number = attr.value
elif attr.value:
equals_string = str(attr.value)
if equals_number:
default_value = _make_ifabsent(equals_number)
elif equals_string:
default_value = _make_ifabsent(equals_string)
elif attr.default_value:
default_value = _make_ifabsent(attr.default_value)
return AttrDefaults(
equals_string=equals_string, equals_number=equals_number, ifabsent=default_value
)
@classmethod
@abstractmethod
def check(cls, attr: Attribute) -> bool:
"""
Check if this map applies
"""
pass # pragma: no cover
@classmethod
@abstractmethod
def apply(
cls, attr: Attribute, res: Optional[BuildResult] = None, name: Optional[str] = None
) -> BuildResult:
"""
Apply this mapping
"""
pass # pragma: no cover
class MapScalar(AttributeMap):
"""
Map a simple scalar value
"""
@classmethod
def check(cls, attr: Attribute) -> bool:
"""
Check if we are a scalar value!
"""
return not attr.dims and not attr.shape
@classmethod
def apply(cls, attr: Attribute, res: Optional[BuildResult] = None) -> BuildResult:
"""
Make a slot for us!
"""
slot = SlotDefinition(
name=attr.name,
range=handle_dtype(attr.dtype),
description=attr.doc,
required=attr.required,
**cls.handle_defaults(attr),
)
return BuildResult(slots=[slot])
class MapArray(AttributeMap):
"""
Map an array value!
"""
@classmethod
def check(cls, attr: Attribute) -> bool:
"""
Check that we have some array specification!
"""
return bool(attr.dims) or attr.shape
@classmethod
def apply(cls, attr: Attribute, res: Optional[BuildResult] = None) -> BuildResult:
"""
Make a slot with an array expression!
If we're just a 1D array, use a list (set multivalued: true).
If more than that, make an array descriptor
"""
expressions = {}
multivalued = False
if is_1d(attr):
multivalued = True
else:
# ---------------------------------
# SPECIAL CASE: Some old versions of HDMF don't have ``dims``, only shape
# ---------------------------------
shape = attr.shape
dims = attr.dims
if shape and not dims:
dims = ["null"] * len(shape)
array_adapter = ArrayAdapter(dims, shape)
expressions = array_adapter.make_slot()
slot = SlotDefinition(
name=attr.name,
range=handle_dtype(attr.dtype),
multivalued=multivalued,
description=attr.doc,
required=attr.required,
**expressions,
**cls.handle_defaults(attr),
)
return BuildResult(slots=[slot])
class AttributeAdapter(Adapter):
"""
Create slot definitions from nwb schema language attributes
"""
TYPE: ClassVar[Type] = Attribute
cls: Attribute
def build(self) -> "BuildResult":
"""
Build the slot definitions, every attribute should have a map.
"""
map = self.match()
return map.apply(self.cls)
def match(self) -> Optional[Type[AttributeMap]]:
"""
Find the map class that applies to this attribute
Returns:
:class:`.AttributeMap`
Raises:
RuntimeError - if more than one map matches
"""
# find a map to use
matches = [m for m in AttributeMap.__subclasses__() if m.check(self.cls)]
if len(matches) > 1: # pragma: no cover
raise RuntimeError(
"Only one map should apply to a dataset, you need to refactor the maps! Got maps:"
f" {matches}"
)
elif len(matches) == 0:
return None
else:
return matches[0]

View file

@ -9,9 +9,10 @@ from linkml_runtime.linkml_model import ClassDefinition, SlotDefinition
from pydantic import field_validator from pydantic import field_validator
from nwb_linkml.adapters.adapter import Adapter, BuildResult from nwb_linkml.adapters.adapter import Adapter, BuildResult
from nwb_linkml.adapters.attribute import AttributeAdapter
from nwb_linkml.maps import QUANTITY_MAP from nwb_linkml.maps import QUANTITY_MAP
from nwb_linkml.maps.naming import camel_to_snake from nwb_linkml.maps.naming import camel_to_snake
from nwb_schema_language import CompoundDtype, Dataset, DTypeType, FlatDtype, Group, ReferenceDtype from nwb_schema_language import Dataset, Group
T = TypeVar("T", bound=Type[Dataset] | Type[Group]) T = TypeVar("T", bound=Type[Dataset] | Type[Group])
TI = TypeVar("TI", bound=Dataset | Group) TI = TypeVar("TI", bound=Dataset | Group)
@ -118,16 +119,9 @@ class ClassAdapter(Adapter):
Returns: Returns:
list[:class:`.SlotDefinition`] list[:class:`.SlotDefinition`]
""" """
attrs = [ results = [AttributeAdapter(cls=attr).build() for attr in cls.attributes]
SlotDefinition( slots = [r.slots[0] for r in results]
name=attr.name, return slots
description=attr.doc,
range=self.handle_dtype(attr.dtype),
)
for attr in cls.attributes
]
return attrs
def _get_full_name(self) -> str: def _get_full_name(self) -> str:
"""The full name of the object in the generated linkml """The full name of the object in the generated linkml
@ -205,37 +199,6 @@ class ClassAdapter(Adapter):
return name return name
@classmethod
def handle_dtype(cls, dtype: DTypeType | None) -> str:
"""
Get the string form of a dtype
Args:
dtype (:class:`.DTypeType`): Dtype to stringify
Returns:
str
"""
if isinstance(dtype, ReferenceDtype):
return dtype.target_type
elif dtype is None or dtype == []:
# Some ill-defined datasets are "abstract" despite that not being in the schema language
return "AnyType"
elif isinstance(dtype, FlatDtype):
return dtype.value
elif isinstance(dtype, list) and isinstance(dtype[0], CompoundDtype):
# there is precisely one class that uses compound dtypes:
# TimeSeriesReferenceVectorData
# compoundDtypes are able to define a ragged table according to the schema
# but are used in this single case equivalently to attributes.
# so we'll... uh... treat them as slots.
# TODO
return "AnyType"
else:
# flat dtype
return dtype
def build_name_slot(self) -> SlotDefinition: def build_name_slot(self) -> SlotDefinition:
""" """
If a class has a name, then that name should be a slot with a If a class has a name, then that name should be a slot with a

View file

@ -7,13 +7,13 @@ from typing import ClassVar, Optional, Type
from linkml_runtime.linkml_model.meta import ArrayExpression, SlotDefinition from linkml_runtime.linkml_model.meta import ArrayExpression, SlotDefinition
from nwb_linkml.adapters.adapter import BuildResult from nwb_linkml.adapters.adapter import BuildResult, is_1d, is_compound, has_attrs
from nwb_linkml.adapters.array import ArrayAdapter from nwb_linkml.adapters.array import ArrayAdapter
from nwb_linkml.adapters.classes import ClassAdapter from nwb_linkml.adapters.classes import ClassAdapter
from nwb_linkml.maps import QUANTITY_MAP, Map from nwb_linkml.maps import QUANTITY_MAP, Map
from nwb_linkml.maps.dtype import flat_to_linkml from nwb_linkml.maps.dtype import flat_to_linkml, handle_dtype
from nwb_linkml.maps.naming import camel_to_snake from nwb_linkml.maps.naming import camel_to_snake
from nwb_schema_language import CompoundDtype, Dataset from nwb_schema_language import Dataset
class DatasetMap(Map): class DatasetMap(Map):
@ -106,7 +106,7 @@ class MapScalar(DatasetMap):
this_slot = SlotDefinition( this_slot = SlotDefinition(
name=cls.name, name=cls.name,
description=cls.doc, description=cls.doc,
range=ClassAdapter.handle_dtype(cls.dtype), range=handle_dtype(cls.dtype),
**QUANTITY_MAP[cls.quantity], **QUANTITY_MAP[cls.quantity],
) )
res = BuildResult(slots=[this_slot]) res = BuildResult(slots=[this_slot])
@ -203,9 +203,7 @@ class MapScalarAttributes(DatasetMap):
""" """
Map to a scalar attribute with an adjoining "value" slot Map to a scalar attribute with an adjoining "value" slot
""" """
value_slot = SlotDefinition( value_slot = SlotDefinition(name="value", range=handle_dtype(cls.dtype), required=True)
name="value", range=ClassAdapter.handle_dtype(cls.dtype), required=True
)
res.classes[0].attributes["value"] = value_slot res.classes[0].attributes["value"] = value_slot
return res return res
@ -271,7 +269,7 @@ class MapListlike(DatasetMap):
* - ``dtype`` * - ``dtype``
- ``Class`` - ``Class``
""" """
dtype = ClassAdapter.handle_dtype(cls.dtype) dtype = handle_dtype(cls.dtype)
return ( return (
cls.neurodata_type_inc != "VectorData" cls.neurodata_type_inc != "VectorData"
and is_1d(cls) and is_1d(cls)
@ -289,7 +287,7 @@ class MapListlike(DatasetMap):
slot = SlotDefinition( slot = SlotDefinition(
name="value", name="value",
multivalued=True, multivalued=True,
range=ClassAdapter.handle_dtype(cls.dtype), range=handle_dtype(cls.dtype),
description=cls.doc, description=cls.doc,
required=cls.quantity not in ("*", "?"), required=cls.quantity not in ("*", "?"),
annotations=[{"source_type": "reference"}], annotations=[{"source_type": "reference"}],
@ -378,7 +376,7 @@ class MapArraylike(DatasetMap):
- ``False`` - ``False``
""" """
dtype = ClassAdapter.handle_dtype(cls.dtype) dtype = handle_dtype(cls.dtype)
return ( return (
cls.name cls.name
and (all([cls.dims, cls.shape]) or cls.neurodata_type_inc == "VectorData") and (all([cls.dims, cls.shape]) or cls.neurodata_type_inc == "VectorData")
@ -409,7 +407,7 @@ class MapArraylike(DatasetMap):
SlotDefinition( SlotDefinition(
name=name, name=name,
multivalued=False, multivalued=False,
range=ClassAdapter.handle_dtype(cls.dtype), range=handle_dtype(cls.dtype),
description=cls.doc, description=cls.doc,
required=cls.quantity not in ("*", "?"), required=cls.quantity not in ("*", "?"),
**expressions, **expressions,
@ -513,7 +511,7 @@ class MapArrayLikeAttributes(DatasetMap):
""" """
Check that we're an array with some additional metadata Check that we're an array with some additional metadata
""" """
dtype = ClassAdapter.handle_dtype(cls.dtype) dtype = handle_dtype(cls.dtype)
return ( return (
all([cls.dims, cls.shape]) all([cls.dims, cls.shape])
and cls.neurodata_type_inc != "VectorData" and cls.neurodata_type_inc != "VectorData"
@ -532,9 +530,7 @@ class MapArrayLikeAttributes(DatasetMap):
array_adapter = ArrayAdapter(cls.dims, cls.shape) array_adapter = ArrayAdapter(cls.dims, cls.shape)
expressions = array_adapter.make_slot() expressions = array_adapter.make_slot()
# make a slot for the arraylike class # make a slot for the arraylike class
array_slot = SlotDefinition( array_slot = SlotDefinition(name="value", range=handle_dtype(cls.dtype), **expressions)
name="value", range=ClassAdapter.handle_dtype(cls.dtype), **expressions
)
res.classes[0].attributes.update({"value": array_slot}) res.classes[0].attributes.update({"value": array_slot})
return res return res
@ -596,7 +592,7 @@ class MapVectorClassRange(DatasetMap):
Check that we are a VectorData object without any additional attributes Check that we are a VectorData object without any additional attributes
with a dtype that refers to another class with a dtype that refers to another class
""" """
dtype = ClassAdapter.handle_dtype(cls.dtype) dtype = handle_dtype(cls.dtype)
return ( return (
cls.neurodata_type_inc == "VectorData" cls.neurodata_type_inc == "VectorData"
and cls.name and cls.name
@ -617,7 +613,7 @@ class MapVectorClassRange(DatasetMap):
name=cls.name, name=cls.name,
description=cls.doc, description=cls.doc,
multivalued=True, multivalued=True,
range=ClassAdapter.handle_dtype(cls.dtype), range=handle_dtype(cls.dtype),
required=cls.quantity not in ("*", "?"), required=cls.quantity not in ("*", "?"),
) )
res = BuildResult(slots=[this_slot]) res = BuildResult(slots=[this_slot])
@ -672,7 +668,7 @@ class MapVectorClassRange(DatasetMap):
# this_slot = SlotDefinition( # this_slot = SlotDefinition(
# name=cls.name, # name=cls.name,
# description=cls.doc, # description=cls.doc,
# range=ClassAdapter.handle_dtype(cls.dtype), # range=handle_dtype(cls.dtype),
# multivalued=True, # multivalued=True,
# ) # )
# # No need to make a class for us, so we replace the existing build results # # No need to make a class for us, so we replace the existing build results
@ -783,7 +779,7 @@ class MapCompoundDtype(DatasetMap):
slots[a_dtype.name] = SlotDefinition( slots[a_dtype.name] = SlotDefinition(
name=a_dtype.name, name=a_dtype.name,
description=a_dtype.doc, description=a_dtype.doc,
range=ClassAdapter.handle_dtype(a_dtype.dtype), range=handle_dtype(a_dtype.dtype),
**QUANTITY_MAP[cls.quantity], **QUANTITY_MAP[cls.quantity],
) )
res.classes[0].attributes.update(slots) res.classes[0].attributes.update(slots)
@ -836,36 +832,3 @@ class DatasetAdapter(ClassAdapter):
return None return None
else: else:
return matches[0] return matches[0]
def is_1d(cls: Dataset) -> bool:
"""
Check if the values of a dataset are 1-dimensional.
Specifically:
* a single-layer dim/shape list of length 1, or
* a nested dim/shape list where every nested spec is of length 1
"""
return (
not any([isinstance(dim, list) for dim in cls.dims]) and len(cls.dims) == 1
) or ( # nested list
all([isinstance(dim, list) for dim in cls.dims])
and len(cls.dims) == 1
and len(cls.dims[0]) == 1
)
def is_compound(cls: Dataset) -> bool:
"""Check if dataset has a compound dtype"""
return (
isinstance(cls.dtype, list)
and len(cls.dtype) > 0
and isinstance(cls.dtype[0], CompoundDtype)
)
def has_attrs(cls: Dataset) -> bool:
"""
Check if a dataset has any attributes at all without defaults
"""
return len(cls.attributes) > 0 and all([not a.value for a in cls.attributes])

View file

@ -7,6 +7,7 @@ from typing import Any, Type
import nptyping import nptyping
import numpy as np import numpy as np
from nwb_schema_language import CompoundDtype, DTypeType, FlatDtype, ReferenceDtype
flat_to_linkml = { flat_to_linkml = {
"float": "float", "float": "float",
@ -185,3 +186,34 @@ def struct_from_dtype(dtype: np.dtype) -> Type[nptyping.Structure]:
struct_pieces = [f"{k}: {flat_to_nptyping[v[0].name]}" for k, v in dtype.fields.items()] struct_pieces = [f"{k}: {flat_to_nptyping[v[0].name]}" for k, v in dtype.fields.items()]
struct_dtype = ", ".join(struct_pieces) struct_dtype = ", ".join(struct_pieces)
return nptyping.Structure[struct_dtype] return nptyping.Structure[struct_dtype]
def handle_dtype(dtype: DTypeType | None) -> str:
"""
Get the string form of a dtype
Args:
dtype (:class:`.DTypeType`): Dtype to stringify
Returns:
str
"""
if isinstance(dtype, ReferenceDtype):
return dtype.target_type
elif dtype is None or dtype == []:
# Some ill-defined datasets are "abstract" despite that not being in the schema language
return "AnyType"
elif isinstance(dtype, FlatDtype):
return dtype.value
elif isinstance(dtype, list) and isinstance(dtype[0], CompoundDtype):
# there is precisely one class that uses compound dtypes:
# TimeSeriesReferenceVectorData
# compoundDtypes are able to define a ragged table according to the schema
# but are used in this single case equivalently to attributes.
# so we'll... uh... treat them as slots.
# TODO
return "AnyType"
else:
# flat dtype
return dtype

View file

@ -1,6 +1,8 @@
import shutil import shutil
import os import os
import sys
import traceback import traceback
from pdb import post_mortem
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
@ -53,6 +55,7 @@ def generate_versions(
dry_run: bool = False, dry_run: bool = False,
repo: GitRepo = NWB_CORE_REPO, repo: GitRepo = NWB_CORE_REPO,
hdmf_only=False, hdmf_only=False,
pdb=False,
): ):
""" """
Generate linkml models for all versions Generate linkml models for all versions
@ -128,6 +131,11 @@ def generate_versions(
build_progress.update(pydantic_task, action="Built Pydantic") build_progress.update(pydantic_task, action="Built Pydantic")
except Exception as e: except Exception as e:
if pdb:
live.stop()
post_mortem()
sys.exit(1)
build_progress.stop_task(linkml_task) build_progress.stop_task(linkml_task)
if linkml_task is not None: if linkml_task is not None:
build_progress.update(linkml_task, action="[bold red]LinkML Build Failed") build_progress.update(linkml_task, action="[bold red]LinkML Build Failed")
@ -205,6 +213,7 @@ def parser() -> ArgumentParser:
), ),
action="store_true", action="store_true",
) )
parser.add_argument("--pdb", help="Launch debugger on an error", action="store_true")
return parser return parser
@ -222,7 +231,7 @@ def main():
generate_core_yaml(args.yaml, args.dry_run, args.hdmf) generate_core_yaml(args.yaml, args.dry_run, args.hdmf)
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run) generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
else: else:
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf) generate_versions(args.yaml, args.pydantic, args.dry_run, repo, args.hdmf, pdb=args.pdb)
if __name__ == "__main__": if __name__ == "__main__":