working complete, strict validating io :)

This commit is contained in:
sneakers-the-rat 2024-09-26 01:02:16 -07:00
parent 886d3db860
commit f9f1d49fca
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
8 changed files with 180 additions and 74 deletions

View file

@ -354,3 +354,40 @@ def defaults(cls: Dataset | Attribute) -> dict:
ret["ifabsent"] = cls.default_value ret["ifabsent"] = cls.default_value
return ret return ret
def is_container(group: Group) -> bool:
"""
Check if a group is a container group.
i.e. a group that...
* has no name
* multivalued quantity
* has a ``neurodata_type_inc``
* has no ``neurodata_type_def``
* has no sub-groups
* has no datasets
* has no attributes
Examples:
.. code-block:: yaml
- name: templates
groups:
- neurodata_type_inc: TimeSeries
doc: TimeSeries objects containing template data of presented stimuli.
quantity: '*'
- neurodata_type_inc: Images
doc: Images objects containing images of presented stimuli.
quantity: '*'
"""
return (
not group.name
and group.quantity == "*"
and group.neurodata_type_inc
and not group.neurodata_type_def
and not group.datasets
and not group.groups
and not group.attributes
)

View file

@ -2,11 +2,11 @@
Adapter for NWB groups to linkml Classes Adapter for NWB groups to linkml Classes
""" """
from typing import List, Type from typing import Type
from linkml_runtime.linkml_model import SlotDefinition from linkml_runtime.linkml_model import SlotDefinition
from nwb_linkml.adapters.adapter import BuildResult from nwb_linkml.adapters.adapter import BuildResult, is_container
from nwb_linkml.adapters.classes import ClassAdapter from nwb_linkml.adapters.classes import ClassAdapter
from nwb_linkml.adapters.dataset import DatasetAdapter from nwb_linkml.adapters.dataset import DatasetAdapter
from nwb_linkml.maps import QUANTITY_MAP from nwb_linkml.maps import QUANTITY_MAP
@ -45,19 +45,21 @@ class GroupAdapter(ClassAdapter):
): ):
return self.handle_container_slot(self.cls) return self.handle_container_slot(self.cls)
nested_res = self.build_subclasses() nested_res = self.build_datasets()
# add links nested_res += self.build_groups()
links = self.build_links() nested_res += self.build_links()
nested_res += self.build_containers()
nested_res += self.build_special_cases()
# we don't propagate slots up to the next level since they are meant for this # we don't propagate slots up to the next level since they are meant for this
# level (ie. a way to refer to our children) # level (ie. a way to refer to our children)
res = self.build_base(extra_attrs=nested_res.slots + links) res = self.build_base(extra_attrs=nested_res.slots)
# we do propagate classes tho # we do propagate classes tho
res.classes.extend(nested_res.classes) res.classes.extend(nested_res.classes)
return res return res
def build_links(self) -> List[SlotDefinition]: def build_links(self) -> BuildResult:
""" """
Build links specified in the ``links`` field as slots that refer to other Build links specified in the ``links`` field as slots that refer to other
classes, with an additional annotation specifying that they are in fact links. classes, with an additional annotation specifying that they are in fact links.
@ -66,7 +68,7 @@ class GroupAdapter(ClassAdapter):
file hierarchy as a string. file hierarchy as a string.
""" """
if not self.cls.links: if not self.cls.links:
return [] return BuildResult()
annotations = [{"tag": "source_type", "value": "link"}] annotations = [{"tag": "source_type", "value": "link"}]
@ -83,7 +85,7 @@ class GroupAdapter(ClassAdapter):
) )
for link in self.cls.links for link in self.cls.links
] ]
return slots return BuildResult(slots=slots)
def handle_container_group(self, cls: Group) -> BuildResult: def handle_container_group(self, cls: Group) -> BuildResult:
""" """
@ -129,7 +131,7 @@ class GroupAdapter(ClassAdapter):
# We are a top-level container class like ProcessingModule # We are a top-level container class like ProcessingModule
base = self.build_base() base = self.build_base()
# remove all the attributes and replace with child slot # remove all the attributes and replace with child slot
base.classes[0].attributes.append(slot) base.classes[0].attributes.update({slot.name: slot})
return base return base
def handle_container_slot(self, cls: Group) -> BuildResult: def handle_container_slot(self, cls: Group) -> BuildResult:
@ -167,30 +169,88 @@ class GroupAdapter(ClassAdapter):
return BuildResult(slots=[slot]) return BuildResult(slots=[slot])
def build_subclasses(self) -> BuildResult: def build_datasets(self) -> BuildResult:
""" """
Build nested groups and datasets Build nested groups and datasets
Create ClassDefinitions for each, but then also create SlotDefinitions that Create ClassDefinitions for each, but then also create SlotDefinitions that
will be used as attributes linking the main class to the subclasses will be used as attributes linking the main class to the subclasses
Datasets are simple, they are terminal classes, and all logic
for creating slots vs. classes is handled by the adapter class
""" """
# Datasets are simple, they are terminal classes, and all logic
# for creating slots vs. classes is handled by the adapter class
dataset_res = BuildResult() dataset_res = BuildResult()
if self.cls.datasets: if self.cls.datasets:
for dset in self.cls.datasets: for dset in self.cls.datasets:
dset_adapter = DatasetAdapter(cls=dset, parent=self) dset_adapter = DatasetAdapter(cls=dset, parent=self)
dataset_res += dset_adapter.build() dataset_res += dset_adapter.build()
return dataset_res
def build_groups(self) -> BuildResult:
"""
Build subgroups, excluding pure container subgroups
"""
group_res = BuildResult() group_res = BuildResult()
if self.cls.groups: if self.cls.groups:
for group in self.cls.groups: for group in self.cls.groups:
if is_container(group):
continue
group_adapter = GroupAdapter(cls=group, parent=self) group_adapter = GroupAdapter(cls=group, parent=self)
group_res += group_adapter.build() group_res += group_adapter.build()
res = dataset_res + group_res return group_res
def build_containers(self) -> BuildResult:
"""
Build all container types into a single ``value`` slot
"""
res = BuildResult()
if not self.cls.groups:
return res
containers = [grp for grp in self.cls.groups if is_container(grp)]
if not containers:
return res
if len(containers) == 1:
range = {"range": containers[0].neurodata_type_inc}
description = containers[0].doc
else:
range = {"any_of": [{"range": subcls.neurodata_type_inc} for subcls in containers]}
description = "\n\n".join([grp.doc for grp in containers])
slot = SlotDefinition(
name="value",
multivalued=True,
inlined=True,
inlined_as_list=False,
description=description,
**range,
)
if self.debug: # pragma: no cover - only used in development
slot.annotations["group_adapter"] = {
"tag": "slot_adapter",
"value": "container_value_slot",
}
res.slots = [slot]
return res
def build_special_cases(self) -> BuildResult:
"""
Special cases, at this point just for NWBFile, which has
extra ``.specloc`` and ``specifications`` attrs
"""
res = BuildResult()
if self.cls.neurodata_type_def == "NWBFile":
res.slots = [
SlotDefinition(
name="specifications",
range="dict",
description="Nested dictionary of schema specifications",
),
]
return res return res
def build_self_slot(self) -> SlotDefinition: def build_self_slot(self) -> SlotDefinition:

