add logging. less janky adapter instantiation using model validators. correctly propagate properties from ancestor classes when building

This commit is contained in:
sneakers-the-rat 2024-08-12 18:48:59 -07:00
parent c09b633cda
commit 0452a4359f
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
15 changed files with 415 additions and 140 deletions

View file

@ -74,7 +74,9 @@ addopts = [
] ]
markers = [ markers = [
"dev: tests that are just for development rather than testing correctness", "dev: tests that are just for development rather than testing correctness",
"provider: tests for providers!" "provider: tests for providers!",
"linkml: tests related to linkml generation",
"pydantic: tests related to pydantic generation"
] ]
testpaths = [ testpaths = [
"src/nwb_linkml", "src/nwb_linkml",

View file

@ -5,16 +5,8 @@ Base class for adapters
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from typing import Any, Generator, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload
Any, from logging import Logger
Generator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.dumpers import yaml_dumper
from linkml_runtime.linkml_model import ( from linkml_runtime.linkml_model import (
@ -26,6 +18,7 @@ from linkml_runtime.linkml_model import (
) )
from pydantic import BaseModel from pydantic import BaseModel
from nwb_linkml.logging import init_logger
from nwb_schema_language import Attribute, CompoundDtype, Dataset, Group, Schema from nwb_schema_language import Attribute, CompoundDtype, Dataset, Group, Schema
if sys.version_info.minor >= 11: if sys.version_info.minor >= 11:
@ -107,6 +100,14 @@ class BuildResult:
class Adapter(BaseModel): class Adapter(BaseModel):
"""Abstract base class for adapters""" """Abstract base class for adapters"""
_logger: Optional[Logger] = None
@property
def logger(self) -> Logger:
if self._logger is None:
self._logger = init_logger(self.__class__.__name__)
return self._logger
@abstractmethod @abstractmethod
def build(self) -> "BuildResult": def build(self) -> "BuildResult":
""" """
@ -152,8 +153,8 @@ class Adapter(BaseModel):
# SchemaAdapters that should be located under the same # SchemaAdapters that should be located under the same
# NamespacesAdapter when it's important to query across SchemaAdapters, # NamespacesAdapter when it's important to query across SchemaAdapters,
# so skip to avoid combinatoric walking # so skip to avoid combinatoric walking
if key == "imports" and type(input).__name__ == "SchemaAdapter": # if key == "imports" and type(input).__name__ == "SchemaAdapter":
continue # continue
val = getattr(input, key) val = getattr(input, key)
yield (key, val) yield (key, val)
if isinstance(val, (BaseModel, dict, list)): if isinstance(val, (BaseModel, dict, list)):
@ -196,6 +197,14 @@ class Adapter(BaseModel):
if isinstance(item, tuple) and item[0] in field and item[1] is not None: if isinstance(item, tuple) and item[0] in field and item[1] is not None:
yield item[1] yield item[1]
@overload
def walk_field_values(
self,
input: Union[BaseModel, dict, list],
field: Literal["neurodata_type_def"],
value: Optional[Any] = None,
) -> Generator[Group | Dataset, None, None]: ...
def walk_field_values( def walk_field_values(
self, input: Union[BaseModel, dict, list], field: str, value: Optional[Any] = None self, input: Union[BaseModel, dict, list], field: str, value: Optional[Any] = None
) -> Generator[BaseModel, None, None]: ) -> Generator[BaseModel, None, None]:
@ -248,6 +257,9 @@ def is_1d(cls: Dataset | Attribute) -> bool:
* a single-layer dim/shape list of length 1, or * a single-layer dim/shape list of length 1, or
* a nested dim/shape list where every nested spec is of length 1 * a nested dim/shape list where every nested spec is of length 1
""" """
if cls.dims is None:
return False
return ( return (
not any([isinstance(dim, list) for dim in cls.dims]) and len(cls.dims) == 1 not any([isinstance(dim, list) for dim in cls.dims]) and len(cls.dims) == 1
) or ( # nested list ) or ( # nested list
@ -270,4 +282,8 @@ def has_attrs(cls: Dataset) -> bool:
""" """
Check if a dataset has any attributes at all without defaults Check if a dataset has any attributes at all without defaults
""" """
return len(cls.attributes) > 0 and all([not a.value for a in cls.attributes]) return (
cls.attributes is not None
and len(cls.attributes) > 0
and all([not a.value for a in cls.attributes])
)

View file

@ -119,9 +119,12 @@ class ClassAdapter(Adapter):
Returns: Returns:
list[:class:`.SlotDefinition`] list[:class:`.SlotDefinition`]
""" """
if cls.attributes is not None:
results = [AttributeAdapter(cls=attr).build() for attr in cls.attributes] results = [AttributeAdapter(cls=attr).build() for attr in cls.attributes]
slots = [r.slots[0] for r in results] slots = [r.slots[0] for r in results]
return slots return slots
else:
return []
def _get_full_name(self) -> str: def _get_full_name(self) -> str:
"""The full name of the object in the generated linkml """The full name of the object in the generated linkml

View file

@ -784,6 +784,12 @@ class MapCompoundDtype(DatasetMap):
Make a new class for this dtype, using its sub-dtypes as fields, Make a new class for this dtype, using its sub-dtypes as fields,
and use it as the range for the parent class and use it as the range for the parent class
""" """
# all the slots share the same ndarray spec if there is one
array = {}
if cls.dims or cls.shape:
array_adapter = ArrayAdapter(cls.dims, cls.shape)
array = array_adapter.make_slot()
slots = {} slots = {}
for a_dtype in cls.dtype: for a_dtype in cls.dtype:
slots[a_dtype.name] = SlotDefinition( slots[a_dtype.name] = SlotDefinition(
@ -791,8 +797,13 @@ class MapCompoundDtype(DatasetMap):
description=a_dtype.doc, description=a_dtype.doc,
range=handle_dtype(a_dtype.dtype), range=handle_dtype(a_dtype.dtype),
**QUANTITY_MAP[cls.quantity], **QUANTITY_MAP[cls.quantity],
**array,
) )
res.classes[0].attributes.update(slots) res.classes[0].attributes.update(slots)
# the compound dtype replaces the ``value`` slot, if present
if "value" in res.classes[0].attributes:
del res.classes[0].attributes["value"]
return res return res

View file

@ -13,13 +13,13 @@ from typing import Dict, List, Optional
from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.dumpers import yaml_dumper
from linkml_runtime.linkml_model import Annotation, SchemaDefinition from linkml_runtime.linkml_model import Annotation, SchemaDefinition
from pydantic import Field, PrivateAttr from pydantic import Field, model_validator
from nwb_linkml.adapters.adapter import Adapter, BuildResult from nwb_linkml.adapters.adapter import Adapter, BuildResult
from nwb_linkml.adapters.schema import SchemaAdapter from nwb_linkml.adapters.schema import SchemaAdapter
from nwb_linkml.lang_elements import NwbLangSchema from nwb_linkml.lang_elements import NwbLangSchema
from nwb_linkml.ui import AdapterProgress from nwb_linkml.ui import AdapterProgress
from nwb_schema_language import Namespaces from nwb_schema_language import Namespaces, Group, Dataset
class NamespacesAdapter(Adapter): class NamespacesAdapter(Adapter):
@ -31,12 +31,6 @@ class NamespacesAdapter(Adapter):
schemas: List[SchemaAdapter] schemas: List[SchemaAdapter]
imported: List["NamespacesAdapter"] = Field(default_factory=list) imported: List["NamespacesAdapter"] = Field(default_factory=list)
_imports_populated: bool = PrivateAttr(False)
def __init__(self, **kwargs: dict):
super().__init__(**kwargs)
self._populate_schema_namespaces()
@classmethod @classmethod
def from_yaml(cls, path: Path) -> "NamespacesAdapter": def from_yaml(cls, path: Path) -> "NamespacesAdapter":
""" """
@ -70,8 +64,6 @@ class NamespacesAdapter(Adapter):
""" """
Build the NWB namespace to the LinkML Schema Build the NWB namespace to the LinkML Schema
""" """
if not self._imports_populated and not skip_imports:
self.populate_imports()
sch_result = BuildResult() sch_result = BuildResult()
for sch in self.schemas: for sch in self.schemas:
@ -129,6 +121,7 @@ class NamespacesAdapter(Adapter):
return sch_result return sch_result
@model_validator(mode="after")
def _populate_schema_namespaces(self) -> None: def _populate_schema_namespaces(self) -> None:
""" """
annotate for each schema which namespace imports it annotate for each schema which namespace imports it
@ -143,6 +136,7 @@ class NamespacesAdapter(Adapter):
sch.namespace = ns.name sch.namespace = ns.name
sch.version = ns.version sch.version = ns.version
break break
return self
def find_type_source(self, name: str) -> SchemaAdapter: def find_type_source(self, name: str) -> SchemaAdapter:
""" """
@ -182,7 +176,8 @@ class NamespacesAdapter(Adapter):
else: else:
raise KeyError(f"No schema found that define {name}") raise KeyError(f"No schema found that define {name}")
def populate_imports(self) -> None: @model_validator(mode="after")
def populate_imports(self) -> "NamespacesAdapter":
""" """
Populate the imports that are needed for each schema file Populate the imports that are needed for each schema file
@ -199,11 +194,46 @@ class NamespacesAdapter(Adapter):
if depends_on not in sch.imports: if depends_on not in sch.imports:
sch.imports.append(depends_on) sch.imports.append(depends_on)
# do so recursively return self
for imported in self.imported:
imported.populate_imports()
self._imports_populated = True @model_validator(mode="after")
def _populate_inheritance(self):
"""
ensure properties from `neurodata_type_inc` are propaged through to inheriting classes.
This seems super expensive but we'll optimize for perf later if that proves to be the case
"""
# don't use walk_types here so we can replace the objects as we mutate them
for sch in self.schemas:
for i, group in enumerate(sch.groups):
if getattr(group, "neurodata_type_inc", None) is not None:
merged_attrs = self._merge_inheritance(group)
sch.groups[i] = Group(**merged_attrs)
for i, dataset in enumerate(sch.datasets):
if getattr(dataset, "neurodata_type_inc", None) is not None:
merged_attrs = self._merge_inheritance(dataset)
sch.datasets[i] = Dataset(**merged_attrs)
return self
def _merge_inheritance(self, obj: Group | Dataset) -> dict:
obj_dict = obj.model_dump(exclude_none=True)
if obj.neurodata_type_inc:
name = obj.neurodata_type_def if obj.neurodata_type_def else obj.name
self.logger.debug(f"Merging {name} with {obj.neurodata_type_inc}")
# there must be only one type with this name
parent: Group | Dataset = next(
self.walk_field_values(self, "neurodata_type_def", obj.neurodata_type_inc)
)
if obj.neurodata_type_def == "TimeSeriesReferenceVectorData":
pdb.set_trace()
parent_dict = copy(self._merge_inheritance(parent))
# children don't inherit the type_def
del parent_dict["neurodata_type_def"]
# overwrite with child values
parent_dict.update(obj_dict)
return parent_dict
return obj_dict
def to_yaml(self, base_dir: Path) -> None: def to_yaml(self, base_dir: Path) -> None:
""" """

View file

@ -42,7 +42,8 @@ class SchemaAdapter(Adapter):
""" """
The namespace.schema name for a single schema The namespace.schema name for a single schema
""" """
return ".".join([self.namespace, self.path.with_suffix("").name]) namespace = self.namespace if self.namespace is not None else ""
return ".".join([namespace, self.path.with_suffix("").name])
def __repr__(self): def __repr__(self):
out_str = "\n" + self.name + "\n" out_str = "\n" + self.name + "\n"

View file

@ -2,10 +2,12 @@
Manage the operation of nwb_linkml from environmental variables Manage the operation of nwb_linkml from environmental variables
""" """
from typing import Optional, Literal
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from pydantic import ( from pydantic import (
BaseModel,
DirectoryPath, DirectoryPath,
Field, Field,
FieldValidationInfo, FieldValidationInfo,
@ -15,15 +17,68 @@ from pydantic import (
) )
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
LOG_LEVELS = Literal["DEBUG", "INFO", "WARNING", "ERROR"]
class LogConfig(BaseModel):
"""
Configuration for logging
"""
level: LOG_LEVELS = "INFO"
"""
Severity of log messages to process.
"""
level_file: Optional[LOG_LEVELS] = None
"""
Severity for file-based logging. If unset, use ``level``
"""
level_stdout: Optional[LOG_LEVELS] = "WARNING"
"""
Severity for stream-based logging. If unset, use ``level``
"""
file_n: int = 5
"""
Number of log files to rotate through
"""
file_size: int = 2**22 # roughly 4MB
"""
Maximum size of log files (bytes)
"""
@field_validator("level", "level_file", "level_stdout", mode="before")
@classmethod
def uppercase_levels(cls, value: Optional[str] = None) -> Optional[str]:
"""
Ensure log level strings are uppercased
"""
if value is not None:
value = value.upper()
return value
@model_validator(mode="after")
def inherit_base_level(self) -> "LogConfig":
"""
If loglevels for specific output streams are unset, set from base :attr:`.level`
"""
levels = ("level_file", "level_stdout")
for level_name in levels:
if getattr(self, level_name) is None:
setattr(self, level_name, self.level)
return self
class Config(BaseSettings): class Config(BaseSettings):
""" """
Configuration for nwb_linkml, populated by default but can be overridden Configuration for nwb_linkml, populated by default but can be overridden
by environment variables. by environment variables.
Nested models can be assigned from .env files with a __ (see examples)
Examples: Examples:
export NWB_LINKML_CACHE_DIR="/home/mycache/dir" export NWB_LINKML_CACHE_DIR="/home/mycache/dir"
export NWB_LINKML_LOGS__LEVEL="debug"
""" """
@ -32,6 +87,11 @@ class Config(BaseSettings):
default_factory=lambda: Path(tempfile.gettempdir()) / "nwb_linkml__cache", default_factory=lambda: Path(tempfile.gettempdir()) / "nwb_linkml__cache",
description="Location to cache generated schema and models", description="Location to cache generated schema and models",
) )
log_dir: Path = Field(
Path("logs"),
description="Location to store logs. If a relative directory, relative to ``cache_dir``",
)
logs: LogConfig = Field(LogConfig(), description="Log configuration")
@computed_field @computed_field
@property @property
@ -62,6 +122,15 @@ class Config(BaseSettings):
assert v.exists() assert v.exists()
return v return v
@model_validator(mode="after")
def log_dir_relative_to_cache_dir(self) -> "Config":
"""
If log dir is relative, put it beneath the cache_dir
"""
if not self.log_dir.is_absolute():
self.log_dir = self.cache_dir / self.log_dir
return self
@model_validator(mode="after") @model_validator(mode="after")
def folders_exist(self) -> "Config": def folders_exist(self) -> "Config":
""" """

View file

@ -70,6 +70,7 @@ def load_namespace_adapter(
namespace: Path | NamespaceRepo | Namespaces, namespace: Path | NamespaceRepo | Namespaces,
path: Optional[Path] = None, path: Optional[Path] = None,
version: Optional[str] = None, version: Optional[str] = None,
imported: Optional[list[NamespacesAdapter]] = None,
) -> NamespacesAdapter: ) -> NamespacesAdapter:
""" """
Load all schema referenced by a namespace file Load all schema referenced by a namespace file
@ -115,6 +116,9 @@ def load_namespace_adapter(
yml_file = (path / schema.source).resolve() yml_file = (path / schema.source).resolve()
sch.append(load_schema_file(yml_file)) sch.append(load_schema_file(yml_file))
if imported is not None:
adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch, imported=imported)
else:
adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch) adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch)
return adapter return adapter
@ -148,8 +152,6 @@ def load_nwb_core(
if hdmf_only: if hdmf_only:
schema = hdmf_schema schema = hdmf_schema
else: else:
schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version) schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version, imported=[hdmf_schema])
schema.imported.append(hdmf_schema)
return schema return schema

