working complete, strict validating io :)

This commit is contained in:
sneakers-the-rat 2024-09-26 01:02:16 -07:00
parent 886d3db860
commit f9f1d49fca
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
8 changed files with 180 additions and 74 deletions

View file

@ -354,3 +354,40 @@ def defaults(cls: Dataset | Attribute) -> dict:
ret["ifabsent"] = cls.default_value
return ret
def is_container(group: Group) -> bool:
"""
Check if a group is a container group.
i.e. a group that...
* has no name
* multivalued quantity
* has a ``neurodata_type_inc``
* has no ``neurodata_type_def``
* has no sub-groups
* has no datasets
* has no attributes
Examples:
.. code-block:: yaml
- name: templates
groups:
- neurodata_type_inc: TimeSeries
doc: TimeSeries objects containing template data of presented stimuli.
quantity: '*'
- neurodata_type_inc: Images
doc: Images objects containing images of presented stimuli.
quantity: '*'
"""
return (
not group.name
and group.quantity == "*"
and group.neurodata_type_inc
and not group.neurodata_type_def
and not group.datasets
and not group.groups
and not group.attributes
)

View file

@ -2,11 +2,11 @@
Adapter for NWB groups to linkml Classes
"""
from typing import List, Type
from typing import Type
from linkml_runtime.linkml_model import SlotDefinition
from nwb_linkml.adapters.adapter import BuildResult
from nwb_linkml.adapters.adapter import BuildResult, is_container
from nwb_linkml.adapters.classes import ClassAdapter
from nwb_linkml.adapters.dataset import DatasetAdapter
from nwb_linkml.maps import QUANTITY_MAP
@ -45,19 +45,21 @@ class GroupAdapter(ClassAdapter):
):
return self.handle_container_slot(self.cls)
nested_res = self.build_subclasses()
# add links
links = self.build_links()
nested_res = self.build_datasets()
nested_res += self.build_groups()
nested_res += self.build_links()
nested_res += self.build_containers()
nested_res += self.build_special_cases()
# we don't propagate slots up to the next level since they are meant for this
# level (ie. a way to refer to our children)
res = self.build_base(extra_attrs=nested_res.slots + links)
res = self.build_base(extra_attrs=nested_res.slots)
# we do propagate classes tho
res.classes.extend(nested_res.classes)
return res
def build_links(self) -> List[SlotDefinition]:
def build_links(self) -> BuildResult:
"""
Build links specified in the ``links`` field as slots that refer to other
classes, with an additional annotation specifying that they are in fact links.
@ -66,7 +68,7 @@ class GroupAdapter(ClassAdapter):
file hierarchy as a string.
"""
if not self.cls.links:
return []
return BuildResult()
annotations = [{"tag": "source_type", "value": "link"}]
@ -83,7 +85,7 @@ class GroupAdapter(ClassAdapter):
)
for link in self.cls.links
]
return slots
return BuildResult(slots=slots)
def handle_container_group(self, cls: Group) -> BuildResult:
"""
@ -129,7 +131,7 @@ class GroupAdapter(ClassAdapter):
# We are a top-level container class like ProcessingModule
base = self.build_base()
# remove all the attributes and replace with child slot
base.classes[0].attributes.append(slot)
base.classes[0].attributes.update({slot.name: slot})
return base
def handle_container_slot(self, cls: Group) -> BuildResult:
@ -167,30 +169,88 @@ class GroupAdapter(ClassAdapter):
return BuildResult(slots=[slot])
def build_subclasses(self) -> BuildResult:
def build_datasets(self) -> BuildResult:
"""
Build nested groups and datasets
Create ClassDefinitions for each, but then also create SlotDefinitions that
will be used as attributes linking the main class to the subclasses
Datasets are simple, they are terminal classes, and all logic
for creating slots vs. classes is handled by the adapter class
"""
# Datasets are simple, they are terminal classes, and all logic
# for creating slots vs. classes is handled by the adapter class
dataset_res = BuildResult()
if self.cls.datasets:
for dset in self.cls.datasets:
dset_adapter = DatasetAdapter(cls=dset, parent=self)
dataset_res += dset_adapter.build()
return dataset_res
def build_groups(self) -> BuildResult:
"""
Build subgroups, excluding pure container subgroups
"""
group_res = BuildResult()
if self.cls.groups:
for group in self.cls.groups:
if is_container(group):
continue
group_adapter = GroupAdapter(cls=group, parent=self)
group_res += group_adapter.build()
res = dataset_res + group_res
return group_res
def build_containers(self) -> BuildResult:
"""
Build all container types into a single ``value`` slot
"""
res = BuildResult()
if not self.cls.groups:
return res
containers = [grp for grp in self.cls.groups if is_container(grp)]
if not containers:
return res
if len(containers) == 1:
range = {"range": containers[0].neurodata_type_inc}
description = containers[0].doc
else:
range = {"any_of": [{"range": subcls.neurodata_type_inc} for subcls in containers]}
description = "\n\n".join([grp.doc for grp in containers])
slot = SlotDefinition(
name="value",
multivalued=True,
inlined=True,
inlined_as_list=False,
description=description,
**range,
)
if self.debug: # pragma: no cover - only used in development
slot.annotations["group_adapter"] = {
"tag": "slot_adapter",
"value": "container_value_slot",
}
res.slots = [slot]
return res
def build_special_cases(self) -> BuildResult:
"""
Special cases, at this point just for NWBFile, which has
extra ``.specloc`` and ``specifications`` attrs
"""
res = BuildResult()
if self.cls.neurodata_type_def == "NWBFile":
res.slots = [
SlotDefinition(
name="specifications",
range="dict",
description="Nested dictionary of schema specifications",
),
]
return res
def build_self_slot(self) -> SlotDefinition:

