ruff unsafe fixes

This commit is contained in:
sneakers-the-rat 2024-07-01 23:05:47 -07:00
parent 084bceaa2e
commit 7c6e69c87e
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
16 changed files with 87 additions and 158 deletions

View file

@ -27,7 +27,7 @@ import subprocess
import warnings import warnings
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload from typing import TYPE_CHECKING, Dict, List, Never, Optional, Union, overload
import h5py import h5py
import numpy as np import numpy as np
@ -100,10 +100,7 @@ class HDF5IO:
provider = self.make_provider() provider = self.make_provider()
h5f = h5py.File(str(self.path)) h5f = h5py.File(str(self.path))
if path: src = h5f.get(path) if path else h5f
src = h5f.get(path)
else:
src = h5f
# get all children of selected item # get all children of selected item
if isinstance(src, (h5py.File, h5py.Group)): if isinstance(src, (h5py.File, h5py.Group)):
@ -127,7 +124,7 @@ class HDF5IO:
else: else:
return queue.completed[path].result return queue.completed[path].result
def write(self, path: Path): def write(self, path: Path) -> Never:
""" """
Write to NWB file Write to NWB file
@ -193,7 +190,7 @@ def read_specs_as_dicts(group: h5py.Group) -> dict:
""" """
spec_dict = {} spec_dict = {}
def _read_spec(name, node): def _read_spec(name, node) -> None:
if isinstance(node, h5py.Dataset): if isinstance(node, h5py.Dataset):
# make containing dict if they dont exist # make containing dict if they dont exist
@ -233,7 +230,7 @@ def find_references(h5f: h5py.File, path: str) -> List[str]:
""" """
references = [] references = []
def _find_references(name, obj: h5py.Group | h5py.Dataset): def _find_references(name, obj: h5py.Group | h5py.Dataset) -> None:
pbar.update() pbar.update()
refs = [] refs = []
for attr in obj.attrs.values(): for attr in obj.attrs.values():
@ -254,7 +251,6 @@ def find_references(h5f: h5py.File, path: str) -> List[str]:
for ref in refs: for ref in refs:
assert isinstance(ref, h5py.h5r.Reference) assert isinstance(ref, h5py.h5r.Reference)
refname = h5f[ref].name
if name == path: if name == path:
references.append(name) references.append(name)
return return
@ -281,10 +277,7 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
Returns: Returns:
:class:`pathlib.Path` path of the truncated file :class:`pathlib.Path` path of the truncated file
""" """
if target is None: target = source.parent / (source.stem + "_truncated.hdf5") if target is None else Path(target)
target = source.parent / (source.stem + "_truncated.hdf5")
else:
target = Path(target)
source = Path(source) source = Path(source)
@ -300,9 +293,8 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
to_resize = [] to_resize = []
def _need_resizing(name: str, obj: h5py.Dataset | h5py.Group): def _need_resizing(name: str, obj: h5py.Dataset | h5py.Group) -> None:
if isinstance(obj, h5py.Dataset): if isinstance(obj, h5py.Dataset) and obj.size > n:
if obj.size > n:
to_resize.append(name) to_resize.append(name)
print("Resizing datasets...") print("Resizing datasets...")

View file

@ -17,7 +17,7 @@ from nwb_schema_language.datamodel.nwb_schema_pydantic import FlatDtype as FlatD
FlatDType = EnumDefinition( FlatDType = EnumDefinition(
name="FlatDType", name="FlatDType",
permissible_values=[PermissibleValue(p) for p in FlatDtype_source.__members__.keys()], permissible_values=[PermissibleValue(p) for p in FlatDtype_source.__members__],
) )
DTypeTypes = [] DTypeTypes = []

View file