View file

@ -0,0 +1,100 @@
"""
Logging factory and handlers
"""
import logging
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional, Union
from rich.logging import RichHandler
from nwb_linkml.config import LOG_LEVELS, Config
def init_logger(
name: str,
log_dir: Union[Optional[Path], bool] = None,
level: Optional[LOG_LEVELS] = None,
file_level: Optional[LOG_LEVELS] = None,
log_file_n: Optional[int] = None,
log_file_size: Optional[int] = None,
) -> logging.Logger:
"""
Make a logger.
Log to a set of rotating files in the ``log_dir`` according to ``name`` ,
as well as using the :class:`~rich.RichHandler` for pretty-formatted stdout logs.
Args:
name (str): Name of this logger. Ideally names are hierarchical
and indicate what they are logging for, eg. ``miniscope_io.sdcard``
and don't contain metadata like timestamps, etc. (which are in the logs)
log_dir (:class:`pathlib.Path`): Directory to store file-based logs in. If ``None``,
get from :class:`.Config`. If ``False`` , disable file logging.
level (:class:`.LOG_LEVELS`): Level to use for stdout logging. If ``None`` ,
get from :class:`.Config`
file_level (:class:`.LOG_LEVELS`): Level to use for file-based logging.
If ``None`` , get from :class:`.Config`
log_file_n (int): Number of rotating file logs to use.
If ``None`` , get from :class:`.Config`
log_file_size (int): Maximum size of logfiles before rotation.
If ``None`` , get from :class:`.Config`
Returns:
:class:`logging.Logger`
"""
config = Config()
if log_dir is None:
log_dir = config.log_dir
if level is None:
level = config.logs.level_stdout
if file_level is None:
file_level = config.logs.level_file
if log_file_n is None:
log_file_n = config.logs.file_n
if log_file_size is None:
log_file_size = config.logs.file_size
if not name.startswith("nwb_linkml"):
name = "nwb_linkml." + name
logger = logging.getLogger(name)
logger.setLevel(level)
# Add handlers for stdout and file
if log_dir is not False:
logger.addHandler(_file_handler(name, file_level, log_dir, log_file_n, log_file_size))
logger.addHandler(_rich_handler())
return logger
def _file_handler(
name: str,
file_level: LOG_LEVELS,
log_dir: Path,
log_file_n: int = 5,
log_file_size: int = 2**22,
) -> RotatingFileHandler:
# See init_logger for arg docs
filename = Path(log_dir) / ".".join([name, "log"])
file_handler = RotatingFileHandler(
str(filename), mode="a", maxBytes=log_file_size, backupCount=log_file_n
)
file_formatter = logging.Formatter("[%(asctime)s] %(levelname)s [%(name)s]: %(message)s")
file_handler.setLevel(file_level)
file_handler.setFormatter(file_formatter)
return file_handler
def _rich_handler() -> RichHandler:
rich_handler = RichHandler(rich_tracebacks=True, markup=True)
rich_formatter = logging.Formatter(
"[bold green]\[%(name)s][/bold green] %(message)s",
datefmt="[%y-%m-%dT%H:%M:%S]",
)
rich_handler.setFormatter(rich_formatter)
return rich_handler

