mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-12 17:54:29 +00:00
add logging. less janky adapter instantiation using model validators. correctly propagate properties from ancestor classes when building
This commit is contained in:
parent
c09b633cda
commit
0452a4359f
15 changed files with 415 additions and 140 deletions
|
@ -74,7 +74,9 @@ addopts = [
|
|||
]
|
||||
markers = [
|
||||
"dev: tests that are just for development rather than testing correctness",
|
||||
"provider: tests for providers!"
|
||||
"provider: tests for providers!",
|
||||
"linkml: tests related to linkml generation",
|
||||
"pydantic: tests related to pydantic generation"
|
||||
]
|
||||
testpaths = [
|
||||
"src/nwb_linkml",
|
||||
|
|
|
@ -5,16 +5,8 @@ Base class for adapters
|
|||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Generator, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
from logging import Logger
|
||||
|
||||
from linkml_runtime.dumpers import yaml_dumper
|
||||
from linkml_runtime.linkml_model import (
|
||||
|
@ -26,6 +18,7 @@ from linkml_runtime.linkml_model import (
|
|||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nwb_linkml.logging import init_logger
|
||||
from nwb_schema_language import Attribute, CompoundDtype, Dataset, Group, Schema
|
||||
|
||||
if sys.version_info.minor >= 11:
|
||||
|
@ -107,6 +100,14 @@ class BuildResult:
|
|||
class Adapter(BaseModel):
|
||||
"""Abstract base class for adapters"""
|
||||
|
||||
_logger: Optional[Logger] = None
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
if self._logger is None:
|
||||
self._logger = init_logger(self.__class__.__name__)
|
||||
return self._logger
|
||||
|
||||
@abstractmethod
|
||||
def build(self) -> "BuildResult":
|
||||
"""
|
||||
|
@ -152,8 +153,8 @@ class Adapter(BaseModel):
|
|||
# SchemaAdapters that should be located under the same
|
||||
# NamespacesAdapter when it's important to query across SchemaAdapters,
|
||||
# so skip to avoid combinatoric walking
|
||||
if key == "imports" and type(input).__name__ == "SchemaAdapter":
|
||||
continue
|
||||
# if key == "imports" and type(input).__name__ == "SchemaAdapter":
|
||||
# continue
|
||||
val = getattr(input, key)
|
||||
yield (key, val)
|
||||
if isinstance(val, (BaseModel, dict, list)):
|
||||
|
@ -196,6 +197,14 @@ class Adapter(BaseModel):
|
|||
if isinstance(item, tuple) and item[0] in field and item[1] is not None:
|
||||
yield item[1]
|
||||
|
||||
@overload
|
||||
def walk_field_values(
|
||||
self,
|
||||
input: Union[BaseModel, dict, list],
|
||||
field: Literal["neurodata_type_def"],
|
||||
value: Optional[Any] = None,
|
||||
) -> Generator[Group | Dataset, None, None]: ...
|
||||
|
||||
def walk_field_values(
|
||||
self, input: Union[BaseModel, dict, list], field: str, value: Optional[Any] = None
|
||||
) -> Generator[BaseModel, None, None]:
|
||||
|
@ -248,6 +257,9 @@ def is_1d(cls: Dataset | Attribute) -> bool:
|
|||
* a single-layer dim/shape list of length 1, or
|
||||
* a nested dim/shape list where every nested spec is of length 1
|
||||
"""
|
||||
if cls.dims is None:
|
||||
return False
|
||||
|
||||
return (
|
||||
not any([isinstance(dim, list) for dim in cls.dims]) and len(cls.dims) == 1
|
||||
) or ( # nested list
|
||||
|
@ -270,4 +282,8 @@ def has_attrs(cls: Dataset) -> bool:
|
|||
"""
|
||||
Check if a dataset has any attributes at all without defaults
|
||||
"""
|
||||
return len(cls.attributes) > 0 and all([not a.value for a in cls.attributes])
|
||||
return (
|
||||
cls.attributes is not None
|
||||
and len(cls.attributes) > 0
|
||||
and all([not a.value for a in cls.attributes])
|
||||
)
|
||||
|
|
|
@ -119,9 +119,12 @@ class ClassAdapter(Adapter):
|
|||
Returns:
|
||||
list[:class:`.SlotDefinition`]
|
||||
"""
|
||||
results = [AttributeAdapter(cls=attr).build() for attr in cls.attributes]
|
||||
slots = [r.slots[0] for r in results]
|
||||
return slots
|
||||
if cls.attributes is not None:
|
||||
results = [AttributeAdapter(cls=attr).build() for attr in cls.attributes]
|
||||
slots = [r.slots[0] for r in results]
|
||||
return slots
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_full_name(self) -> str:
|
||||
"""The full name of the object in the generated linkml
|
||||
|
|
|
@ -784,6 +784,12 @@ class MapCompoundDtype(DatasetMap):
|
|||
Make a new class for this dtype, using its sub-dtypes as fields,
|
||||
and use it as the range for the parent class
|
||||
"""
|
||||
# all the slots share the same ndarray spec if there is one
|
||||
array = {}
|
||||
if cls.dims or cls.shape:
|
||||
array_adapter = ArrayAdapter(cls.dims, cls.shape)
|
||||
array = array_adapter.make_slot()
|
||||
|
||||
slots = {}
|
||||
for a_dtype in cls.dtype:
|
||||
slots[a_dtype.name] = SlotDefinition(
|
||||
|
@ -791,8 +797,13 @@ class MapCompoundDtype(DatasetMap):
|
|||
description=a_dtype.doc,
|
||||
range=handle_dtype(a_dtype.dtype),
|
||||
**QUANTITY_MAP[cls.quantity],
|
||||
**array,
|
||||
)
|
||||
res.classes[0].attributes.update(slots)
|
||||
|
||||
# the compound dtype replaces the ``value`` slot, if present
|
||||
if "value" in res.classes[0].attributes:
|
||||
del res.classes[0].attributes["value"]
|
||||
return res
|
||||
|
||||
|
||||
|
|
|
@ -13,13 +13,13 @@ from typing import Dict, List, Optional
|
|||
|
||||
from linkml_runtime.dumpers import yaml_dumper
|
||||
from linkml_runtime.linkml_model import Annotation, SchemaDefinition
|
||||
from pydantic import Field, PrivateAttr
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from nwb_linkml.adapters.adapter import Adapter, BuildResult
|
||||
from nwb_linkml.adapters.schema import SchemaAdapter
|
||||
from nwb_linkml.lang_elements import NwbLangSchema
|
||||
from nwb_linkml.ui import AdapterProgress
|
||||
from nwb_schema_language import Namespaces
|
||||
from nwb_schema_language import Namespaces, Group, Dataset
|
||||
|
||||
|
||||
class NamespacesAdapter(Adapter):
|
||||
|
@ -31,12 +31,6 @@ class NamespacesAdapter(Adapter):
|
|||
schemas: List[SchemaAdapter]
|
||||
imported: List["NamespacesAdapter"] = Field(default_factory=list)
|
||||
|
||||
_imports_populated: bool = PrivateAttr(False)
|
||||
|
||||
def __init__(self, **kwargs: dict):
|
||||
super().__init__(**kwargs)
|
||||
self._populate_schema_namespaces()
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: Path) -> "NamespacesAdapter":
|
||||
"""
|
||||
|
@ -70,8 +64,6 @@ class NamespacesAdapter(Adapter):
|
|||
"""
|
||||
Build the NWB namespace to the LinkML Schema
|
||||
"""
|
||||
if not self._imports_populated and not skip_imports:
|
||||
self.populate_imports()
|
||||
|
||||
sch_result = BuildResult()
|
||||
for sch in self.schemas:
|
||||
|
@ -129,6 +121,7 @@ class NamespacesAdapter(Adapter):
|
|||
|
||||
return sch_result
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _populate_schema_namespaces(self) -> None:
|
||||
"""
|
||||
annotate for each schema which namespace imports it
|
||||
|
@ -143,6 +136,7 @@ class NamespacesAdapter(Adapter):
|
|||
sch.namespace = ns.name
|
||||
sch.version = ns.version
|
||||
break
|
||||
return self
|
||||
|
||||
def find_type_source(self, name: str) -> SchemaAdapter:
|
||||
"""
|
||||
|
@ -182,7 +176,8 @@ class NamespacesAdapter(Adapter):
|
|||
else:
|
||||
raise KeyError(f"No schema found that define {name}")
|
||||
|
||||
def populate_imports(self) -> None:
|
||||
@model_validator(mode="after")
|
||||
def populate_imports(self) -> "NamespacesAdapter":
|
||||
"""
|
||||
Populate the imports that are needed for each schema file
|
||||
|
||||
|
@ -199,11 +194,46 @@ class NamespacesAdapter(Adapter):
|
|||
if depends_on not in sch.imports:
|
||||
sch.imports.append(depends_on)
|
||||
|
||||
# do so recursively
|
||||
for imported in self.imported:
|
||||
imported.populate_imports()
|
||||
return self
|
||||
|
||||
self._imports_populated = True
|
||||
@model_validator(mode="after")
|
||||
def _populate_inheritance(self):
|
||||
"""
|
||||
ensure properties from `neurodata_type_inc` are propaged through to inheriting classes.
|
||||
|
||||
This seems super expensive but we'll optimize for perf later if that proves to be the case
|
||||
"""
|
||||
# don't use walk_types here so we can replace the objects as we mutate them
|
||||
for sch in self.schemas:
|
||||
for i, group in enumerate(sch.groups):
|
||||
if getattr(group, "neurodata_type_inc", None) is not None:
|
||||
merged_attrs = self._merge_inheritance(group)
|
||||
sch.groups[i] = Group(**merged_attrs)
|
||||
for i, dataset in enumerate(sch.datasets):
|
||||
if getattr(dataset, "neurodata_type_inc", None) is not None:
|
||||
merged_attrs = self._merge_inheritance(dataset)
|
||||
sch.datasets[i] = Dataset(**merged_attrs)
|
||||
return self
|
||||
|
||||
def _merge_inheritance(self, obj: Group | Dataset) -> dict:
|
||||
obj_dict = obj.model_dump(exclude_none=True)
|
||||
if obj.neurodata_type_inc:
|
||||
name = obj.neurodata_type_def if obj.neurodata_type_def else obj.name
|
||||
self.logger.debug(f"Merging {name} with {obj.neurodata_type_inc}")
|
||||
# there must be only one type with this name
|
||||
parent: Group | Dataset = next(
|
||||
self.walk_field_values(self, "neurodata_type_def", obj.neurodata_type_inc)
|
||||
)
|
||||
if obj.neurodata_type_def == "TimeSeriesReferenceVectorData":
|
||||
pdb.set_trace()
|
||||
parent_dict = copy(self._merge_inheritance(parent))
|
||||
# children don't inherit the type_def
|
||||
del parent_dict["neurodata_type_def"]
|
||||
# overwrite with child values
|
||||
parent_dict.update(obj_dict)
|
||||
return parent_dict
|
||||
|
||||
return obj_dict
|
||||
|
||||
def to_yaml(self, base_dir: Path) -> None:
|
||||
"""
|
||||
|
|
|
@ -42,7 +42,8 @@ class SchemaAdapter(Adapter):
|
|||
"""
|
||||
The namespace.schema name for a single schema
|
||||
"""
|
||||
return ".".join([self.namespace, self.path.with_suffix("").name])
|
||||
namespace = self.namespace if self.namespace is not None else ""
|
||||
return ".".join([namespace, self.path.with_suffix("").name])
|
||||
|
||||
def __repr__(self):
|
||||
out_str = "\n" + self.name + "\n"
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
Manage the operation of nwb_linkml from environmental variables
|
||||
"""
|
||||
|
||||
from typing import Optional, Literal
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
DirectoryPath,
|
||||
Field,
|
||||
FieldValidationInfo,
|
||||
|
@ -15,15 +17,68 @@ from pydantic import (
|
|||
)
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
LOG_LEVELS = Literal["DEBUG", "INFO", "WARNING", "ERROR"]
|
||||
|
||||
|
||||
class LogConfig(BaseModel):
|
||||
"""
|
||||
Configuration for logging
|
||||
"""
|
||||
|
||||
level: LOG_LEVELS = "INFO"
|
||||
"""
|
||||
Severity of log messages to process.
|
||||
"""
|
||||
level_file: Optional[LOG_LEVELS] = None
|
||||
"""
|
||||
Severity for file-based logging. If unset, use ``level``
|
||||
"""
|
||||
level_stdout: Optional[LOG_LEVELS] = "WARNING"
|
||||
"""
|
||||
Severity for stream-based logging. If unset, use ``level``
|
||||
"""
|
||||
file_n: int = 5
|
||||
"""
|
||||
Number of log files to rotate through
|
||||
"""
|
||||
file_size: int = 2**22 # roughly 4MB
|
||||
"""
|
||||
Maximum size of log files (bytes)
|
||||
"""
|
||||
|
||||
@field_validator("level", "level_file", "level_stdout", mode="before")
|
||||
@classmethod
|
||||
def uppercase_levels(cls, value: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Ensure log level strings are uppercased
|
||||
"""
|
||||
if value is not None:
|
||||
value = value.upper()
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def inherit_base_level(self) -> "LogConfig":
|
||||
"""
|
||||
If loglevels for specific output streams are unset, set from base :attr:`.level`
|
||||
"""
|
||||
levels = ("level_file", "level_stdout")
|
||||
for level_name in levels:
|
||||
if getattr(self, level_name) is None:
|
||||
setattr(self, level_name, self.level)
|
||||
return self
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""
|
||||
Configuration for nwb_linkml, populated by default but can be overridden
|
||||
by environment variables.
|
||||
|
||||
Nested models can be assigned from .env files with a __ (see examples)
|
||||
|
||||
Examples:
|
||||
|
||||
export NWB_LINKML_CACHE_DIR="/home/mycache/dir"
|
||||
export NWB_LINKML_LOGS__LEVEL="debug"
|
||||
|
||||
"""
|
||||
|
||||
|
@ -32,6 +87,11 @@ class Config(BaseSettings):
|
|||
default_factory=lambda: Path(tempfile.gettempdir()) / "nwb_linkml__cache",
|
||||
description="Location to cache generated schema and models",
|
||||
)
|
||||
log_dir: Path = Field(
|
||||
Path("logs"),
|
||||
description="Location to store logs. If a relative directory, relative to ``cache_dir``",
|
||||
)
|
||||
logs: LogConfig = Field(LogConfig(), description="Log configuration")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
@ -62,6 +122,15 @@ class Config(BaseSettings):
|
|||
assert v.exists()
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def log_dir_relative_to_cache_dir(self) -> "Config":
|
||||
"""
|
||||
If log dir is relative, put it beneath the cache_dir
|
||||
"""
|
||||
if not self.log_dir.is_absolute():
|
||||
self.log_dir = self.cache_dir / self.log_dir
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def folders_exist(self) -> "Config":
|
||||
"""
|
||||
|
|
|
@ -70,6 +70,7 @@ def load_namespace_adapter(
|
|||
namespace: Path | NamespaceRepo | Namespaces,
|
||||
path: Optional[Path] = None,
|
||||
version: Optional[str] = None,
|
||||
imported: Optional[list[NamespacesAdapter]] = None,
|
||||
) -> NamespacesAdapter:
|
||||
"""
|
||||
Load all schema referenced by a namespace file
|
||||
|
@ -115,7 +116,10 @@ def load_namespace_adapter(
|
|||
yml_file = (path / schema.source).resolve()
|
||||
sch.append(load_schema_file(yml_file))
|
||||
|
||||
adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch)
|
||||
if imported is not None:
|
||||
adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch, imported=imported)
|
||||
else:
|
||||
adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch)
|
||||
|
||||
return adapter
|
||||
|
||||
|
@ -148,8 +152,6 @@ def load_nwb_core(
|
|||
if hdmf_only:
|
||||
schema = hdmf_schema
|
||||
else:
|
||||
schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version)
|
||||
|
||||
schema.imported.append(hdmf_schema)
|
||||
schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version, imported=[hdmf_schema])
|
||||
|
||||
return schema
|
||||
|
|
100
nwb_linkml/src/nwb_linkml/logging.py
Normal file
100
nwb_linkml/src/nwb_linkml/logging.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
"""
|
||||
Logging factory and handlers
|
||||
"""
|
||||
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from nwb_linkml.config import LOG_LEVELS, Config
|
||||
|
||||
|
||||
def init_logger(
|
||||
name: str,
|
||||
log_dir: Union[Optional[Path], bool] = None,
|
||||
level: Optional[LOG_LEVELS] = None,
|
||||
file_level: Optional[LOG_LEVELS] = None,
|
||||
log_file_n: Optional[int] = None,
|
||||
log_file_size: Optional[int] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Make a logger.
|
||||
|
||||
Log to a set of rotating files in the ``log_dir`` according to ``name`` ,
|
||||
as well as using the :class:`~rich.RichHandler` for pretty-formatted stdout logs.
|
||||
|
||||
Args:
|
||||
name (str): Name of this logger. Ideally names are hierarchical
|
||||
and indicate what they are logging for, eg. ``miniscope_io.sdcard``
|
||||
and don't contain metadata like timestamps, etc. (which are in the logs)
|
||||
log_dir (:class:`pathlib.Path`): Directory to store file-based logs in. If ``None``,
|
||||
get from :class:`.Config`. If ``False`` , disable file logging.
|
||||
level (:class:`.LOG_LEVELS`): Level to use for stdout logging. If ``None`` ,
|
||||
get from :class:`.Config`
|
||||
file_level (:class:`.LOG_LEVELS`): Level to use for file-based logging.
|
||||
If ``None`` , get from :class:`.Config`
|
||||
log_file_n (int): Number of rotating file logs to use.
|
||||
If ``None`` , get from :class:`.Config`
|
||||
log_file_size (int): Maximum size of logfiles before rotation.
|
||||
If ``None`` , get from :class:`.Config`
|
||||
|
||||
Returns:
|
||||
:class:`logging.Logger`
|
||||
"""
|
||||
config = Config()
|
||||
if log_dir is None:
|
||||
log_dir = config.log_dir
|
||||
if level is None:
|
||||
level = config.logs.level_stdout
|
||||
if file_level is None:
|
||||
file_level = config.logs.level_file
|
||||
if log_file_n is None:
|
||||
log_file_n = config.logs.file_n
|
||||
if log_file_size is None:
|
||||
log_file_size = config.logs.file_size
|
||||
|
||||
if not name.startswith("nwb_linkml"):
|
||||
name = "nwb_linkml." + name
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# Add handlers for stdout and file
|
||||
if log_dir is not False:
|
||||
logger.addHandler(_file_handler(name, file_level, log_dir, log_file_n, log_file_size))
|
||||
|
||||
logger.addHandler(_rich_handler())
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def _file_handler(
|
||||
name: str,
|
||||
file_level: LOG_LEVELS,
|
||||
log_dir: Path,
|
||||
log_file_n: int = 5,
|
||||
log_file_size: int = 2**22,
|
||||
) -> RotatingFileHandler:
|
||||
# See init_logger for arg docs
|
||||
|
||||
filename = Path(log_dir) / ".".join([name, "log"])
|
||||
file_handler = RotatingFileHandler(
|
||||
str(filename), mode="a", maxBytes=log_file_size, backupCount=log_file_n
|
||||
)
|
||||
file_formatter = logging.Formatter("[%(asctime)s] %(levelname)s [%(name)s]: %(message)s")
|
||||
file_handler.setLevel(file_level)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
return file_handler
|
||||
|
||||
|
||||
def _rich_handler() -> RichHandler:
|
||||
rich_handler = RichHandler(rich_tracebacks=True, markup=True)
|
||||
rich_formatter = logging.Formatter(
|
||||
"[bold green]\[%(name)s][/bold green] %(message)s",
|
||||
datefmt="[%y-%m-%dT%H:%M:%S]",
|
||||
)
|
||||
rich_handler.setFormatter(rich_formatter)
|
||||
return rich_handler
|
|
@ -82,7 +82,6 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path:
|
|||
@pytest.fixture(scope="session", params=[{"core_version": "2.7.0", "hdmf_version": "1.8.0"}])
|
||||
def nwb_core_fixture(request) -> NamespacesAdapter:
|
||||
nwb_core = io.load_nwb_core(**request.param)
|
||||
nwb_core.populate_imports()
|
||||
assert (
|
||||
request.param["core_version"] in nwb_core.versions["core"]
|
||||
) # 2.6.0 is actually 2.6.0-alpha
|
||||
|
|
|
@ -46,3 +46,17 @@ def test_skip_imports(nwb_core_fixture):
|
|||
# we shouldn't have any of the hdmf-common schema in with us
|
||||
namespaces = [sch.annotations["namespace"].value for sch in res.schemas]
|
||||
assert all([ns == "core" for ns in namespaces])
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_populate_inheritance(nwb_core_fixture):
|
||||
"""
|
||||
Classes should receive and override the properties of their parents
|
||||
when they have neurodata_type_inc
|
||||
Args:
|
||||
nwb_core_fixture:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -76,6 +76,7 @@ def test_generate_pydantic(tmp_output_dir):
|
|||
initfile.write("# Autogenerated module indicator")
|
||||
|
||||
|
||||
@pytest.mark.linkml
|
||||
@pytest.mark.provider
|
||||
@pytest.mark.dev
|
||||
def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
|
||||
|
@ -84,6 +85,7 @@ def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
|
|||
result = provider.build(nwb_core_fixture)
|
||||
|
||||
|
||||
@pytest.mark.pydantic
|
||||
@pytest.mark.provider
|
||||
@pytest.mark.dev
|
||||
def test_generate_pydantic_provider(tmp_output_dir):
|
||||
|
|
120
nwb_linkml/tests/test_includes/conftest.py
Normal file
120
nwb_linkml/tests/test_includes/conftest.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from nwb_linkml.models import (
|
||||
ElectricalSeries,
|
||||
ExtracellularEphysElectrodes,
|
||||
Device,
|
||||
ElectrodeGroup,
|
||||
DynamicTableRegion,
|
||||
Units,
|
||||
IntracellularElectrode,
|
||||
IntracellularElectrodesTable,
|
||||
IntracellularResponsesTable,
|
||||
IntracellularStimuliTable,
|
||||
IntracellularRecordingsTable,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrodes"]:
|
||||
"""
|
||||
Demo electrical series with adjoining electrodes
|
||||
"""
|
||||
n_electrodes = 5
|
||||
n_times = 100
|
||||
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float)
|
||||
timestamps = np.linspace(0, 1, n_times)
|
||||
|
||||
device = Device(name="my electrode")
|
||||
|
||||
# electrode group is the physical description of the electrodes
|
||||
electrode_group = ElectrodeGroup(
|
||||
name="GroupA",
|
||||
device=device,
|
||||
description="an electrode group",
|
||||
location="you know where it is",
|
||||
)
|
||||
|
||||
# make electrodes tables
|
||||
electrodes = ExtracellularEphysElectrodes(
|
||||
description="idk these are also electrodes",
|
||||
id=np.arange(0, n_electrodes),
|
||||
x=np.arange(0, n_electrodes).astype(float),
|
||||
y=np.arange(n_electrodes, n_electrodes * 2).astype(float),
|
||||
group=[electrode_group] * n_electrodes,
|
||||
group_name=[electrode_group.name] * n_electrodes,
|
||||
location=[str(i) for i in range(n_electrodes)],
|
||||
extra_column=["sup"] * n_electrodes,
|
||||
)
|
||||
|
||||
electrical_series = ElectricalSeries(
|
||||
name="my recording!",
|
||||
electrodes=DynamicTableRegion(
|
||||
table=electrodes,
|
||||
value=np.arange(n_electrodes - 1, -1, step=-1),
|
||||
name="electrodes",
|
||||
description="hey",
|
||||
),
|
||||
timestamps=timestamps,
|
||||
data=data,
|
||||
)
|
||||
return electrical_series, electrodes
|
||||
|
||||
|
||||
def _ragged_array(n_units: int) -> tuple[list[np.ndarray], np.ndarray]:
|
||||
generator = np.random.default_rng()
|
||||
spike_times = [
|
||||
np.full(shape=generator.integers(10, 50), fill_value=i, dtype=float) for i in range(n_units)
|
||||
]
|
||||
spike_idx = []
|
||||
for i in range(n_units):
|
||||
if i == 0:
|
||||
spike_idx.append(len(spike_times[0]))
|
||||
else:
|
||||
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
|
||||
spike_idx = np.array(spike_idx)
|
||||
return spike_times, spike_idx
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
Test case for units
|
||||
|
||||
Parameterized by extra_column because pandas likes to pivot dataframes
|
||||
to long when there is only one column and it's not len() == 1
|
||||
"""
|
||||
spike_times, spike_idx = _ragged_array(24)
|
||||
|
||||
spike_times_flat = np.concatenate(spike_times)
|
||||
|
||||
kwargs = {
|
||||
"description": "units!!!!",
|
||||
"spike_times": spike_times_flat,
|
||||
"spike_times_index": spike_idx,
|
||||
}
|
||||
if request.param:
|
||||
kwargs["extra_column"] = ["hey!"] * 24
|
||||
units = Units(**kwargs)
|
||||
return units, spike_times, spike_idx
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def intracellular_recordings_table() -> IntracellularRecordingsTable:
|
||||
n_recordings = 10
|
||||
device = Device(name="my device")
|
||||
electrode = IntracellularElectrode(
|
||||
name="my_electrode", description="an electrode", device=device
|
||||
)
|
||||
electrodes = IntracellularElectrodesTable(
|
||||
name="intracellular_electrodes", electrode=[electrode] * n_recordings
|
||||
)
|
||||
stimuli = IntracellularStimuliTable(
|
||||
name="intracellular_stimuli",
|
||||
)
|
||||
responses = IntracellularResponsesTable()
|
||||
|
||||
recordings_table = IntracellularRecordingsTable()
|
|
@ -1,103 +1,13 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# FIXME: Make this just be the output of the provider by patching into import machinery
|
||||
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
|
||||
Device,
|
||||
DynamicTable,
|
||||
DynamicTableRegion,
|
||||
ElectricalSeries,
|
||||
ElectrodeGroup,
|
||||
ExtracellularEphysElectrodes,
|
||||
Units,
|
||||
VectorIndex,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrodes"]:
|
||||
"""
|
||||
Demo electrical series with adjoining electrodes
|
||||
"""
|
||||
n_electrodes = 5
|
||||
n_times = 100
|
||||
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float)
|
||||
timestamps = np.linspace(0, 1, n_times)
|
||||
|
||||
device = Device(name="my electrode")
|
||||
|
||||
# electrode group is the physical description of the electrodes
|
||||
electrode_group = ElectrodeGroup(
|
||||
name="GroupA",
|
||||
device=device,
|
||||
description="an electrode group",
|
||||
location="you know where it is",
|
||||
)
|
||||
|
||||
# make electrodes tables
|
||||
electrodes = ExtracellularEphysElectrodes(
|
||||
description="idk these are also electrodes",
|
||||
id=np.arange(0, n_electrodes),
|
||||
x=np.arange(0, n_electrodes).astype(float),
|
||||
y=np.arange(n_electrodes, n_electrodes * 2).astype(float),
|
||||
group=[electrode_group] * n_electrodes,
|
||||
group_name=[electrode_group.name] * n_electrodes,
|
||||
location=[str(i) for i in range(n_electrodes)],
|
||||
extra_column=["sup"] * n_electrodes,
|
||||
)
|
||||
|
||||
electrical_series = ElectricalSeries(
|
||||
name="my recording!",
|
||||
electrodes=DynamicTableRegion(
|
||||
table=electrodes,
|
||||
value=np.arange(n_electrodes - 1, -1, step=-1),
|
||||
name="electrodes",
|
||||
description="hey",
|
||||
),
|
||||
timestamps=timestamps,
|
||||
data=data,
|
||||
)
|
||||
return electrical_series, electrodes
|
||||
|
||||
|
||||
def _ragged_array(n_units: int) -> tuple[list[np.ndarray], np.ndarray]:
|
||||
generator = np.random.default_rng()
|
||||
spike_times = [
|
||||
np.full(shape=generator.integers(10, 50), fill_value=i, dtype=float) for i in range(n_units)
|
||||
]
|
||||
spike_idx = []
|
||||
for i in range(n_units):
|
||||
if i == 0:
|
||||
spike_idx.append(len(spike_times[0]))
|
||||
else:
|
||||
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
|
||||
spike_idx = np.array(spike_idx)
|
||||
return spike_times, spike_idx
|
||||
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
Test case for units
|
||||
|
||||
Parameterized by extra_column because pandas likes to pivot dataframes
|
||||
to long when there is only one column and it's not len() == 1
|
||||
"""
|
||||
spike_times, spike_idx = _ragged_array(24)
|
||||
|
||||
spike_times_flat = np.concatenate(spike_times)
|
||||
|
||||
kwargs = {
|
||||
"description": "units!!!!",
|
||||
"spike_times": spike_times_flat,
|
||||
"spike_times_index": spike_idx,
|
||||
}
|
||||
if request.param:
|
||||
kwargs["extra_column"] = ["hey!"] * 24
|
||||
units = Units(**kwargs)
|
||||
return units, spike_times, spike_idx
|
||||
from .conftest import _ragged_array
|
||||
|
||||
|
||||
def test_dynamictable_indexing(electrical_series):
|
||||
|
|
|
@ -220,8 +220,8 @@ class DtypeMixin(ConfiguredBaseModel):
|
|||
class Attribute(DtypeMixin):
|
||||
|
||||
name: str = Field(...)
|
||||
dims: Optional[List[Union[Any, str]]] = Field(default_factory=list)
|
||||
shape: Optional[List[Union[Any, int, str]]] = Field(default_factory=list)
|
||||
dims: Optional[List[Union[Any, str]]] = Field(None)
|
||||
shape: Optional[List[Union[Any, int, str]]] = Field(None)
|
||||
value: Optional[Any] = Field(
|
||||
None, description="""Optional constant, fixed value for the attribute."""
|
||||
)
|
||||
|
@ -233,9 +233,7 @@ class Attribute(DtypeMixin):
|
|||
True,
|
||||
description="""Optional boolean key describing whether the attribute is required. Default value is True.""",
|
||||
)
|
||||
dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(None)
|
||||
|
||||
|
||||
class Dataset(DtypeMixin):
|
||||
|
@ -250,8 +248,8 @@ class Dataset(DtypeMixin):
|
|||
)
|
||||
name: Optional[str] = Field(None)
|
||||
default_name: Optional[str] = Field(None)
|
||||
dims: Optional[List[Union[Any, str]]] = Field(default_factory=list)
|
||||
shape: Optional[List[Union[Any, int, str]]] = Field(default_factory=list)
|
||||
dims: Optional[List[Union[Any, str]]] = Field(None)
|
||||
shape: Optional[List[Union[Any, int, str]]] = Field(None)
|
||||
value: Optional[Any] = Field(
|
||||
None, description="""Optional constant, fixed value for the attribute."""
|
||||
)
|
||||
|
@ -261,7 +259,5 @@ class Dataset(DtypeMixin):
|
|||
doc: str = Field(..., description="""Description of corresponding object.""")
|
||||
quantity: Optional[Union[QuantityEnum, int]] = Field(1)
|
||||
linkable: Optional[bool] = Field(None)
|
||||
attributes: Optional[List[Attribute]] = Field(default_factory=list)
|
||||
dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
attributes: Optional[List[Attribute]] = Field(None)
|
||||
dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(None)
|
||||
|
|
Loading…
Reference in a new issue