@ -5,6 +5,7 @@ We have sort of diverged from the initial idea of a generalized map as in :class
so we will make our own mapping class here and re-evaluate whether they should be unified later so we will make our own mapping class here and re-evaluate whether they should be unified later
""" """
import contextlib
import datetime import datetime
import inspect import inspect
from abc import abstractmethod from abc import abstractmethod
@ -187,10 +188,7 @@ def check_empty(obj: h5py.Group) -> bool:
children_empty = True children_empty = True
# if we have no attrs and we are a leaf OR our children are empty, remove us # if we have no attrs and we are a leaf OR our children are empty, remove us
if no_attrs and (no_children or children_empty): return bool(no_attrs and (no_children or children_empty))
return True
else:
return False
class PruneEmpty(HDF5Map): class PruneEmpty(HDF5Map):
@ -244,10 +242,7 @@ class ResolveDynamicTable(HDF5Map):
# we might replace DynamicTable in the future, and there isn't a stable DynamicTable # we might replace DynamicTable in the future, and there isn't a stable DynamicTable
# class to inherit from anyway because of the whole multiple versions thing # class to inherit from anyway because of the whole multiple versions thing
parents = [parent.__name__ for parent in model.__mro__] parents = [parent.__name__ for parent in model.__mro__]
if "DynamicTable" in parents: return "DynamicTable" in parents
return True
else:
return False
else: else:
return False return False
@ -322,10 +317,7 @@ class ResolveModelGroup(HDF5Map):
def check( def check(
cls, src: H5SourceItem, provider: SchemaProvider, completed: Dict[str, H5ReadResult] cls, src: H5SourceItem, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool: ) -> bool:
if "neurodata_type" in src.attrs and src.h5_type == "group": return bool("neurodata_type" in src.attrs and src.h5_type == "group")
return True
else:
return False
@classmethod @classmethod
def apply( def apply(
@ -336,14 +328,14 @@ class ResolveModelGroup(HDF5Map):
depends = [] depends = []
with h5py.File(src.h5f_path, "r") as h5f: with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path) obj = h5f.get(src.path)
for key, type in model.model_fields.items(): for key in model.model_fields.keys():
if key == "children": if key == "children":
res[key] = {name: resolve_hardlink(child) for name, child in obj.items()} res[key] = {name: resolve_hardlink(child) for name, child in obj.items()}
depends.extend([resolve_hardlink(child) for child in obj.values()]) depends.extend([resolve_hardlink(child) for child in obj.values()])
elif key in obj.attrs: elif key in obj.attrs:
res[key] = obj.attrs[key] res[key] = obj.attrs[key]
continue continue
elif key in obj.keys(): elif key in obj:
# make sure it's not empty # make sure it's not empty
if check_empty(obj[key]): if check_empty(obj[key]):
continue continue
@ -386,10 +378,7 @@ class ResolveDatasetAsDict(HDF5Map):
if src.h5_type == "dataset" and "neurodata_type" not in src.attrs: if src.h5_type == "dataset" and "neurodata_type" not in src.attrs:
with h5py.File(src.h5f_path, "r") as h5f: with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path) obj = h5f.get(src.path)
if obj.shape != (): return obj.shape != ()
return True
else:
return False
else: else:
return False return False
@ -420,10 +409,7 @@ class ResolveScalars(HDF5Map):
if src.h5_type == "dataset" and "neurodata_type" not in src.attrs: if src.h5_type == "dataset" and "neurodata_type" not in src.attrs:
with h5py.File(src.h5f_path, "r") as h5f: with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path) obj = h5f.get(src.path)
if obj.shape == (): return obj.shape == ()
return True
else:
return False
else: else:
return False return False
@ -456,10 +442,7 @@ class ResolveContainerGroups(HDF5Map):
if src.h5_type == "group" and "neurodata_type" not in src.attrs and len(src.attrs) == 0: if src.h5_type == "group" and "neurodata_type" not in src.attrs and len(src.attrs) == 0:
with h5py.File(src.h5f_path, "r") as h5f: with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path) obj = h5f.get(src.path)
if len(obj.keys()) > 0: return len(obj.keys()) > 0
return True
else:
return False
else: else:
return False return False
@ -515,10 +498,7 @@ class CompletePassThrough(HDF5Map):
) -> bool: ) -> bool:
passthrough_ops = ("ResolveDynamicTable", "ResolveDatasetAsDict", "ResolveScalars") passthrough_ops = ("ResolveDynamicTable", "ResolveDatasetAsDict", "ResolveScalars")
for op in passthrough_ops: return any(hasattr(src, "applied") and op in src.applied for op in passthrough_ops)
if hasattr(src, "applied") and op in src.applied:
return True
return False
@classmethod @classmethod
def apply( def apply(
@ -542,15 +522,7 @@ class CompleteContainerGroups(HDF5Map):
def check( def check(
cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult] cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool: ) -> bool:
if ( return (src.model is None and src.neurodata_type is None and src.source.h5_type == "group" and all([depend in completed for depend in src.depends]))
src.model is None
and src.neurodata_type is None
and src.source.h5_type == "group"
and all([depend in completed for depend in src.depends])
):
return True
else:
return False
@classmethod @classmethod
def apply( def apply(
@ -574,15 +546,7 @@ class CompleteModelGroups(HDF5Map):
def check( def check(
cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult] cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool: ) -> bool:
if ( return (src.model is not None and src.source.h5_type == "group" and src.neurodata_type != "NWBFile" and all([depend in completed for depend in src.depends]))
src.model is not None
and src.source.h5_type == "group"
and src.neurodata_type != "NWBFile"
and all([depend in completed for depend in src.depends])
):
return True
else:
return False
@classmethod @classmethod
def apply( def apply(
@ -639,10 +603,7 @@ class CompleteNWBFile(HDF5Map):
def check( def check(
cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult] cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool: ) -> bool:
if src.neurodata_type == "NWBFile" and all([depend in completed for depend in src.depends]): return (src.neurodata_type == "NWBFile" and all([depend in completed for depend in src.depends]))
return True
else:
return False
@classmethod @classmethod
def apply( def apply(
@ -724,14 +685,14 @@ class ReadQueue(BaseModel):
default_factory=list, description="Phases that have already been completed" default_factory=list, description="Phases that have already been completed"
) )
def apply_phase(self, phase: ReadPhases, max_passes=5): def apply_phase(self, phase: ReadPhases, max_passes=5) -> None:
phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase] phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase]
phase_maps = sorted(phase_maps, key=lambda x: x.priority) phase_maps = sorted(phase_maps, key=lambda x: x.priority)
results = [] results = []
# TODO: Thread/multiprocess this # TODO: Thread/multiprocess this
for name, item in self.queue.items(): for item in self.queue.values():
for op in phase_maps: for op in phase_maps:
if op.check(item, self.provider, self.completed): if op.check(item, self.provider, self.completed):
# Formerly there was an "exclusive" property in the maps which let potentially multiple # Formerly there was an "exclusive" property in the maps which let potentially multiple
@ -768,10 +729,8 @@ class ReadQueue(BaseModel):
# delete the ones that were already completed but might have been # delete the ones that were already completed but might have been
# incorrectly added back in the pile # incorrectly added back in the pile
for c in completes: for c in completes:
try: with contextlib.suppress(KeyError):
del self.queue[c] del self.queue[c]
except KeyError:
pass
# if we have nothing left in our queue, we have completed this phase # if we have nothing left in our queue, we have completed this phase
# and prepare only ever has one pass # and prepare only ever has one pass
@ -798,7 +757,7 @@ def flatten_hdf(h5f: h5py.File | h5py.Group, skip="specifications") -> Dict[str,
""" """
items = {} items = {}
def _itemize(name: str, obj: h5py.Dataset | h5py.Group): def _itemize(name: str, obj: h5py.Dataset | h5py.Group) -> None:
if skip in name: if skip in name:
return return

View file

@ -23,10 +23,7 @@ def model_from_dynamictable(group: h5py.Group, base: Optional[BaseModel] = None)
for col in colnames: for col in colnames:
nptype = group[col].dtype nptype = group[col].dtype
if nptype.type == np.void: nptype = struct_from_dtype(nptype) if nptype.type == np.void else nptype.type
nptype = struct_from_dtype(nptype)
else:
nptype = nptype.type
type_ = Optional[NDArray[Any, nptype]] type_ = Optional[NDArray[Any, nptype]]
@ -53,7 +50,7 @@ def dynamictable_to_model(
items = {} items = {}
for col, col_type in model.model_fields.items(): for col, col_type in model.model_fields.items():
if col not in group.keys(): if col not in group:
if col in group.attrs: if col in group.attrs:
items[col] = group.attrs[col] items[col] = group.attrs[col]
continue continue

View file

@ -3,7 +3,7 @@ Monkeypatches to external modules
""" """
def patch_npytyping_perf(): def patch_npytyping_perf() -> None:
""" """
npytyping makes an expensive call to inspect.stack() npytyping makes an expensive call to inspect.stack()
that makes imports of pydantic models take ~200x longer than that makes imports of pydantic models take ~200x longer than
@ -43,7 +43,7 @@ def patch_npytyping_perf():
base_meta_classes.SubscriptableMeta._get_module = new_get_module base_meta_classes.SubscriptableMeta._get_module = new_get_module
def patch_nptyping_warnings(): def patch_nptyping_warnings() -> None:
""" """
nptyping shits out a bunch of numpy deprecation warnings from using nptyping shits out a bunch of numpy deprecation warnings from using
olde aliases olde aliases
@ -53,7 +53,7 @@ def patch_nptyping_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning, module="nptyping.*") warnings.filterwarnings("ignore", category=DeprecationWarning, module="nptyping.*")
def patch_schemaview(): def patch_schemaview() -> None:
""" """
Patch schemaview to correctly resolve multiple layers of relative imports. Patch schemaview to correctly resolve multiple layers of relative imports.
@ -114,7 +114,7 @@ def patch_schemaview():
SchemaView.imports_closure = imports_closure SchemaView.imports_closure = imports_closure
def apply_patches(): def apply_patches() -> None:
patch_npytyping_perf() patch_npytyping_perf()
patch_nptyping_warnings() patch_nptyping_warnings()
patch_schemaview() patch_schemaview()

View file

@ -55,10 +55,7 @@ class Node:
def make_node(element: Group | Dataset, parent=None, recurse: bool = True) -> List[Node]: def make_node(element: Group | Dataset, parent=None, recurse: bool = True) -> List[Node]:
if element.neurodata_type_def is None: if element.neurodata_type_def is None:
if element.name is None: if element.name is None:
if element.neurodata_type_inc is None: name = "anonymous" if element.neurodata_type_inc is None else element.neurodata_type_inc
name = "anonymous"
else:
name = element.neurodata_type_inc
else: else:
name = element.name name = element.name
id = name + "-" + str(random.randint(0, 1000)) id = name + "-" + str(random.randint(0, 1000))
@ -88,7 +85,6 @@ def make_graph(namespaces: "NamespacesAdapter", recurse: bool = True) -> List[Cy
nodes = [] nodes = []
element: Namespace | Group | Dataset element: Namespace | Group | Dataset
print("walking graph") print("walking graph")
i = 0
for element in namespaces.walk_types(namespaces, (Group, Dataset)): for element in namespaces.walk_types(namespaces, (Group, Dataset)):
if element.neurodata_type_def is None: if element.neurodata_type_def is None:
# skip child nodes at top level, we'll get them in recursion # skip child nodes at top level, we'll get them in recursion

View file

@ -157,7 +157,7 @@ class GitRepo:
return self._commit return self._commit
@commit.setter @commit.setter
def commit(self, commit: str | None): def commit(self, commit: str | None) -> None:
# setting commit as None should do nothing if we have already cloned, # setting commit as None should do nothing if we have already cloned,
# and if we are just cloning we will always be at the most recent commit anyway # and if we are just cloning we will always be at the most recent commit anyway
if commit is not None: if commit is not None:
@ -199,7 +199,7 @@ class GitRepo:
return res.stdout.decode("utf-8").strip() return res.stdout.decode("utf-8").strip()
@tag.setter @tag.setter
def tag(self, tag: str): def tag(self, tag: str) -> None:
# first check that we have the most recent tags # first check that we have the most recent tags
self._git_call("fetch", "--all", "--tags") self._git_call("fetch", "--all", "--tags")
self._git_call("checkout", f"tags/{tag}") self._git_call("checkout", f"tags/{tag}")
@ -227,10 +227,7 @@ class GitRepo:
""" """
res = self._git_call("branch", "--show-current") res = self._git_call("branch", "--show-current")
branch = res.stdout.decode("utf-8").strip() branch = res.stdout.decode("utf-8").strip()
if not branch: return not branch
return True
else:
return False
def check(self) -> bool: def check(self) -> bool:
""" """
@ -262,7 +259,7 @@ class GitRepo:
# otherwise we're good # otherwise we're good
return True return True
def cleanup(self, force: bool = False): def cleanup(self, force: bool = False) -> None:
""" """
Delete contents of temporary directory Delete contents of temporary directory
@ -285,7 +282,7 @@ class GitRepo:
shutil.rmtree(str(self.temp_directory)) shutil.rmtree(str(self.temp_directory))
self._temp_directory = None self._temp_directory = None
def clone(self, force: bool = False): def clone(self, force: bool = False) -> None:
""" """
Clone the repository into the temporary directory Clone the repository into the temporary directory

View file

@ -94,10 +94,7 @@ class Provider(ABC):
PROVIDES_CLASS: P = None PROVIDES_CLASS: P = None
def __init__(self, path: Optional[Path] = None, allow_repo: bool = True, verbose: bool = True): def __init__(self, path: Optional[Path] = None, allow_repo: bool = True, verbose: bool = True):
if path is not None: config = Config(cache_dir=path) if path is not None else Config()
config = Config(cache_dir=path)
else:
config = Config()
self.config = config self.config = config
self.cache_dir = config.cache_dir self.cache_dir = config.cache_dir
self.allow_repo = allow_repo self.allow_repo = allow_repo
@ -352,8 +349,7 @@ class LinkMLProvider(Provider):
of the build. If ``force == False`` and the schema already exist, it will be ``None`` of the build. If ``force == False`` and the schema already exist, it will be ``None``
""" """
if not force: if not force and all(
if all(
[ [
(self.namespace_path(ns, version) / "namespace.yaml").exists() (self.namespace_path(ns, version) / "namespace.yaml").exists()
for ns, version in ns_adapter.versions.items() for ns, version in ns_adapter.versions.items()
@ -427,7 +423,7 @@ class LinkMLProvider(Provider):
self, sch: SchemaDefinition, ns_adapter: adapters.NamespacesAdapter, output_file: Path self, sch: SchemaDefinition, ns_adapter: adapters.NamespacesAdapter, output_file: Path
) -> SchemaDefinition: ) -> SchemaDefinition:
for animport in sch.imports: for animport in sch.imports:
if animport.split(".")[0] in ns_adapter.versions.keys(): if animport.split(".")[0] in ns_adapter.versions:
imported_path = ( imported_path = (
self.namespace_path( self.namespace_path(
animport.split(".")[0], ns_adapter.versions[animport.split(".")[0]] animport.split(".")[0], ns_adapter.versions[animport.split(".")[0]]
@ -485,7 +481,7 @@ class PydanticProvider(Provider):
PROVIDES = "pydantic" PROVIDES = "pydantic"
def __init__(self, path: Optional[Path] = None, verbose: bool = True): def __init__(self, path: Optional[Path] = None, verbose: bool = True):
super(PydanticProvider, self).__init__(path, verbose) super().__init__(path, verbose)
# create a metapathfinder to find module we might create # create a metapathfinder to find module we might create
pathfinder = EctopicModelFinder(self.path) pathfinder = EctopicModelFinder(self.path)
sys.meta_path.append(pathfinder) sys.meta_path.append(pathfinder)
@ -771,7 +767,7 @@ class PydanticProvider(Provider):
return module return module
@staticmethod @staticmethod
def _clear_package_imports(): def _clear_package_imports() -> None:
""" """
When using allow_repo=False, delete any already-imported When using allow_repo=False, delete any already-imported
namespaces from sys.modules that are within the nwb_linkml package namespaces from sys.modules that are within the nwb_linkml package
@ -826,7 +822,7 @@ class EctopicModelFinder(MetaPathFinder):
MODEL_STEM = "nwb_linkml.models.pydantic" MODEL_STEM = "nwb_linkml.models.pydantic"
def __init__(self, path: Path, *args, **kwargs): def __init__(self, path: Path, *args, **kwargs):
super(EctopicModelFinder, self).__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.path = path self.path = path
def find_spec(self, fullname, path, target=None): def find_spec(self, fullname, path, target=None):
@ -883,7 +879,7 @@ class SchemaProvider(Provider):
**kwargs: passed to superclass __init__ (see :class:`.Provider` ) **kwargs: passed to superclass __init__ (see :class:`.Provider` )
""" """
self.versions = versions self.versions = versions
super(SchemaProvider, self).__init__(**kwargs) super().__init__(**kwargs)
@property @property
def path(self) -> Path: def path(self) -> Path:

View file

@ -85,7 +85,7 @@ class DataFrame(BaseModel, pd.DataFrame):
df = df.fillna(np.nan).replace([np.nan], [None]) df = df.fillna(np.nan).replace([np.nan], [None])
return df return df
def update_df(self): def update_df(self) -> None:
""" """
Update the internal dataframe in the case that the model values are changed Update the internal dataframe in the case that the model values are changed
in a way that we can't detect, like appending to one of the lists. in a way that we can't detect, like appending to one of the lists.
@ -99,7 +99,7 @@ class DataFrame(BaseModel, pd.DataFrame):
""" """
if item in ("df", "_df"): if item in ("df", "_df"):
return self.__pydantic_private__["_df"] return self.__pydantic_private__["_df"]
elif item in self.model_fields.keys(): elif item in self.model_fields:
return self._df[item] return self._df[item]
else: else:
try: try:
@ -108,7 +108,7 @@ class DataFrame(BaseModel, pd.DataFrame):
return object.__getattribute__(self, item) return object.__getattribute__(self, item)
@model_validator(mode="after") @model_validator(mode="after")
def recreate_df(self): def recreate_df(self) -> None:
""" """
Remake DF when validating (eg. when updating values on assignment) Remake DF when validating (eg. when updating values on assignment)
""" """
@ -137,11 +137,11 @@ def dynamictable_to_df(
model = model_from_dynamictable(group, base) model = model_from_dynamictable(group, base)
items = {} items = {}
for col, col_type in model.model_fields.items(): for col, _col_type in model.model_fields.items():
if col not in group.keys(): if col not in group:
continue continue
idxname = col + "_index" idxname = col + "_index"
if idxname in group.keys(): if idxname in group:
idx = group.get(idxname)[:] idx = group.get(idxname)[:]
data = group.get(col)[idx - 1] data = group.get(col)[idx - 1]
else: else:

View file

@ -40,16 +40,10 @@ def _list_of_lists_schema(shape, array_type_handler):
for arg, label in zip(shape_args, shape_labels): for arg, label in zip(shape_args, shape_labels):
# which handler to use? for the first we use the actual type # which handler to use? for the first we use the actual type
# handler, everywhere else we use the prior list handler # handler, everywhere else we use the prior list handler
if list_schema is None: inner_schema = array_type_handler if list_schema is None else list_schema
inner_schema = array_type_handler
else:
inner_schema = list_schema
# make a label annotation, if we have one # make a label annotation, if we have one
if label is not None: metadata = {"name": label} if label is not None else None
metadata = {"name": label}
else:
metadata = None
# make the current level list schema, accounting for shape # make the current level list schema, accounting for shape
if arg == "*": if arg == "*":
@ -66,7 +60,8 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
""" """
Kept here to allow for hooking into metaclass, which has Kept here to allow for hooking into metaclass, which has
been necessary on and off as we work this class into a stable been necessary on and off as we work this class into a stable
state""" state
"""
class NDArray(NPTypingType, metaclass=NDArrayMeta): class NDArray(NPTypingType, metaclass=NDArrayMeta):

View file

@ -45,13 +45,13 @@ class AdapterProgress:
self.progress, title="Building Namespaces", border_style="green", padding=(2, 2) self.progress, title="Building Namespaces", border_style="green", padding=(2, 2)
) )
def update(self, namespace: str, **kwargs): def update(self, namespace: str, **kwargs) -> None:
self.progress.update(self.task_ids[namespace], **kwargs) self.progress.update(self.task_ids[namespace], **kwargs)
def start(self): def start(self) -> None:
self.progress.start() self.progress.start()
def stop(self): def stop(self) -> None:
self.progress.stop() self.progress.stop()
def __enter__(self) -> Live: def __enter__(self) -> Live:

View file

@ -10,7 +10,7 @@ def test_nothing(nwb_core_fixture):
def _compare_dicts(dict1, dict2) -> bool: def _compare_dicts(dict1, dict2) -> bool:
"""just in one direction - that all the entries in dict1 are in dict2""" """just in one direction - that all the entries in dict1 are in dict2"""
assert all([dict1[k] == dict2[k] for k in dict1.keys()]) assert all([dict1[k] == dict2[k] for k in dict1])
# assert all([dict1[k] == dict2[k] for k in dict2.keys()]) # assert all([dict1[k] == dict2[k] for k in dict2.keys()])

View file

@ -84,10 +84,7 @@ def imported_schema(linkml_schema, request) -> TestModules:
Convenience fixture for testing non-core generator features without needing to re-generate and Convenience fixture for testing non-core generator features without needing to re-generate and
import every time. import every time.
""" """
if request.param == "split": split = request.param == "split"
split = True
else:
split = False
yield generate_and_import(linkml_schema, split) yield generate_and_import(linkml_schema, split)
@ -188,7 +185,7 @@ def test_arraylike(imported_schema):
# check that we have gotten an NDArray annotation and its shape is correct # check that we have gotten an NDArray annotation and its shape is correct
array = imported_schema["core"].MainTopLevel.model_fields["array"].annotation array = imported_schema["core"].MainTopLevel.model_fields["array"].annotation
args = typing.get_args(array) args = typing.get_args(array)
for i, shape in enumerate(("* x, * y", "* x, * y, 3 z", "* x, * y, 3 z, 4 a")): for i, _ in enumerate(("* x, * y", "* x, * y, 3 z", "* x, * y, 3 z, 4 a")):
assert isinstance(args[i], NDArrayMeta) assert isinstance(args[i], NDArrayMeta)
assert args[i].__args__[0].__args__ assert args[i].__args__[0].__args__
assert args[i].__args__[1] == np.number assert args[i].__args__[1] == np.number
@ -213,8 +210,8 @@ def test_linkml_meta(imported_schema):
""" """
meta = imported_schema["core"].LinkML_Meta meta = imported_schema["core"].LinkML_Meta
assert "tree_root" in meta.model_fields assert "tree_root" in meta.model_fields
assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root == True assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root
assert imported_schema["core"].OtherClass.linkml_meta.default.tree_root == False assert not imported_schema["core"].OtherClass.linkml_meta.default.tree_root
def test_skip(linkml_schema): def test_skip(linkml_schema):

View file

@ -26,11 +26,11 @@ def test_preload_maps():
yaml.dump(hdmf_style_naming, temp_f, Dumper=Dumper) yaml.dump(hdmf_style_naming, temp_f, Dumper=Dumper)
loaded = load_yaml(Path(temp_name)) loaded = load_yaml(Path(temp_name))
assert "neurodata_type_def" in loaded["groups"][0].keys() assert "neurodata_type_def" in loaded["groups"][0]
assert "data_type_def" not in loaded["groups"][0].keys() assert "data_type_def" not in loaded["groups"][0]
assert "neurodata_type_inc" in loaded["groups"][0].keys() assert "neurodata_type_inc" in loaded["groups"][0]
assert "data_type_inc" not in loaded["groups"][0].keys() assert "data_type_inc" not in loaded["groups"][0]
assert "neurodata_type_inc" in loaded["groups"][0]["datasets"][0].keys() assert "neurodata_type_inc" in loaded["groups"][0]["datasets"][0]
assert "data_type_inc" not in loaded["groups"][0]["datasets"][0].keys() assert "data_type_inc" not in loaded["groups"][0]["datasets"][0]
os.remove(temp_name) os.remove(temp_name)

View file

@ -88,7 +88,7 @@ def test_ndarray_serialize():
mod_str = mod.model_dump_json() mod_str = mod.model_dump_json()
mod_json = json.loads(mod_str) mod_json = json.loads(mod_str)
for a in ("array", "shape", "dtype", "unpack_fns"): for a in ("array", "shape", "dtype", "unpack_fns"):
assert a in mod_json["large_array"].keys() assert a in mod_json["large_array"]
assert isinstance(mod_json["large_array"]["array"], str) assert isinstance(mod_json["large_array"]["array"], str)
assert isinstance(mod_json["small_array"], list) assert isinstance(mod_json["small_array"], list)

View file

@ -83,7 +83,7 @@ patch_contact_single_multiple = Patch(
) )
def run_patches(phase: Phases, verbose: bool = False): def run_patches(phase: Phases, verbose: bool = False) -> None:
patches = [p for p in Patch.instances if p.phase == phase] patches = [p for p in Patch.instances if p.phase == phase]
for patch in patches: for patch in patches:
if verbose: if verbose:
@ -96,7 +96,7 @@ def run_patches(phase: Phases, verbose: bool = False):
pfile.write(string) pfile.write(string)
def main(): def main() -> None:
parser = argparse.ArgumentParser(description="Run patches for a given phase of code generation") parser = argparse.ArgumentParser(description="Run patches for a given phase of code generation")
parser.add_argument("--phase", choices=list(Phases.__members__.keys()), type=Phases) parser.add_argument("--phase", choices=list(Phases.__members__.keys()), type=Phases)
args = parser.parse_args() args = parser.parse_args()