View file

@ -82,7 +82,6 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path:
@pytest.fixture(scope="session", params=[{"core_version": "2.7.0", "hdmf_version": "1.8.0"}]) @pytest.fixture(scope="session", params=[{"core_version": "2.7.0", "hdmf_version": "1.8.0"}])
def nwb_core_fixture(request) -> NamespacesAdapter: def nwb_core_fixture(request) -> NamespacesAdapter:
nwb_core = io.load_nwb_core(**request.param) nwb_core = io.load_nwb_core(**request.param)
nwb_core.populate_imports()
assert ( assert (
request.param["core_version"] in nwb_core.versions["core"] request.param["core_version"] in nwb_core.versions["core"]
) # 2.6.0 is actually 2.6.0-alpha ) # 2.6.0 is actually 2.6.0-alpha

View file

@ -46,3 +46,17 @@ def test_skip_imports(nwb_core_fixture):
# we shouldn't have any of the hdmf-common schema in with us # we shouldn't have any of the hdmf-common schema in with us
namespaces = [sch.annotations["namespace"].value for sch in res.schemas] namespaces = [sch.annotations["namespace"].value for sch in res.schemas]
assert all([ns == "core" for ns in namespaces]) assert all([ns == "core" for ns in namespaces])
@pytest.mark.skip()
def test_populate_inheritance(nwb_core_fixture):
"""
Classes should receive and override the properties of their parents
when they have neurodata_type_inc
Args:
nwb_core_fixture:
Returns:
"""
pass

View file

@ -76,6 +76,7 @@ def test_generate_pydantic(tmp_output_dir):
initfile.write("# Autogenerated module indicator") initfile.write("# Autogenerated module indicator")
@pytest.mark.linkml
@pytest.mark.provider @pytest.mark.provider
@pytest.mark.dev @pytest.mark.dev
def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture): def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
@ -84,6 +85,7 @@ def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
result = provider.build(nwb_core_fixture) result = provider.build(nwb_core_fixture)
@pytest.mark.pydantic
@pytest.mark.provider @pytest.mark.provider
@pytest.mark.dev @pytest.mark.dev
def test_generate_pydantic_provider(tmp_output_dir): def test_generate_pydantic_provider(tmp_output_dir):

