ruff unsafe fixes

This commit is contained in:
sneakers-the-rat 2024-07-01 23:05:47 -07:00
parent 084bceaa2e
commit 7c6e69c87e
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
16 changed files with 87 additions and 158 deletions

View file

@ -27,7 +27,7 @@ import subprocess
import warnings
from pathlib import Path
from types import ModuleType
from typing import TYPE_CHECKING, Dict, List, Optional, Union, overload
from typing import TYPE_CHECKING, Dict, List, Never, Optional, Union, overload
import h5py
import numpy as np
@ -100,10 +100,7 @@ class HDF5IO:
provider = self.make_provider()
h5f = h5py.File(str(self.path))
if path:
src = h5f.get(path)
else:
src = h5f
src = h5f.get(path) if path else h5f
# get all children of selected item
if isinstance(src, (h5py.File, h5py.Group)):
@ -127,7 +124,7 @@ class HDF5IO:
else:
return queue.completed[path].result
def write(self, path: Path):
def write(self, path: Path) -> Never:
"""
Write to NWB file
@ -193,7 +190,7 @@ def read_specs_as_dicts(group: h5py.Group) -> dict:
"""
spec_dict = {}
def _read_spec(name, node):
def _read_spec(name, node) -> None:
if isinstance(node, h5py.Dataset):
# make containing dict if they dont exist
@ -233,7 +230,7 @@ def find_references(h5f: h5py.File, path: str) -> List[str]:
"""
references = []
def _find_references(name, obj: h5py.Group | h5py.Dataset):
def _find_references(name, obj: h5py.Group | h5py.Dataset) -> None:
pbar.update()
refs = []
for attr in obj.attrs.values():
@ -254,7 +251,6 @@ def find_references(h5f: h5py.File, path: str) -> List[str]:
for ref in refs:
assert isinstance(ref, h5py.h5r.Reference)
refname = h5f[ref].name
if name == path:
references.append(name)
return
@ -281,10 +277,7 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
Returns:
:class:`pathlib.Path` path of the truncated file
"""
if target is None:
target = source.parent / (source.stem + "_truncated.hdf5")
else:
target = Path(target)
target = source.parent / (source.stem + "_truncated.hdf5") if target is None else Path(target)
source = Path(source)
@ -300,10 +293,9 @@ def truncate_file(source: Path, target: Optional[Path] = None, n: int = 10) -> P
to_resize = []
def _need_resizing(name: str, obj: h5py.Dataset | h5py.Group):
if isinstance(obj, h5py.Dataset):
if obj.size > n:
to_resize.append(name)
def _need_resizing(name: str, obj: h5py.Dataset | h5py.Group) -> None:
if isinstance(obj, h5py.Dataset) and obj.size > n:
to_resize.append(name)
print("Resizing datasets...")
# first we get the items that need to be resized and then resize them below

View file

@ -17,7 +17,7 @@ from nwb_schema_language.datamodel.nwb_schema_pydantic import FlatDtype as FlatD
FlatDType = EnumDefinition(
name="FlatDType",
permissible_values=[PermissibleValue(p) for p in FlatDtype_source.__members__.keys()],
permissible_values=[PermissibleValue(p) for p in FlatDtype_source.__members__],
)
DTypeTypes = []

View file