View file

@ -15,7 +15,7 @@ from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation, NumpydanticArray
from linkml.generators.pydanticgen.build import ClassResult, SlotResult
from linkml.generators.pydanticgen.pydanticgen import SplitMode
from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule
from linkml.generators.pydanticgen.template import Import, Imports, PydanticModule, ObjectImport
from linkml_runtime.linkml_model.meta import (
ArrayExpression,
SchemaDefinition,
@ -30,6 +30,7 @@ from nwb_linkml.includes.base import (
BASEMODEL_COERCE_CHILD,
BASEMODEL_COERCE_VALUE,
BASEMODEL_GETITEM,
BASEMODEL_EXTRA_TO_VALUE,
)
from nwb_linkml.includes.hdmf import (
DYNAMIC_TABLE_IMPORTS,
@ -58,9 +59,15 @@ class NWBPydanticGenerator(PydanticGenerator):
BASEMODEL_COERCE_VALUE,
BASEMODEL_CAST_WITH_VALUE,
BASEMODEL_COERCE_CHILD,
BASEMODEL_EXTRA_TO_VALUE,
)
split: bool = True
imports: list[Import] = field(default_factory=lambda: [Import(module="numpy", alias="np")])
imports: list[Import] = field(
default_factory=lambda: [
Import(module="numpy", alias="np"),
Import(module="pydantic", objects=[ObjectImport(name="model_validator")]),
]
)
schema_map: Optional[Dict[str, SchemaDefinition]] = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""

View file

@ -3,7 +3,7 @@ Modifications to the ConfiguredBaseModel used by all generated classes
"""
BASEMODEL_GETITEM = """
def __getitem__(self, val: Union[int, slice]) -> Any:
def __getitem__(self, val: Union[int, slice, str]) -> Any:
\"\"\"Try and get a value from value or "data" if we have it\"\"\"
if hasattr(self, "value") and self.value is not None:
return self.value[val]
@ -64,3 +64,23 @@ BASEMODEL_COERCE_CHILD = """
pass
return v
"""
BASEMODEL_EXTRA_TO_VALUE = """
@model_validator(mode="before")
@classmethod
def gather_extra_to_value(cls, v: Any, handler) -> Any:
\"\"\"
For classes that don't allow extra fields and have a value slot,
pack those extra kwargs into ``value``
\"\"\"
if cls.model_config["extra"] == "forbid" and "value" in cls.model_fields and isinstance(v, dict):
extras = {key:val for key,val in v.items() if key not in cls.model_fields}
if extras:
for k in extras:
del v[k]
if "value" in v:
v["value"].update(extras)
else:
v["value"] = extras
return v
"""

View file

@ -35,7 +35,7 @@ import h5py
import networkx as nx
import numpy as np
from numpydantic.interface.hdf5 import H5ArrayPath
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from tqdm import tqdm
from nwb_linkml.maps.hdf5 import (
@ -166,24 +166,28 @@ def _load_node(
raise TypeError(f"Nodes can only be h5py Datasets and Groups, got {obj}")
if "neurodata_type" in obj.attrs:
# SPECIAL CASE: ignore `.specloc`
if ".specloc" in args:
del args[".specloc"]
model = provider.get_class(obj.attrs["namespace"], obj.attrs["neurodata_type"])
try:
return model(**args)
except ValidationError as e1:
# try to restack extra fields into ``value``
if "value" in model.model_fields:
value_dict = {
key: val for key, val in args.items() if key not in model.model_fields
}
for k in value_dict:
del args[k]
args["value"] = value_dict
try:
return model(**args)
except Exception as e2:
raise e2 from e1
else:
raise e1
# try:
return model(**args)
# except ValidationError as e1:
# # try to restack extra fields into ``value``
# if "value" in model.model_fields:
# value_dict = {
# key: val for key, val in args.items() if key not in model.model_fields
# }
# for k in value_dict:
# del args[k]
# args["value"] = value_dict
# try:
# return model(**args)
# except Exception as e2:
# raise e2 from e1
# else:
# raise e1
else:
if "name" in args:

View file

@ -39,6 +39,10 @@ def _make_dtypes() -> List[TypeDefinition]:
repr=linkml_reprs.get(nwbtype, None),
)
DTypeTypes.append(atype)
# a dict type!
DTypeTypes.append(TypeDefinition(name="dict", repr="dict"))
return DTypeTypes

View file

@ -80,7 +80,7 @@ def test_position(read_nwbfile, read_pynwb):
py_trials = read_pynwb.trials.to_dataframe()
pd.testing.assert_frame_equal(py_trials, trials)
spatial = read_nwbfile.processing["behavior"].Position.SpatialSeries
spatial = read_nwbfile.processing["behavior"]["Position"]["SpatialSeries"]
py_spatial = read_pynwb.processing["behavior"]["Position"]["SpatialSeries"]
_compare_attrs(spatial, py_spatial)
assert np.array_equal(spatial[:], py_spatial.data[:])

View file

@ -19,37 +19,6 @@ from nwb_linkml.providers import LinkMLProvider, PydanticProvider
from nwb_linkml.providers.git import NWB_CORE_REPO, HDMF_COMMON_REPO, GitRepo
from nwb_linkml.io import schema as io
def generate_core_yaml(output_path: Path, dry_run: bool = False, hdmf_only: bool = False):
"""Just build the latest version of the core schema"""
core = io.load_nwb_core(hdmf_only=hdmf_only)
built_schemas = core.build().schemas
for schema in built_schemas:
output_file = output_path / (schema.name + ".yaml")
if not dry_run:
yaml_dumper.dump(schema, output_file)
def generate_core_pydantic(yaml_path: Path, output_path: Path, dry_run: bool = False):
"""Just generate the latest version of the core schema"""
for schema in yaml_path.glob("*.yaml"):
python_name = schema.stem.replace(".", "_").replace("-", "_")
pydantic_file = (output_path / python_name).with_suffix(".py")
generator = NWBPydanticGenerator(
str(schema),
pydantic_version="2",
emit_metadata=True,
gen_classvars=True,
gen_slots=True,
)
gen_pydantic = generator.serialize()
if not dry_run:
with open(pydantic_file, "w") as pfile:
pfile.write(gen_pydantic)
def make_tmp_dir(clear: bool = False) -> Path:
# use a directory underneath this one as the temporary directory rather than
# the default hidden one
@ -68,6 +37,7 @@ def generate_versions(
dry_run: bool = False,
repo: GitRepo = NWB_CORE_REPO,
pdb=False,
latest: bool = False,
):
"""
Generate linkml models for all versions
@ -82,8 +52,13 @@ def generate_versions(
failed_versions = {}
if latest:
versions = [repo.namespace.versions[-1]]
else:
versions = repo.namespace.versions
overall_progress = Progress()
overall_task = overall_progress.add_task("All Versions", total=len(NWB_CORE_REPO.versions))
overall_task = overall_progress.add_task("All Versions", total=len(versions))
build_progress = Progress(
TextColumn(
@ -100,7 +75,7 @@ def generate_versions(
linkml_task = None
pydantic_task = None
for version in repo.namespace.versions:
for version in versions:
# build linkml
try:
# check out the version (this should also refresh the hdmf-common schema)
@ -251,11 +226,10 @@ def main():
if not args.dry_run:
args.yaml.mkdir(exist_ok=True)
args.pydantic.mkdir(exist_ok=True)
if args.latest:
generate_core_yaml(args.yaml, args.dry_run)
generate_core_pydantic(args.yaml, args.pydantic, args.dry_run)
else:
generate_versions(args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb)
generate_versions(
args.yaml, args.pydantic, args.dry_run, repo, pdb=args.pdb, latest=args.latest
)
if __name__ == "__main__":