View file

@ -0,0 +1,120 @@
from typing import Tuple
import numpy as np
import pytest
from nwb_linkml.models import (
ElectricalSeries,
ExtracellularEphysElectrodes,
Device,
ElectrodeGroup,
DynamicTableRegion,
Units,
IntracellularElectrode,
IntracellularElectrodesTable,
IntracellularResponsesTable,
IntracellularStimuliTable,
IntracellularRecordingsTable,
)
@pytest.fixture()
def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrodes"]:
"""
Demo electrical series with adjoining electrodes
"""
n_electrodes = 5
n_times = 100
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float)
timestamps = np.linspace(0, 1, n_times)
device = Device(name="my electrode")
# electrode group is the physical description of the electrodes
electrode_group = ElectrodeGroup(
name="GroupA",
device=device,
description="an electrode group",
location="you know where it is",
)
# make electrodes tables
electrodes = ExtracellularEphysElectrodes(
description="idk these are also electrodes",
id=np.arange(0, n_electrodes),
x=np.arange(0, n_electrodes).astype(float),
y=np.arange(n_electrodes, n_electrodes * 2).astype(float),
group=[electrode_group] * n_electrodes,
group_name=[electrode_group.name] * n_electrodes,
location=[str(i) for i in range(n_electrodes)],
extra_column=["sup"] * n_electrodes,
)
electrical_series = ElectricalSeries(
name="my recording!",
electrodes=DynamicTableRegion(
table=electrodes,
value=np.arange(n_electrodes - 1, -1, step=-1),
name="electrodes",
description="hey",
),
timestamps=timestamps,
data=data,
)
return electrical_series, electrodes
def _ragged_array(n_units: int) -> tuple[list[np.ndarray], np.ndarray]:
generator = np.random.default_rng()
spike_times = [
np.full(shape=generator.integers(10, 50), fill_value=i, dtype=float) for i in range(n_units)
]
spike_idx = []
for i in range(n_units):
if i == 0:
spike_idx.append(len(spike_times[0]))
else:
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
spike_idx = np.array(spike_idx)
return spike_times, spike_idx
@pytest.fixture(params=[True, False])
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
"""
Test case for units
Parameterized by extra_column because pandas likes to pivot dataframes
to long when there is only one column and it's not len() == 1
"""
spike_times, spike_idx = _ragged_array(24)
spike_times_flat = np.concatenate(spike_times)
kwargs = {
"description": "units!!!!",
"spike_times": spike_times_flat,
"spike_times_index": spike_idx,
}
if request.param:
kwargs["extra_column"] = ["hey!"] * 24
units = Units(**kwargs)
return units, spike_times, spike_idx
@pytest.fixture()
def intracellular_recordings_table() -> IntracellularRecordingsTable:
n_recordings = 10
device = Device(name="my device")
electrode = IntracellularElectrode(
name="my_electrode", description="an electrode", device=device
)
electrodes = IntracellularElectrodesTable(
name="intracellular_electrodes", electrode=[electrode] * n_recordings
)
stimuli = IntracellularStimuliTable(
name="intracellular_stimuli",
)
responses = IntracellularResponsesTable()
recordings_table = IntracellularRecordingsTable()