@ -5,6 +5,7 @@ We have sort of diverged from the initial idea of a generalized map as in :class
so we will make our own mapping class here and re-evaluate whether they should be unified later
"""
import contextlib
import datetime
import inspect
from abc import abstractmethod
@ -187,10 +188,7 @@ def check_empty(obj: h5py.Group) -> bool:
children_empty = True
# if we have no attrs and we are a leaf OR our children are empty, remove us
if no_attrs and (no_children or children_empty):
return True
else:
return False
return bool(no_attrs and (no_children or children_empty))
class PruneEmpty(HDF5Map):
@ -244,10 +242,7 @@ class ResolveDynamicTable(HDF5Map):
# we might replace DynamicTable in the future, and there isn't a stable DynamicTable
# class to inherit from anyway because of the whole multiple versions thing
parents = [parent.__name__ for parent in model.__mro__]
if "DynamicTable" in parents:
return True
else:
return False
return "DynamicTable" in parents
else:
return False
@ -322,10 +317,7 @@ class ResolveModelGroup(HDF5Map):
def check(
cls, src: H5SourceItem, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool:
if "neurodata_type" in src.attrs and src.h5_type == "group":
return True
else:
return False
return bool("neurodata_type" in src.attrs and src.h5_type == "group")
@classmethod
def apply(
@ -336,14 +328,14 @@ class ResolveModelGroup(HDF5Map):
depends = []
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
for key, type in model.model_fields.items():
for key in model.model_fields.keys():
if key == "children":
res[key] = {name: resolve_hardlink(child) for name, child in obj.items()}
depends.extend([resolve_hardlink(child) for child in obj.values()])
elif key in obj.attrs:
res[key] = obj.attrs[key]
continue
elif key in obj.keys():
elif key in obj:
# make sure it's not empty
if check_empty(obj[key]):
continue
@ -386,10 +378,7 @@ class ResolveDatasetAsDict(HDF5Map):
if src.h5_type == "dataset" and "neurodata_type" not in src.attrs:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
if obj.shape != ():
return True
else:
return False
return obj.shape != ()
else:
return False
@ -420,10 +409,7 @@ class ResolveScalars(HDF5Map):
if src.h5_type == "dataset" and "neurodata_type" not in src.attrs:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
if obj.shape == ():
return True
else:
return False
return obj.shape == ()
else:
return False
@ -456,10 +442,7 @@ class ResolveContainerGroups(HDF5Map):
if src.h5_type == "group" and "neurodata_type" not in src.attrs and len(src.attrs) == 0:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
if len(obj.keys()) > 0:
return True
else:
return False
return len(obj.keys()) > 0
else:
return False
@ -515,10 +498,7 @@ class CompletePassThrough(HDF5Map):
) -> bool:
passthrough_ops = ("ResolveDynamicTable", "ResolveDatasetAsDict", "ResolveScalars")
for op in passthrough_ops:
if hasattr(src, "applied") and op in src.applied:
return True
return False
return any(hasattr(src, "applied") and op in src.applied for op in passthrough_ops)
@classmethod
def apply(
@ -542,15 +522,7 @@ class CompleteContainerGroups(HDF5Map):
def check(
cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool:
if (
src.model is None
and src.neurodata_type is None
and src.source.h5_type == "group"
and all([depend in completed for depend in src.depends])
):
return True
else:
return False
return (src.model is None and src.neurodata_type is None and src.source.h5_type == "group" and all([depend in completed for depend in src.depends]))
@classmethod
def apply(
@ -574,15 +546,7 @@ class CompleteModelGroups(HDF5Map):
def check(
cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool:
if (
src.model is not None
and src.source.h5_type == "group"
and src.neurodata_type != "NWBFile"
and all([depend in completed for depend in src.depends])
):
return True
else:
return False
return (src.model is not None and src.source.h5_type == "group" and src.neurodata_type != "NWBFile" and all([depend in completed for depend in src.depends]))
@classmethod
def apply(
@ -639,10 +603,7 @@ class CompleteNWBFile(HDF5Map):
def check(
cls, src: H5ReadResult, provider: SchemaProvider, completed: Dict[str, H5ReadResult]
) -> bool:
if src.neurodata_type == "NWBFile" and all([depend in completed for depend in src.depends]):
return True
else:
return False
return (src.neurodata_type == "NWBFile" and all([depend in completed for depend in src.depends]))
@classmethod
def apply(
@ -724,14 +685,14 @@ class ReadQueue(BaseModel):
default_factory=list, description="Phases that have already been completed"
)
def apply_phase(self, phase: ReadPhases, max_passes=5):
def apply_phase(self, phase: ReadPhases, max_passes=5) -> None:
phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase]
phase_maps = sorted(phase_maps, key=lambda x: x.priority)
results = []
# TODO: Thread/multiprocess this
for name, item in self.queue.items():
for item in self.queue.values():
for op in phase_maps:
if op.check(item, self.provider, self.completed):
# Formerly there was an "exclusive" property in the maps which let potentially multiple
@ -768,10 +729,8 @@ class ReadQueue(BaseModel):
# delete the ones that were already completed but might have been
# incorrectly added back in the pile
for c in completes:
try:
with contextlib.suppress(KeyError):
del self.queue[c]
except KeyError:
pass
# if we have nothing left in our queue, we have completed this phase
# and prepare only ever has one pass
@ -798,7 +757,7 @@ def flatten_hdf(h5f: h5py.File | h5py.Group, skip="specifications") -> Dict[str,
"""
items = {}
def _itemize(name: str, obj: h5py.Dataset | h5py.Group):
def _itemize(name: str, obj: h5py.Dataset | h5py.Group) -> None:
if skip in name:
return

View file