View file

@ -15,7 +15,7 @@ from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
from linkml.generators.pydanticgen.build import ClassResult, SlotResult from linkml.generators.pydanticgen.build import ClassResult, SlotResult
from linkml.generators.pydanticgen.pydanticgen import SplitMode from linkml.generators.pydanticgen.pydanticgen import SplitMode
from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule, ObjectImport
from linkml_runtime.linkml_model.meta import ( from linkml_runtime.linkml_model.meta import (
ArrayExpression, ArrayExpression,
SchemaDefinition, SchemaDefinition,
@ -30,6 +30,7 @@ from nwb_linkml.includes.base import (
BASEMODEL_COERCE_CHILD, BASEMODEL_COERCE_CHILD,
BASEMODEL_COERCE_VALUE, BASEMODEL_COERCE_VALUE,
BASEMODEL_GETITEM, BASEMODEL_GETITEM,
BASEMODEL_EXTRA_TO_VALUE,
) )
from nwb_linkml.includes.hdmf import ( from nwb_linkml.includes.hdmf import (
DYNAMIC_TABLE_IMPORTS, DYNAMIC_TABLE_IMPORTS,
@ -58,9 +59,15 @@ class NWBPydanticGenerator(PydanticGenerator):
BASEMODEL_COERCE_VALUE, BASEMODEL_COERCE_VALUE,
BASEMODEL_CAST_WITH_VALUE, BASEMODEL_CAST_WITH_VALUE,
BASEMODEL_COERCE_CHILD, BASEMODEL_COERCE_CHILD,
BASEMODEL_EXTRA_TO_VALUE,
) )
split: bool = True split: bool = True
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")]) imports: list[Import] = field(
default_factory=lambda: [
Import(module="numpy", alias="np"),
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
]
)
schema_map: Optional[Dict[str, SchemaDefinition]] = None schema_map: Optional[Dict[str, SchemaDefinition]] = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from""" """See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""