View file

@ -1,103 +1,13 @@
from typing import Tuple
import numpy as np import numpy as np
import pytest
# FIXME: Make this just be the output of the provider by patching into import machinery # FIXME: Make this just be the output of the provider by patching into import machinery
from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( from nwb_linkml.models.pydantic.core.v2_7_0.namespace import (
Device,
DynamicTable, DynamicTable,
DynamicTableRegion, DynamicTableRegion,
ElectricalSeries,
ElectrodeGroup, ElectrodeGroup,
ExtracellularEphysElectrodes,
Units,
VectorIndex, VectorIndex,
) )
from .conftest import _ragged_array
@pytest.fixture()
def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrodes"]:
"""
Demo electrical series with adjoining electrodes
"""
n_electrodes = 5
n_times = 100
data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float)
timestamps = np.linspace(0, 1, n_times)
device = Device(name="my electrode")
# electrode group is the physical description of the electrodes
electrode_group = ElectrodeGroup(
name="GroupA",
device=device,
description="an electrode group",
location="you know where it is",
)
# make electrodes tables
electrodes = ExtracellularEphysElectrodes(
description="idk these are also electrodes",
id=np.arange(0, n_electrodes),
x=np.arange(0, n_electrodes).astype(float),
y=np.arange(n_electrodes, n_electrodes * 2).astype(float),
group=[electrode_group] * n_electrodes,
group_name=[electrode_group.name] * n_electrodes,
location=[str(i) for i in range(n_electrodes)],
extra_column=["sup"] * n_electrodes,
)
electrical_series = ElectricalSeries(
name="my recording!",
electrodes=DynamicTableRegion(
table=electrodes,
value=np.arange(n_electrodes - 1, -1, step=-1),
name="electrodes",
description="hey",
),
timestamps=timestamps,
data=data,
)
return electrical_series, electrodes
def _ragged_array(n_units: int) -> tuple[list[np.ndarray], np.ndarray]:
generator = np.random.default_rng()
spike_times = [
np.full(shape=generator.integers(10, 50), fill_value=i, dtype=float) for i in range(n_units)
]
spike_idx = []
for i in range(n_units):
if i == 0:
spike_idx.append(len(spike_times[0]))
else:
spike_idx.append(len(spike_times[i]) + spike_idx[i - 1])
spike_idx = np.array(spike_idx)
return spike_times, spike_idx
@pytest.fixture(params=[True, False])
def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]:
"""
Test case for units
Parameterized by extra_column because pandas likes to pivot dataframes
to long when there is only one column and it's not len() == 1
"""
spike_times, spike_idx = _ragged_array(24)
spike_times_flat = np.concatenate(spike_times)
kwargs = {
"description": "units!!!!",
"spike_times": spike_times_flat,
"spike_times_index": spike_idx,
}
if request.param:
kwargs["extra_column"] = ["hey!"] * 24
units = Units(**kwargs)
return units, spike_times, spike_idx
def test_dynamictable_indexing(electrical_series): def test_dynamictable_indexing(electrical_series):