@ -23,10 +23,7 @@ def model_from_dynamictable(group: h5py.Group, base: Optional[BaseModel] = None)
for col in colnames:
nptype = group[col].dtype
if nptype.type == np.void:
nptype = struct_from_dtype(nptype)
else:
nptype = nptype.type
nptype = struct_from_dtype(nptype) if nptype.type == np.void else nptype.type
type_ = Optional[NDArray[Any, nptype]]
@ -53,7 +50,7 @@ def dynamictable_to_model(
items = {}
for col, col_type in model.model_fields.items():
if col not in group.keys():
if col not in group:
if col in group.attrs:
items[col] = group.attrs[col]
continue

View file

@ -3,7 +3,7 @@ Monkeypatches to external modules
"""
def patch_npytyping_perf():
def patch_npytyping_perf() -> None:
"""
npytyping makes an expensive call to inspect.stack()
that makes imports of pydantic models take ~200x longer than
@ -43,7 +43,7 @@ def patch_npytyping_perf():
base_meta_classes.SubscriptableMeta._get_module = new_get_module
def patch_nptyping_warnings():
def patch_nptyping_warnings() -> None:
"""
nptyping shits out a bunch of numpy deprecation warnings from using
olde aliases
@ -53,7 +53,7 @@ def patch_nptyping_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning, module="nptyping.*")
def patch_schemaview():
def patch_schemaview() -> None:
"""
Patch schemaview to correctly resolve multiple layers of relative imports.
@ -114,7 +114,7 @@ def patch_schemaview():
SchemaView.imports_closure = imports_closure
def apply_patches():
def apply_patches() -> None:
patch_npytyping_perf()
patch_nptyping_warnings()
patch_schemaview()

View file

@ -55,10 +55,7 @@ class Node:
def make_node(element: Group | Dataset, parent=None, recurse: bool = True) -> List[Node]:
if element.neurodata_type_def is None:
if element.name is None:
if element.neurodata_type_inc is None:
name = "anonymous"
else:
name = element.neurodata_type_inc
name = "anonymous" if element.neurodata_type_inc is None else element.neurodata_type_inc
else:
name = element.name
id = name + "-" + str(random.randint(0, 1000))
@ -88,7 +85,6 @@ def make_graph(namespaces: "NamespacesAdapter", recurse: bool = True) -> List[Cy
nodes = []
element: Namespace | Group | Dataset
print("walking graph")
i = 0
for element in namespaces.walk_types(namespaces, (Group, Dataset)):
if element.neurodata_type_def is None:
# skip child nodes at top level, we'll get them in recursion

View file

@ -157,7 +157,7 @@ class GitRepo:
return self._commit
@commit.setter
def commit(self, commit: str | None):
def commit(self, commit: str | None) -> None:
# setting commit as None should do nothing if we have already cloned,
# and if we are just cloning we will always be at the most recent commit anyway
if commit is not None:
@ -199,7 +199,7 @@ class GitRepo:
return res.stdout.decode("utf-8").strip()
@tag.setter
def tag(self, tag: str):
def tag(self, tag: str) -> None:
# first check that we have the most recent tags
self._git_call("fetch", "--all", "--tags")
self._git_call("checkout", f"tags/{tag}")
@ -227,10 +227,7 @@ class GitRepo:
"""
res = self._git_call("branch", "--show-current")
branch = res.stdout.decode("utf-8").strip()
if not branch:
return True
else:
return False
return not branch
def check(self) -> bool:
"""
@ -262,7 +259,7 @@ class GitRepo:
# otherwise we're good
return True
def cleanup(self, force: bool = False):
def cleanup(self, force: bool = False) -> None:
"""
Delete contents of temporary directory
@ -285,7 +282,7 @@ class GitRepo:
shutil.rmtree(str(self.temp_directory))
self._temp_directory = None
def clone(self, force: bool = False):
def clone(self, force: bool = False) -> None:
"""
Clone the repository into the temporary directory

View file

@ -94,10 +94,7 @@ class Provider(ABC):
PROVIDES_CLASS: P = None
def __init__(self, path: Optional[Path] = None, allow_repo: bool = True, verbose: bool = True):
if path is not None:
config = Config(cache_dir=path)
else:
config = Config()
config = Config(cache_dir=path) if path is not None else Config()
self.config = config
self.cache_dir = config.cache_dir
self.allow_repo = allow_repo
@ -352,22 +349,21 @@ class LinkMLProvider(Provider):
of the build. If ``force == False`` and the schema already exist, it will be ``None``
"""
if not force:
if all(
[
(self.namespace_path(ns, version) / "namespace.yaml").exists()
for ns, version in ns_adapter.versions.items()
]
):
return {
k: LinkMLSchemaBuild(
name=k,
result=None,
namespace=self.namespace_path(k, v) / "namespace.yaml",
version=v,
)
for k, v in ns_adapter.versions.items()
}
if not force and all(
[
(self.namespace_path(ns, version) / "namespace.yaml").exists()
for ns, version in ns_adapter.versions.items()
]
):
return {
k: LinkMLSchemaBuild(
name=k,
result=None,
namespace=self.namespace_path(k, v) / "namespace.yaml",
version=v,
)
for k, v in ns_adapter.versions.items()
}
# self._find_imports(ns_adapter, versions, populate=True)
if self.verbose:
@ -427,7 +423,7 @@ class LinkMLProvider(Provider):
self, sch: SchemaDefinition, ns_adapter: adapters.NamespacesAdapter, output_file: Path
) -> SchemaDefinition:
for animport in sch.imports:
if animport.split(".")[0] in ns_adapter.versions.keys():
if animport.split(".")[0] in ns_adapter.versions:
imported_path = (
self.namespace_path(
animport.split(".")[0], ns_adapter.versions[animport.split(".")[0]]
@ -485,7 +481,7 @@ class PydanticProvider(Provider):
PROVIDES = "pydantic"
def __init__(self, path: Optional[Path] = None, verbose: bool = True):
super(PydanticProvider, self).__init__(path, verbose)
super().__init__(path, verbose)
# create a metapathfinder to find module we might create
pathfinder = EctopicModelFinder(self.path)
sys.meta_path.append(pathfinder)
@ -771,7 +767,7 @@ class PydanticProvider(Provider):
return module
@staticmethod
def _clear_package_imports():
def _clear_package_imports() -> None:
"""
When using allow_repo=False, delete any already-imported
namespaces from sys.modules that are within the nwb_linkml package
@ -826,7 +822,7 @@ class EctopicModelFinder(MetaPathFinder):
MODEL_STEM = "nwb_linkml.models.pydantic"
def __init__(self, path: Path, *args, **kwargs):
super(EctopicModelFinder, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.path = path
def find_spec(self, fullname, path, target=None):
@ -883,7 +879,7 @@ class SchemaProvider(Provider):
**kwargs: passed to superclass __init__ (see :class:`.Provider` )
"""
self.versions = versions
super(SchemaProvider, self).__init__(**kwargs)
super().__init__(**kwargs)
@property
def path(self) -> Path:

View file

@ -85,7 +85,7 @@ class DataFrame(BaseModel, pd.DataFrame):
df = df.fillna(np.nan).replace([np.nan], [None])
return df
def update_df(self):
def update_df(self) -> None:
"""
Update the internal dataframe in the case that the model values are changed
in a way that we can't detect, like appending to one of the lists.
@ -99,7 +99,7 @@ class DataFrame(BaseModel, pd.DataFrame):
"""
if item in ("df", "_df"):
return self.__pydantic_private__["_df"]
elif item in self.model_fields.keys():
elif item in self.model_fields:
return self._df[item]
else:
try:
@ -108,7 +108,7 @@ class DataFrame(BaseModel, pd.DataFrame):
return object.__getattribute__(self, item)
@model_validator(mode="after")
def recreate_df(self):
def recreate_df(self) -> None:
"""
Remake DF when validating (eg. when updating values on assignment)
"""
@ -137,11 +137,11 @@ def dynamictable_to_df(
model = model_from_dynamictable(group, base)
items = {}
for col, col_type in model.model_fields.items():
if col not in group.keys():
for col, _col_type in model.model_fields.items():
if col not in group:
continue
idxname = col + "_index"
if idxname in group.keys():
if idxname in group:
idx = group.get(idxname)[:]
data = group.get(col)[idx - 1]
else:

View file

@ -40,16 +40,10 @@ def _list_of_lists_schema(shape, array_type_handler):
for arg, label in zip(shape_args, shape_labels):
# which handler to use? for the first we use the actual type
# handler, everywhere else we use the prior list handler
if list_schema is None:
inner_schema = array_type_handler
else:
inner_schema = list_schema
inner_schema = array_type_handler if list_schema is None else list_schema
# make a label annotation, if we have one
if label is not None:
metadata = {"name": label}
else:
metadata = None
metadata = {"name": label} if label is not None else None
# make the current level list schema, accounting for shape
if arg == "*":
@ -66,7 +60,8 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
"""
Kept here to allow for hooking into metaclass, which has
been necessary on and off as we work this class into a stable
state"""
state
"""
class NDArray(NPTypingType, metaclass=NDArrayMeta):

View file

@ -45,13 +45,13 @@ class AdapterProgress:
self.progress, title="Building Namespaces", border_style="green", padding=(2, 2)
)
def update(self, namespace: str, **kwargs):
def update(self, namespace: str, **kwargs) -> None:
self.progress.update(self.task_ids[namespace], **kwargs)
def start(self):
def start(self) -> None:
self.progress.start()
def stop(self):
def stop(self) -> None:
self.progress.stop()
def __enter__(self) -> Live:

View file

@ -10,7 +10,7 @@ def test_nothing(nwb_core_fixture):
def _compare_dicts(dict1, dict2) -> bool:
"""just in one direction - that all the entries in dict1 are in dict2"""
assert all([dict1[k] == dict2[k] for k in dict1.keys()])
assert all([dict1[k] == dict2[k] for k in dict1])
# assert all([dict1[k] == dict2[k] for k in dict2.keys()])

View file

@ -84,10 +84,7 @@ def imported_schema(linkml_schema, request) -> TestModules:
Convenience fixture for testing non-core generator features without needing to re-generate and
import every time.
"""
if request.param == "split":
split = True
else:
split = False
split = request.param == "split"
yield generate_and_import(linkml_schema, split)
@ -188,7 +185,7 @@ def test_arraylike(imported_schema):
# check that we have gotten an NDArray annotation and its shape is correct
array = imported_schema["core"].MainTopLevel.model_fields["array"].annotation
args = typing.get_args(array)
for i, shape in enumerate(("* x, * y", "* x, * y, 3 z", "* x, * y, 3 z, 4 a")):
for i, _ in enumerate(("* x, * y", "* x, * y, 3 z", "* x, * y, 3 z, 4 a")):
assert isinstance(args[i], NDArrayMeta)
assert args[i].__args__[0].__args__
assert args[i].__args__[1] == np.number
@ -213,8 +210,8 @@ def test_linkml_meta(imported_schema):
"""
meta = imported_schema["core"].LinkML_Meta
assert "tree_root" in meta.model_fields
assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root == True
assert imported_schema["core"].OtherClass.linkml_meta.default.tree_root == False
assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root
assert not imported_schema["core"].OtherClass.linkml_meta.default.tree_root
def test_skip(linkml_schema):

View file

@ -26,11 +26,11 @@ def test_preload_maps():
yaml.dump(hdmf_style_naming, temp_f, Dumper=Dumper)
loaded = load_yaml(Path(temp_name))
assert "neurodata_type_def" in loaded["groups"][0].keys()
assert "data_type_def" not in loaded["groups"][0].keys()
assert "neurodata_type_inc" in loaded["groups"][0].keys()
assert "data_type_inc" not in loaded["groups"][0].keys()
assert "neurodata_type_inc" in loaded["groups"][0]["datasets"][0].keys()
assert "data_type_inc" not in loaded["groups"][0]["datasets"][0].keys()
assert "neurodata_type_def" in loaded["groups"][0]
assert "data_type_def" not in loaded["groups"][0]
assert "neurodata_type_inc" in loaded["groups"][0]
assert "data_type_inc" not in loaded["groups"][0]
assert "neurodata_type_inc" in loaded["groups"][0]["datasets"][0]
assert "data_type_inc" not in loaded["groups"][0]["datasets"][0]
os.remove(temp_name)

View file

@ -88,7 +88,7 @@ def test_ndarray_serialize():
mod_str = mod.model_dump_json()
mod_json = json.loads(mod_str)
for a in ("array", "shape", "dtype", "unpack_fns"):
assert a in mod_json["large_array"].keys()
assert a in mod_json["large_array"]
assert isinstance(mod_json["large_array"]["array"], str)
assert isinstance(mod_json["small_array"], list)

View file

@ -83,7 +83,7 @@ patch_contact_single_multiple = Patch(
)
def run_patches(phase: Phases, verbose: bool = False):
def run_patches(phase: Phases, verbose: bool = False) -> None:
patches = [p for p in Patch.instances if p.phase == phase]
for patch in patches:
if verbose:
@ -96,7 +96,7 @@ def run_patches(phase: Phases, verbose: bool = False):
pfile.write(string)
def main():
def main() -> None:
parser = argparse.ArgumentParser(description="Run patches for a given phase of code generation")
parser.add_argument("--phase", choices=list(Phases.__members__.keys()), type=Phases)
args = parser.parse_args()