View file

@ -3,7 +3,7 @@ Modifications to the ConfiguredBaseModel used by all generated classes
""" """
BASEMODEL_GETITEM = """ BASEMODEL_GETITEM = """
def __getitem__(self, val: Union[int, slice]) -> Any: def __getitem__(self, val: Union[int, slice, str]) -> Any:
\"\"\"Try and get a value from value or "data" if we have it\"\"\" \"\"\"Try and get a value from value or "data" if we have it\"\"\"
if hasattr(self, "value") and self.value is not None: if hasattr(self, "value") and self.value is not None:
return self.value[val] return self.value[val]
@ -64,3 +64,23 @@ BASEMODEL_COERCE_CHILD = """
pass pass
return v return v
""" """
BASEMODEL_EXTRA_TO_VALUE = """
@model_validator(mode="before")
@classmethod
def gather_extra_to_value(cls, v: Any, handler) -> Any:
\"\"\"
For classes that don't allow extra fields and have a value slot,
pack those extra kwargs into ``value``
\"\"\"
if cls.model_config["extra"] == "forbid" and "value" in cls.model_fields and isinstance(v, dict):
extras = {key:val for key,val in v.items() if key not in cls.model_fields}
if extras:
for k in extras:
del v[k]
if "value" in v:
v["value"].update(extras)
else:
v["value"] = extras
return v
"""

View file

@ -35,7 +35,7 @@ import h5py
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from numpydantic.interface.hdf5 import H5ArrayPath from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel, ValidationError from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
from nwb_linkml.maps.hdf5 import ( from nwb_linkml.maps.hdf5 import (
@ -166,24 +166,28 @@ def _load_node(
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}") raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
if "neurodata_type" in obj.attrs: if "neurodata_type" in obj.attrs:
# SPECIAL CASE: ignore `.specloc`
if ".specloc" in args:
del args[".specloc"]
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"]) model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
try: # try:
return model(**args) return model(**args)
except ValidationError as e1: # except ValidationError as e1:
# try to restack extra fields into ``value`` # # try to restack extra fields into ``value``
if "value" in model.model_fields: # if "value" in model.model_fields:
value_dict = { # value_dict = {
key: val for key, val in args.items() if key not in model.model_fields # key: val for key, val in args.items() if key not in model.model_fields
} # }
for k in value_dict: # for k in value_dict:
del args[k] # del args[k]
args["value"] = value_dict # args["value"] = value_dict
try: # try:
return model(**args) # return model(**args)
except Exception as e2: # except Exception as e2:
raise e2 from e1 # raise e2 from e1
else: # else:
raise e1 # raise e1
else: else:
if "name" in args: if "name" in args:

View file

@ -39,6 +39,10 @@ def _make_dtypes() -> List[TypeDefinition]:
repr=linkml_reprs.get(nwbtype, None), repr=linkml_reprs.get(nwbtype, None),
) )
DTypeTypes.append(atype) DTypeTypes.append(atype)
# a dict type!
DTypeTypes.append(TypeDefinition(name="dict", repr="dict"))
return DTypeTypes return DTypeTypes

View file

@ -80,7 +80,7 @@ def test_position(read_nwbfile, read_pynwb):
py_trials = read_pynwb.trials.to_dataframe() py_trials = read_pynwb.trials.to_dataframe()
pd.testing.assert_frame_equal(py_trials, trials) pd.testing.assert_frame_equal(py_trials, trials)
spatial = read_nwbfile.processing["behavior"].Position.SpatialSeries spatial = read_nwbfile.processing["behavior"]["Position"]["SpatialSeries"]
py_spatial = read_pynwb.processing["behavior"]["Position"]["SpatialSeries"] py_spatial = read_pynwb.processing["behavior"]["Position"]["SpatialSeries"]
_compare_attrs(spatial, py_spatial) _compare_attrs(spatial, py_spatial)
assert np.array_equal(spatial[:], py_spatial.data[:]) assert np.array_equal(spatial[:], py_spatial.data[:])

View file

@ -19,37 +19,6 @@ from nwb_linkml.providers import LinkMLProvider, PydanticProvider
from nwb_linkml.providers.git import NWB_CORE_REPO, HDMF_COMMON_REPO, GitRepo from nwb_linkml.providers.git import NWB_CORE_REPO, HDMF_COMMON_REPO, GitRepo
from nwb_linkml.io import schema as io from nwb_linkml.io import schema as io
def generate_core_yaml(output_path: Path, dry_run: bool = False, hdmf_only: bool = False):
"""Just build the latest version of the core schema"""
core = io.load_nwb_core(hdmf_only=hdmf_only)
built_schemas = core.build().schemas
for schema in built_schemas:
output_file = output_path / (schema.name + ".yaml")
if not dry_run:
yaml_dumper.dump(schema, output_file)
def generate_core_pydantic(yaml_path: Path, output_path: Path, dry_run: bool = False):
"""Just generate the latest version of the core schema"""
for schema in yaml_path.glob("*.yaml"):
python_name = schema.stem.replace(".", "_").replace("-", "_")
pydantic_file = (output_path / python_name).with_suffix(".py")
generator = NWBPydanticGenerator(
str(schema),
pydantic_version="2",
emit_metadata=True,
gen_classvars=True,
gen_slots=True,
)
gen_pydantic = generator.serialize()
if not dry_run:
with open(pydantic_file, "w") as pfile:
pfile.write(gen_pydantic)
def make_tmp_dir(clear: bool = False) -> Path: def make_tmp_dir(clear: bool = False) -> Path:
# use a directory underneath this one as the temporary directory rather than # use a directory underneath this one as the temporary directory rather than
# the default hidden one # the default hidden one
@ -68,6 +37,7 @@ def generate_versions(
dry_run: bool = False, dry_run: bool = False,
repo: GitRepo = NWB_CORE_REPO, repo: GitRepo = NWB_CORE_REPO,
pdb=False, pdb=False,
latest: bool = False,
): ):
""" """
Generate linkml models for all versions Generate linkml models for all versions
@ -82,8 +52,13 @@ def generate_versions(
failed_versions = {} failed_versions = {}
if latest:
versions = [repo.namespace.versions[-1]]
else:
versions = repo.namespace.versions
overall_progress = Progress() overall_progress = Progress()
overall_task = overall_progress.add_task("All Versions", total=len(NWB_CORE_REPO.versions)) overall_task = overall_progress.add_task("All Versions", total=len(versions))
build_progress = Progress( build_progress = Progress(
TextColumn( TextColumn(
@ -100,7 +75,7 @@ def generate_versions(
linkml_task = None linkml_task = None
pydantic_task = None pydantic_task = None
for version in repo.namespace.versions: for version in versions:
# build linkml # build linkml
try: try:
# check out the version (this should also refresh the hdmf-common schema) # check out the version (this should also refresh the hdmf-common schema)
@ -251,11 +226,10 @@ def main():
if not args.dry_run: if not args.dry_run:
args.yaml.mkdir(exist_ok=True) args.yaml.mkdir(exist_ok=True)
args.pydantic.mkdir(exist_ok=True) args.pydantic.mkdir(exist_ok=True)
if args.latest:
generate_core_yaml(args.yaml, args.dry_run) generate_versions(
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run) args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb, latest=args.latest
else: )
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb)
if __name__ == "__main__": if __name__ == "__main__":