From 0452a4359fa123c8fec948291aa8570b3d2426c0 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 12 Aug 2024 18:48:59 -0700 Subject: [PATCH] add logging. less janky adapter instantiation using model validators. correctly propagate properties from ancestor classes when building --- nwb_linkml/pyproject.toml | 4 +- nwb_linkml/src/nwb_linkml/adapters/adapter.py | 42 ++++-- nwb_linkml/src/nwb_linkml/adapters/classes.py | 9 +- nwb_linkml/src/nwb_linkml/adapters/dataset.py | 11 ++ .../src/nwb_linkml/adapters/namespaces.py | 60 ++++++--- nwb_linkml/src/nwb_linkml/adapters/schema.py | 3 +- nwb_linkml/src/nwb_linkml/config.py | 69 ++++++++++ nwb_linkml/src/nwb_linkml/io/schema.py | 10 +- nwb_linkml/src/nwb_linkml/logging.py | 100 +++++++++++++++ nwb_linkml/tests/fixtures.py | 1 - .../test_adapters/test_adapter_namespaces.py | 14 ++ nwb_linkml/tests/test_generate.py | 2 + nwb_linkml/tests/test_includes/conftest.py | 120 ++++++++++++++++++ nwb_linkml/tests/test_includes/test_hdmf.py | 92 +------------- .../datamodel/nwb_schema_pydantic.py | 18 +-- 15 files changed, 415 insertions(+), 140 deletions(-) create mode 100644 nwb_linkml/src/nwb_linkml/logging.py create mode 100644 nwb_linkml/tests/test_includes/conftest.py diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 6efc111..97e4cce 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -74,7 +74,9 @@ addopts = [ ] markers = [ "dev: tests that are just for development rather than testing correctness", - "provider: tests for providers!" + "provider: tests for providers!", + "linkml: tests related to linkml generation", + "pydantic: tests related to pydantic generation" ] testpaths = [ "src/nwb_linkml", diff --git a/nwb_linkml/src/nwb_linkml/adapters/adapter.py b/nwb_linkml/src/nwb_linkml/adapters/adapter.py index 72f4248..e09e68f 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/adapter.py +++ b/nwb_linkml/src/nwb_linkml/adapters/adapter.py @@ -5,16 +5,8 @@ Base class for adapters import sys from abc import abstractmethod from dataclasses import dataclass, field -from typing import ( - Any, - Generator, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, -) +from typing import Any, Generator, List, Literal, Optional, Tuple, Type, TypeVar, Union, overload +from logging import Logger from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.linkml_model import ( @@ -26,6 +18,7 @@ from linkml_runtime.linkml_model import ( ) from pydantic import BaseModel +from nwb_linkml.logging import init_logger from nwb_schema_language import Attribute, CompoundDtype, Dataset, Group, Schema if sys.version_info.minor >= 11: @@ -107,6 +100,14 @@ class BuildResult: class Adapter(BaseModel): """Abstract base class for adapters""" + _logger: Optional[Logger] = None + + @property + def logger(self) -> Logger: + if self._logger is None: + self._logger = init_logger(self.__class__.__name__) + return self._logger + @abstractmethod def build(self) -> "BuildResult": """ @@ -152,8 +153,8 @@ class Adapter(BaseModel): # SchemaAdapters that should be located under the same # NamespacesAdapter when it's important to query across SchemaAdapters, # so skip to avoid combinatoric walking - if key == "imports" and type(input).__name__ == "SchemaAdapter": - continue + # if key == "imports" and type(input).__name__ == "SchemaAdapter": + # continue val = getattr(input, key) yield (key, val) if isinstance(val, (BaseModel, dict, list)): @@ -196,6 +197,14 @@ class Adapter(BaseModel): if isinstance(item, tuple) and item[0] in field and item[1] is not None: yield item[1] + @overload + def walk_field_values( + self, + input: Union[BaseModel, dict, list], + field: Literal["neurodata_type_def"], + value: Optional[Any] = None, + ) -> Generator[Group | Dataset, None, None]: ... + def walk_field_values( self, input: Union[BaseModel, dict, list], field: str, value: Optional[Any] = None ) -> Generator[BaseModel, None, None]: @@ -248,6 +257,9 @@ def is_1d(cls: Dataset | Attribute) -> bool: * a single-layer dim/shape list of length 1, or * a nested dim/shape list where every nested spec is of length 1 """ + if cls.dims is None: + return False + return ( not any([isinstance(dim, list) for dim in cls.dims]) and len(cls.dims) == 1 ) or ( # nested list @@ -270,4 +282,8 @@ def has_attrs(cls: Dataset) -> bool: """ Check if a dataset has any attributes at all without defaults """ - return len(cls.attributes) > 0 and all([not a.value for a in cls.attributes]) + return ( + cls.attributes is not None + and len(cls.attributes) > 0 + and all([not a.value for a in cls.attributes]) + ) diff --git a/nwb_linkml/src/nwb_linkml/adapters/classes.py b/nwb_linkml/src/nwb_linkml/adapters/classes.py index 054a401..0097e47 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/classes.py +++ b/nwb_linkml/src/nwb_linkml/adapters/classes.py @@ -119,9 +119,12 @@ class ClassAdapter(Adapter): Returns: list[:class:`.SlotDefinition`] """ - results = [AttributeAdapter(cls=attr).build() for attr in cls.attributes] - slots = [r.slots[0] for r in results] - return slots + if cls.attributes is not None: + results = [AttributeAdapter(cls=attr).build() for attr in cls.attributes] + slots = [r.slots[0] for r in results] + return slots + else: + return [] def _get_full_name(self) -> str: """The full name of the object in the generated linkml diff --git a/nwb_linkml/src/nwb_linkml/adapters/dataset.py b/nwb_linkml/src/nwb_linkml/adapters/dataset.py index 3a49798..2490ef5 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/dataset.py +++ b/nwb_linkml/src/nwb_linkml/adapters/dataset.py @@ -784,6 +784,12 @@ class MapCompoundDtype(DatasetMap): Make a new class for this dtype, using its sub-dtypes as fields, and use it as the range for the parent class """ + # all the slots share the same ndarray spec if there is one + array = {} + if cls.dims or cls.shape: + array_adapter = ArrayAdapter(cls.dims, cls.shape) + array = array_adapter.make_slot() + slots = {} for a_dtype in cls.dtype: slots[a_dtype.name] = SlotDefinition( @@ -791,8 +797,13 @@ class MapCompoundDtype(DatasetMap): description=a_dtype.doc, range=handle_dtype(a_dtype.dtype), **QUANTITY_MAP[cls.quantity], + **array, ) res.classes[0].attributes.update(slots) + + # the compound dtype replaces the ``value`` slot, if present + if "value" in res.classes[0].attributes: + del res.classes[0].attributes["value"] return res diff --git a/nwb_linkml/src/nwb_linkml/adapters/namespaces.py b/nwb_linkml/src/nwb_linkml/adapters/namespaces.py index f8ea857..59194e4 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/namespaces.py +++ b/nwb_linkml/src/nwb_linkml/adapters/namespaces.py @@ -13,13 +13,13 @@ from typing import Dict, List, Optional from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.linkml_model import Annotation, SchemaDefinition -from pydantic import Field, PrivateAttr +from pydantic import Field, model_validator from nwb_linkml.adapters.adapter import Adapter, BuildResult from nwb_linkml.adapters.schema import SchemaAdapter from nwb_linkml.lang_elements import NwbLangSchema from nwb_linkml.ui import AdapterProgress -from nwb_schema_language import Namespaces +from nwb_schema_language import Namespaces, Group, Dataset class NamespacesAdapter(Adapter): @@ -31,12 +31,6 @@ class NamespacesAdapter(Adapter): schemas: List[SchemaAdapter] imported: List["NamespacesAdapter"] = Field(default_factory=list) - _imports_populated: bool = PrivateAttr(False) - - def __init__(self, **kwargs: dict): - super().__init__(**kwargs) - self._populate_schema_namespaces() - @classmethod def from_yaml(cls, path: Path) -> "NamespacesAdapter": """ @@ -70,8 +64,6 @@ class NamespacesAdapter(Adapter): """ Build the NWB namespace to the LinkML Schema """ - if not self._imports_populated and not skip_imports: - self.populate_imports() sch_result = BuildResult() for sch in self.schemas: @@ -129,6 +121,7 @@ class NamespacesAdapter(Adapter): return sch_result + @model_validator(mode="after") def _populate_schema_namespaces(self) -> None: """ annotate for each schema which namespace imports it @@ -143,6 +136,7 @@ class NamespacesAdapter(Adapter): sch.namespace = ns.name sch.version = ns.version break + return self def find_type_source(self, name: str) -> SchemaAdapter: """ @@ -182,7 +176,8 @@ class NamespacesAdapter(Adapter): else: raise KeyError(f"No schema found that define {name}") - def populate_imports(self) -> None: + @model_validator(mode="after") + def populate_imports(self) -> "NamespacesAdapter": """ Populate the imports that are needed for each schema file @@ -199,11 +194,46 @@ class NamespacesAdapter(Adapter): if depends_on not in sch.imports: sch.imports.append(depends_on) - # do so recursively - for imported in self.imported: - imported.populate_imports() + return self - self._imports_populated = True + @model_validator(mode="after") + def _populate_inheritance(self): + """ + ensure properties from `neurodata_type_inc` are propaged through to inheriting classes. + + This seems super expensive but we'll optimize for perf later if that proves to be the case + """ + # don't use walk_types here so we can replace the objects as we mutate them + for sch in self.schemas: + for i, group in enumerate(sch.groups): + if getattr(group, "neurodata_type_inc", None) is not None: + merged_attrs = self._merge_inheritance(group) + sch.groups[i] = Group(**merged_attrs) + for i, dataset in enumerate(sch.datasets): + if getattr(dataset, "neurodata_type_inc", None) is not None: + merged_attrs = self._merge_inheritance(dataset) + sch.datasets[i] = Dataset(**merged_attrs) + return self + + def _merge_inheritance(self, obj: Group | Dataset) -> dict: + obj_dict = obj.model_dump(exclude_none=True) + if obj.neurodata_type_inc: + name = obj.neurodata_type_def if obj.neurodata_type_def else obj.name + self.logger.debug(f"Merging {name} with {obj.neurodata_type_inc}") + # there must be only one type with this name + parent: Group | Dataset = next( + self.walk_field_values(self, "neurodata_type_def", obj.neurodata_type_inc) + ) + if obj.neurodata_type_def == "TimeSeriesReferenceVectorData": + pdb.set_trace() + parent_dict = copy(self._merge_inheritance(parent)) + # children don't inherit the type_def + del parent_dict["neurodata_type_def"] + # overwrite with child values + parent_dict.update(obj_dict) + return parent_dict + + return obj_dict def to_yaml(self, base_dir: Path) -> None: """ diff --git a/nwb_linkml/src/nwb_linkml/adapters/schema.py b/nwb_linkml/src/nwb_linkml/adapters/schema.py index 4f03944..e6316b7 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/schema.py +++ b/nwb_linkml/src/nwb_linkml/adapters/schema.py @@ -42,7 +42,8 @@ class SchemaAdapter(Adapter): """ The namespace.schema name for a single schema """ - return ".".join([self.namespace, self.path.with_suffix("").name]) + namespace = self.namespace if self.namespace is not None else "" + return ".".join([namespace, self.path.with_suffix("").name]) def __repr__(self): out_str = "\n" + self.name + "\n" diff --git a/nwb_linkml/src/nwb_linkml/config.py b/nwb_linkml/src/nwb_linkml/config.py index 8fa84f7..bbfcaed 100644 --- a/nwb_linkml/src/nwb_linkml/config.py +++ b/nwb_linkml/src/nwb_linkml/config.py @@ -2,10 +2,12 @@ Manage the operation of nwb_linkml from environmental variables """ +from typing import Optional, Literal import tempfile from pathlib import Path from pydantic import ( + BaseModel, DirectoryPath, Field, FieldValidationInfo, @@ -15,15 +17,68 @@ from pydantic import ( ) from pydantic_settings import BaseSettings, SettingsConfigDict +LOG_LEVELS = Literal["DEBUG", "INFO", "WARNING", "ERROR"] + + +class LogConfig(BaseModel): + """ + Configuration for logging + """ + + level: LOG_LEVELS = "INFO" + """ + Severity of log messages to process. + """ + level_file: Optional[LOG_LEVELS] = None + """ + Severity for file-based logging. If unset, use ``level`` + """ + level_stdout: Optional[LOG_LEVELS] = "WARNING" + """ + Severity for stream-based logging. If unset, use ``level`` + """ + file_n: int = 5 + """ + Number of log files to rotate through + """ + file_size: int = 2**22 # roughly 4MB + """ + Maximum size of log files (bytes) + """ + + @field_validator("level", "level_file", "level_stdout", mode="before") + @classmethod + def uppercase_levels(cls, value: Optional[str] = None) -> Optional[str]: + """ + Ensure log level strings are uppercased + """ + if value is not None: + value = value.upper() + return value + + @model_validator(mode="after") + def inherit_base_level(self) -> "LogConfig": + """ + If loglevels for specific output streams are unset, set from base :attr:`.level` + """ + levels = ("level_file", "level_stdout") + for level_name in levels: + if getattr(self, level_name) is None: + setattr(self, level_name, self.level) + return self + class Config(BaseSettings): """ Configuration for nwb_linkml, populated by default but can be overridden by environment variables. + Nested models can be assigned from .env files with a __ (see examples) + Examples: export NWB_LINKML_CACHE_DIR="/home/mycache/dir" + export NWB_LINKML_LOGS__LEVEL="debug" """ @@ -32,6 +87,11 @@ class Config(BaseSettings): default_factory=lambda: Path(tempfile.gettempdir()) / "nwb_linkml__cache", description="Location to cache generated schema and models", ) + log_dir: Path = Field( + Path("logs"), + description="Location to store logs. If a relative directory, relative to ``cache_dir``", + ) + logs: LogConfig = Field(LogConfig(), description="Log configuration") @computed_field @property @@ -62,6 +122,15 @@ class Config(BaseSettings): assert v.exists() return v + @model_validator(mode="after") + def log_dir_relative_to_cache_dir(self) -> "Config": + """ + If log dir is relative, put it beneath the cache_dir + """ + if not self.log_dir.is_absolute(): + self.log_dir = self.cache_dir / self.log_dir + return self + @model_validator(mode="after") def folders_exist(self) -> "Config": """ diff --git a/nwb_linkml/src/nwb_linkml/io/schema.py b/nwb_linkml/src/nwb_linkml/io/schema.py index a162856..954fb3a 100644 --- a/nwb_linkml/src/nwb_linkml/io/schema.py +++ b/nwb_linkml/src/nwb_linkml/io/schema.py @@ -70,6 +70,7 @@ def load_namespace_adapter( namespace: Path | NamespaceRepo | Namespaces, path: Optional[Path] = None, version: Optional[str] = None, + imported: Optional[list[NamespacesAdapter]] = None, ) -> NamespacesAdapter: """ Load all schema referenced by a namespace file @@ -115,7 +116,10 @@ def load_namespace_adapter( yml_file = (path / schema.source).resolve() sch.append(load_schema_file(yml_file)) - adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch) + if imported is not None: + adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch, imported=imported) + else: + adapter = NamespacesAdapter(namespaces=namespaces, schemas=sch) return adapter @@ -148,8 +152,6 @@ def load_nwb_core( if hdmf_only: schema = hdmf_schema else: - schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version) - - schema.imported.append(hdmf_schema) + schema = load_namespace_adapter(NWB_CORE_REPO, version=core_version, imported=[hdmf_schema]) return schema diff --git a/nwb_linkml/src/nwb_linkml/logging.py b/nwb_linkml/src/nwb_linkml/logging.py new file mode 100644 index 0000000..35e4425 --- /dev/null +++ b/nwb_linkml/src/nwb_linkml/logging.py @@ -0,0 +1,100 @@ +""" +Logging factory and handlers +""" + +import logging +from logging.handlers import RotatingFileHandler +from pathlib import Path +from typing import Optional, Union + +from rich.logging import RichHandler + +from nwb_linkml.config import LOG_LEVELS, Config + + +def init_logger( + name: str, + log_dir: Union[Optional[Path], bool] = None, + level: Optional[LOG_LEVELS] = None, + file_level: Optional[LOG_LEVELS] = None, + log_file_n: Optional[int] = None, + log_file_size: Optional[int] = None, +) -> logging.Logger: + """ + Make a logger. + + Log to a set of rotating files in the ``log_dir`` according to ``name`` , + as well as using the :class:`~rich.RichHandler` for pretty-formatted stdout logs. + + Args: + name (str): Name of this logger. Ideally names are hierarchical + and indicate what they are logging for, eg. ``miniscope_io.sdcard`` + and don't contain metadata like timestamps, etc. (which are in the logs) + log_dir (:class:`pathlib.Path`): Directory to store file-based logs in. If ``None``, + get from :class:`.Config`. If ``False`` , disable file logging. + level (:class:`.LOG_LEVELS`): Level to use for stdout logging. If ``None`` , + get from :class:`.Config` + file_level (:class:`.LOG_LEVELS`): Level to use for file-based logging. + If ``None`` , get from :class:`.Config` + log_file_n (int): Number of rotating file logs to use. + If ``None`` , get from :class:`.Config` + log_file_size (int): Maximum size of logfiles before rotation. + If ``None`` , get from :class:`.Config` + + Returns: + :class:`logging.Logger` + """ + config = Config() + if log_dir is None: + log_dir = config.log_dir + if level is None: + level = config.logs.level_stdout + if file_level is None: + file_level = config.logs.level_file + if log_file_n is None: + log_file_n = config.logs.file_n + if log_file_size is None: + log_file_size = config.logs.file_size + + if not name.startswith("nwb_linkml"): + name = "nwb_linkml." + name + + logger = logging.getLogger(name) + logger.setLevel(level) + + # Add handlers for stdout and file + if log_dir is not False: + logger.addHandler(_file_handler(name, file_level, log_dir, log_file_n, log_file_size)) + + logger.addHandler(_rich_handler()) + + return logger + + +def _file_handler( + name: str, + file_level: LOG_LEVELS, + log_dir: Path, + log_file_n: int = 5, + log_file_size: int = 2**22, +) -> RotatingFileHandler: + # See init_logger for arg docs + + filename = Path(log_dir) / ".".join([name, "log"]) + file_handler = RotatingFileHandler( + str(filename), mode="a", maxBytes=log_file_size, backupCount=log_file_n + ) + file_formatter = logging.Formatter("[%(asctime)s] %(levelname)s [%(name)s]: %(message)s") + file_handler.setLevel(file_level) + file_handler.setFormatter(file_formatter) + return file_handler + + +def _rich_handler() -> RichHandler: + rich_handler = RichHandler(rich_tracebacks=True, markup=True) + rich_formatter = logging.Formatter( + "[bold green]\[%(name)s][/bold green] %(message)s", + datefmt="[%y-%m-%dT%H:%M:%S]", + ) + rich_handler.setFormatter(rich_formatter) + return rich_handler diff --git a/nwb_linkml/tests/fixtures.py b/nwb_linkml/tests/fixtures.py index 3ab2d3c..a38e3e0 100644 --- a/nwb_linkml/tests/fixtures.py +++ b/nwb_linkml/tests/fixtures.py @@ -82,7 +82,6 @@ def tmp_output_dir_mod(tmp_output_dir) -> Path: @pytest.fixture(scope="session", params=[{"core_version": "2.7.0", "hdmf_version": "1.8.0"}]) def nwb_core_fixture(request) -> NamespacesAdapter: nwb_core = io.load_nwb_core(**request.param) - nwb_core.populate_imports() assert ( request.param["core_version"] in nwb_core.versions["core"] ) # 2.6.0 is actually 2.6.0-alpha diff --git a/nwb_linkml/tests/test_adapters/test_adapter_namespaces.py b/nwb_linkml/tests/test_adapters/test_adapter_namespaces.py index 5124bdd..bbcb739 100644 --- a/nwb_linkml/tests/test_adapters/test_adapter_namespaces.py +++ b/nwb_linkml/tests/test_adapters/test_adapter_namespaces.py @@ -46,3 +46,17 @@ def test_skip_imports(nwb_core_fixture): # we shouldn't have any of the hdmf-common schema in with us namespaces = [sch.annotations["namespace"].value for sch in res.schemas] assert all([ns == "core" for ns in namespaces]) + + +@pytest.mark.skip() +def test_populate_inheritance(nwb_core_fixture): + """ + Classes should receive and override the properties of their parents + when they have neurodata_type_inc + Args: + nwb_core_fixture: + + Returns: + + """ + pass diff --git a/nwb_linkml/tests/test_generate.py b/nwb_linkml/tests/test_generate.py index 70b08bc..529cdd1 100644 --- a/nwb_linkml/tests/test_generate.py +++ b/nwb_linkml/tests/test_generate.py @@ -76,6 +76,7 @@ def test_generate_pydantic(tmp_output_dir): initfile.write("# Autogenerated module indicator") +@pytest.mark.linkml @pytest.mark.provider @pytest.mark.dev def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture): @@ -84,6 +85,7 @@ def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture): result = provider.build(nwb_core_fixture) +@pytest.mark.pydantic @pytest.mark.provider @pytest.mark.dev def test_generate_pydantic_provider(tmp_output_dir): diff --git a/nwb_linkml/tests/test_includes/conftest.py b/nwb_linkml/tests/test_includes/conftest.py new file mode 100644 index 0000000..9eacd9f --- /dev/null +++ b/nwb_linkml/tests/test_includes/conftest.py @@ -0,0 +1,120 @@ +from typing import Tuple + +import numpy as np +import pytest + +from nwb_linkml.models import ( + ElectricalSeries, + ExtracellularEphysElectrodes, + Device, + ElectrodeGroup, + DynamicTableRegion, + Units, + IntracellularElectrode, + IntracellularElectrodesTable, + IntracellularResponsesTable, + IntracellularStimuliTable, + IntracellularRecordingsTable, +) + + +@pytest.fixture() +def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrodes"]: + """ + Demo electrical series with adjoining electrodes + """ + n_electrodes = 5 + n_times = 100 + data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float) + timestamps = np.linspace(0, 1, n_times) + + device = Device(name="my electrode") + + # electrode group is the physical description of the electrodes + electrode_group = ElectrodeGroup( + name="GroupA", + device=device, + description="an electrode group", + location="you know where it is", + ) + + # make electrodes tables + electrodes = ExtracellularEphysElectrodes( + description="idk these are also electrodes", + id=np.arange(0, n_electrodes), + x=np.arange(0, n_electrodes).astype(float), + y=np.arange(n_electrodes, n_electrodes * 2).astype(float), + group=[electrode_group] * n_electrodes, + group_name=[electrode_group.name] * n_electrodes, + location=[str(i) for i in range(n_electrodes)], + extra_column=["sup"] * n_electrodes, + ) + + electrical_series = ElectricalSeries( + name="my recording!", + electrodes=DynamicTableRegion( + table=electrodes, + value=np.arange(n_electrodes - 1, -1, step=-1), + name="electrodes", + description="hey", + ), + timestamps=timestamps, + data=data, + ) + return electrical_series, electrodes + + +def _ragged_array(n_units: int) -> tuple[list[np.ndarray], np.ndarray]: + generator = np.random.default_rng() + spike_times = [ + np.full(shape=generator.integers(10, 50), fill_value=i, dtype=float) for i in range(n_units) + ] + spike_idx = [] + for i in range(n_units): + if i == 0: + spike_idx.append(len(spike_times[0])) + else: + spike_idx.append(len(spike_times[i]) + spike_idx[i - 1]) + spike_idx = np.array(spike_idx) + return spike_times, spike_idx + + +@pytest.fixture(params=[True, False]) +def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]: + """ + Test case for units + + Parameterized by extra_column because pandas likes to pivot dataframes + to long when there is only one column and it's not len() == 1 + """ + spike_times, spike_idx = _ragged_array(24) + + spike_times_flat = np.concatenate(spike_times) + + kwargs = { + "description": "units!!!!", + "spike_times": spike_times_flat, + "spike_times_index": spike_idx, + } + if request.param: + kwargs["extra_column"] = ["hey!"] * 24 + units = Units(**kwargs) + return units, spike_times, spike_idx + + +@pytest.fixture() +def intracellular_recordings_table() -> IntracellularRecordingsTable: + n_recordings = 10 + device = Device(name="my device") + electrode = IntracellularElectrode( + name="my_electrode", description="an electrode", device=device + ) + electrodes = IntracellularElectrodesTable( + name="intracellular_electrodes", electrode=[electrode] * n_recordings + ) + stimuli = IntracellularStimuliTable( + name="intracellular_stimuli", + ) + responses = IntracellularResponsesTable() + + recordings_table = IntracellularRecordingsTable() diff --git a/nwb_linkml/tests/test_includes/test_hdmf.py b/nwb_linkml/tests/test_includes/test_hdmf.py index e00c02e..b21e51a 100644 --- a/nwb_linkml/tests/test_includes/test_hdmf.py +++ b/nwb_linkml/tests/test_includes/test_hdmf.py @@ -1,103 +1,13 @@ -from typing import Tuple - import numpy as np -import pytest # FIXME: Make this just be the output of the provider by patching into import machinery from nwb_linkml.models.pydantic.core.v2_7_0.namespace import ( - Device, DynamicTable, DynamicTableRegion, - ElectricalSeries, ElectrodeGroup, - ExtracellularEphysElectrodes, - Units, VectorIndex, ) - - -@pytest.fixture() -def electrical_series() -> Tuple["ElectricalSeries", "ExtracellularEphysElectrodes"]: - """ - Demo electrical series with adjoining electrodes - """ - n_electrodes = 5 - n_times = 100 - data = np.arange(0, n_electrodes * n_times).reshape(n_times, n_electrodes).astype(float) - timestamps = np.linspace(0, 1, n_times) - - device = Device(name="my electrode") - - # electrode group is the physical description of the electrodes - electrode_group = ElectrodeGroup( - name="GroupA", - device=device, - description="an electrode group", - location="you know where it is", - ) - - # make electrodes tables - electrodes = ExtracellularEphysElectrodes( - description="idk these are also electrodes", - id=np.arange(0, n_electrodes), - x=np.arange(0, n_electrodes).astype(float), - y=np.arange(n_electrodes, n_electrodes * 2).astype(float), - group=[electrode_group] * n_electrodes, - group_name=[electrode_group.name] * n_electrodes, - location=[str(i) for i in range(n_electrodes)], - extra_column=["sup"] * n_electrodes, - ) - - electrical_series = ElectricalSeries( - name="my recording!", - electrodes=DynamicTableRegion( - table=electrodes, - value=np.arange(n_electrodes - 1, -1, step=-1), - name="electrodes", - description="hey", - ), - timestamps=timestamps, - data=data, - ) - return electrical_series, electrodes - - -def _ragged_array(n_units: int) -> tuple[list[np.ndarray], np.ndarray]: - generator = np.random.default_rng() - spike_times = [ - np.full(shape=generator.integers(10, 50), fill_value=i, dtype=float) for i in range(n_units) - ] - spike_idx = [] - for i in range(n_units): - if i == 0: - spike_idx.append(len(spike_times[0])) - else: - spike_idx.append(len(spike_times[i]) + spike_idx[i - 1]) - spike_idx = np.array(spike_idx) - return spike_times, spike_idx - - -@pytest.fixture(params=[True, False]) -def units(request) -> Tuple[Units, list[np.ndarray], np.ndarray]: - """ - Test case for units - - Parameterized by extra_column because pandas likes to pivot dataframes - to long when there is only one column and it's not len() == 1 - """ - spike_times, spike_idx = _ragged_array(24) - - spike_times_flat = np.concatenate(spike_times) - - kwargs = { - "description": "units!!!!", - "spike_times": spike_times_flat, - "spike_times_index": spike_idx, - } - if request.param: - kwargs["extra_column"] = ["hey!"] * 24 - units = Units(**kwargs) - return units, spike_times, spike_idx +from .conftest import _ragged_array def test_dynamictable_indexing(electrical_series): diff --git a/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py b/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py index ef04312..84132d0 100644 --- a/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py +++ b/nwb_schema_language/src/nwb_schema_language/datamodel/nwb_schema_pydantic.py @@ -220,8 +220,8 @@ class DtypeMixin(ConfiguredBaseModel): class Attribute(DtypeMixin): name: str = Field(...) - dims: Optional[List[Union[Any, str]]] = Field(default_factory=list) - shape: Optional[List[Union[Any, int, str]]] = Field(default_factory=list) + dims: Optional[List[Union[Any, str]]] = Field(None) + shape: Optional[List[Union[Any, int, str]]] = Field(None) value: Optional[Any] = Field( None, description="""Optional constant, fixed value for the attribute.""" ) @@ -233,9 +233,7 @@ class Attribute(DtypeMixin): True, description="""Optional boolean key describing whether the attribute is required. Default value is True.""", ) - dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field( - default_factory=list - ) + dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(None) class Dataset(DtypeMixin): @@ -250,8 +248,8 @@ class Dataset(DtypeMixin): ) name: Optional[str] = Field(None) default_name: Optional[str] = Field(None) - dims: Optional[List[Union[Any, str]]] = Field(default_factory=list) - shape: Optional[List[Union[Any, int, str]]] = Field(default_factory=list) + dims: Optional[List[Union[Any, str]]] = Field(None) + shape: Optional[List[Union[Any, int, str]]] = Field(None) value: Optional[Any] = Field( None, description="""Optional constant, fixed value for the attribute.""" ) @@ -261,7 +259,5 @@ class Dataset(DtypeMixin): doc: str = Field(..., description="""Description of corresponding object.""") quantity: Optional[Union[QuantityEnum, int]] = Field(1) linkable: Optional[bool] = Field(None) - attributes: Optional[List[Attribute]] = Field(default_factory=list) - dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field( - default_factory=list - ) + attributes: Optional[List[Attribute]] = Field(None) + dtype: Optional[Union[List[CompoundDtype], FlatDtype, ReferenceDtype]] = Field(None)