View file

@ -220,8 +220,8 @@ class DtypeMixin(ConfiguredBaseModel):
class Attribute(DtypeMixin): class Attribute(DtypeMixin):
name: str = Field(...) name: str = Field(...)
dims: Optional[List[Union[Any, str]]] = Field(default_factory=list) dims: Optional[List[Union[Any, str]]] = Field(None)
shape: Optional[List[Union[Any, int, str]]] = Field(default_factory=list) shape: Optional[List[Union[Any, int, str]]] = Field(None)
value: Optional[Any] = Field( value: Optional[Any] = Field(
None, description="""Optional constant, fixed value for the attribute.""" None, description="""Optional constant, fixed value for the attribute."""
) )
@ -233,9 +233,7 @@ class Attribute(DtypeMixin):
True, True,
description="""Optional boolean key describing whether the attribute is required. Default value is True.""", description="""Optional boolean key describing whether the attribute is required. Default value is True.""",
) )
dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field( dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(None)
default_factory=list
)
class Dataset(DtypeMixin): class Dataset(DtypeMixin):
@ -250,8 +248,8 @@ class Dataset(DtypeMixin):
) )
name: Optional[str] = Field(None) name: Optional[str] = Field(None)
default_name: Optional[str] = Field(None) default_name: Optional[str] = Field(None)
dims: Optional[List[Union[Any, str]]] = Field(default_factory=list) dims: Optional[List[Union[Any, str]]] = Field(None)
shape: Optional[List[Union[Any, int, str]]] = Field(default_factory=list) shape: Optional[List[Union[Any, int, str]]] = Field(None)
value: Optional[Any] = Field( value: Optional[Any] = Field(
None, description="""Optional constant, fixed value for the attribute.""" None, description="""Optional constant, fixed value for the attribute."""
) )
@ -261,7 +259,5 @@ class Dataset(DtypeMixin):
doc: str = Field(..., description="""Description of corresponding object.""") doc: str = Field(..., description="""Description of corresponding object.""")
quantity: Optional[Union[QuantityEnum, int]] = Field(1) quantity: Optional[Union[QuantityEnum, int]] = Field(1)
linkable: Optional[bool] = Field(None) linkable: Optional[bool] = Field(None)
attributes: Optional[List[Attribute]] = Field(default_factory=list) attributes: Optional[List[Attribute]] = Field(None)
dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field( dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(None)
default_factory=list
)