This commit is contained in:
sneakers-the-rat 2024-07-24 22:49:35 -07:00
parent af11bb61ec
commit 3d8403e9e3
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
10 changed files with 111 additions and 90 deletions

View file

@ -1,6 +1,7 @@
""" """
Adapter for NWB datasets to linkml Classes Adapter for NWB datasets to linkml Classes
""" """
from abc import abstractmethod from abc import abstractmethod
from typing import ClassVar, Optional, Type from typing import ClassVar, Optional, Type

View file

@ -63,7 +63,8 @@ from pydantic import BaseModel
from nwb_linkml.maps import flat_to_nptyping from nwb_linkml.maps import flat_to_nptyping
from nwb_linkml.maps.naming import module_case, version_module_case from nwb_linkml.maps.naming import module_case, version_module_case
OPTIONAL_PATTERN = re.compile(r'Optional\[([\w\.]*)\]') OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
@dataclass @dataclass
class NWBPydanticGenerator(PydanticGenerator): class NWBPydanticGenerator(PydanticGenerator):
@ -77,7 +78,6 @@ class NWBPydanticGenerator(PydanticGenerator):
) )
split: bool = True split: bool = True
schema_map: Optional[Dict[str, SchemaDefinition]] = None schema_map: Optional[Dict[str, SchemaDefinition]] = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from""" """See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""
array_representations: List[ArrayRepresentation] = field( array_representations: List[ArrayRepresentation] = field(
@ -89,8 +89,7 @@ class NWBPydanticGenerator(PydanticGenerator):
gen_classvars: bool = True gen_classvars: bool = True
gen_slots: bool = True gen_slots: bool = True
skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias")
skip_meta: ClassVar[Tuple[str]] = ('domain_of','alias')
def _check_anyof( def _check_anyof(
self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView
@ -125,9 +124,9 @@ class NWBPydanticGenerator(PydanticGenerator):
del slot.attribute.meta[key] del slot.attribute.meta[key]
# make array ranges in any_of # make array ranges in any_of
if 'any_of' in slot.attribute.meta: if "any_of" in slot.attribute.meta:
any_ofs = slot.attribute.meta['any_of'] any_ofs = slot.attribute.meta["any_of"]
if all(['array' in expr for expr in any_ofs]): if all(["array" in expr for expr in any_ofs]):
ranges = [] ranges = []
is_optional = False is_optional = False
for expr in any_ofs: for expr in any_ofs:
@ -136,20 +135,19 @@ class NWBPydanticGenerator(PydanticGenerator):
is_optional = OPTIONAL_PATTERN.match(pyrange) is_optional = OPTIONAL_PATTERN.match(pyrange)
if is_optional: if is_optional:
pyrange = is_optional.groups()[0] pyrange = is_optional.groups()[0]
range_generator = NumpydanticArray(ArrayExpression(**expr['array']), pyrange) range_generator = NumpydanticArray(ArrayExpression(**expr["array"]), pyrange)
ranges.append(range_generator.make().range) ranges.append(range_generator.make().range)
slot.attribute.range = 'Union[' + ', '.join(ranges) + ']' slot.attribute.range = "Union[" + ", ".join(ranges) + "]"
if is_optional: if is_optional:
slot.attribute.range = 'Optional[' + slot.attribute.range + ']' slot.attribute.range = "Optional[" + slot.attribute.range + "]"
del slot.attribute.meta['any_of'] del slot.attribute.meta["any_of"]
return slot return slot
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule: def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
if 'source_file' in template.meta: if "source_file" in template.meta:
del template.meta['source_file'] del template.meta["source_file"]
def compile_module( def compile_module(
self, module_path: Path = None, module_name: str = "test", **kwargs self, module_path: Path = None, module_name: str = "test", **kwargs
@ -171,8 +169,6 @@ class NWBPydanticGenerator(PydanticGenerator):
raise e raise e
def compile_python( def compile_python(
text_or_fn: str, package_path: Path = None, module_name: str = "test" text_or_fn: str, package_path: Path = None, module_name: str = "test"
) -> ModuleType: ) -> ModuleType:

View file

@ -3,9 +3,10 @@ Utility functions for dealing with yaml files.
No we are not going to implement a yaml parser here No we are not going to implement a yaml parser here
""" """
import re import re
from pathlib import Path from pathlib import Path
from typing import Literal, List, Union, overload from typing import List, Literal, Union, overload
import yaml import yaml
@ -13,15 +14,26 @@ from nwb_linkml.maps.postload import apply_postload
@overload @overload
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:Literal[True]=True) -> str: ... def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: Literal[True] = True
) -> str: ...
@overload @overload
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:Literal[False]=False) -> List[str]: ... def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: Literal[False] = False
) -> List[str]: ...
@overload @overload
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:bool=True) -> Union[str, List[str]]: ... def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: bool = True
) -> Union[str, List[str]]: ...
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:bool=True) -> Union[str, List[str]]:
def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: bool = True
) -> Union[str, List[str]]:
""" """
Peek into a yaml file without parsing the whole file to retrieve the value of a single key. Peek into a yaml file without parsing the whole file to retrieve the value of a single key.
@ -43,27 +55,27 @@ def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:bool=Tru
str str
""" """
if root: if root:
pattern = re.compile(rf'^(?P<key>{key}):\s*(?P<value>\S.*)') pattern = re.compile(rf"^(?P<key>{key}):\s*(?P<value>\S.*)")
else: else:
pattern = re.compile(rf'^\s*(?P<key>{key}):\s*(?P<value>\S.*)') pattern = re.compile(rf"^\s*(?P<key>{key}):\s*(?P<value>\S.*)")
res = None res = None
if first: if first:
with open(path, 'r') as yfile: with open(path) as yfile:
for l in yfile: for line in yfile:
res = pattern.match(l) res = pattern.match(line)
if res: if res:
break break
if res: if res:
return res.groupdict()['value'] return res.groupdict()["value"]
else: else:
with open(path, 'r') as yfile: with open(path) as yfile:
text = yfile.read() text = yfile.read()
res = [match.groupdict()['value'] for match in pattern.finditer(text)] res = [match.groupdict()["value"] for match in pattern.finditer(text)]
if res: if res:
return res return res
raise KeyError(f'Key {key} not found in {path}') raise KeyError(f"Key {key} not found in {path}")
def load_yaml(path: Path | str) -> dict: def load_yaml(path: Path | str) -> dict:

View file

@ -2,7 +2,9 @@
Language elements in nwb schema language that have a fixed, alternative representation Language elements in nwb schema language that have a fixed, alternative representation
in LinkML. These are exported as an nwb.language.yml file along with every generated namespace in LinkML. These are exported as an nwb.language.yml file along with every generated namespace
""" """
from typing import List from typing import List
from linkml_runtime.linkml_model import ( from linkml_runtime.linkml_model import (
ClassDefinition, ClassDefinition,
EnumDefinition, EnumDefinition,
@ -35,12 +37,15 @@ def _make_dtypes() -> List[TypeDefinition]:
np_type = flat_to_np[nwbtype] np_type = flat_to_np[nwbtype]
repr_string = f'np.{np_type.__name__}' if np_type.__module__ == 'numpy' else None repr_string = f"np.{np_type.__name__}" if np_type.__module__ == "numpy" else None
atype = TypeDefinition(name=nwbtype, minimum_value=amin, typeof=linkmltype, repr=repr_string) atype = TypeDefinition(
name=nwbtype, minimum_value=amin, typeof=linkmltype, repr=repr_string
)
DTypeTypes.append(atype) DTypeTypes.append(atype)
return DTypeTypes return DTypeTypes
DTypeTypes = _make_dtypes() DTypeTypes = _make_dtypes()
AnyType = ClassDefinition( AnyType = ClassDefinition(

View file

@ -2,7 +2,7 @@
Mapping from one domain to another Mapping from one domain to another
""" """
from nwb_linkml.maps.dtype import flat_to_linkml, flat_to_nptyping, flat_to_np from nwb_linkml.maps.dtype import flat_to_linkml, flat_to_np, flat_to_nptyping
from nwb_linkml.maps.map import Map from nwb_linkml.maps.map import Map
from nwb_linkml.maps.postload import MAP_HDMF_DATATYPE_DEF, MAP_HDMF_DATATYPE_INC from nwb_linkml.maps.postload import MAP_HDMF_DATATYPE_DEF, MAP_HDMF_DATATYPE_INC
from nwb_linkml.maps.quantity import QUANTITY_MAP from nwb_linkml.maps.quantity import QUANTITY_MAP

View file

@ -11,8 +11,8 @@ from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import List, Optional, Type from typing import List, Optional, Type
from linkml.generators.pydanticgen.pydanticgen import SplitMode, _ensure_inits, _import_to_path
from linkml_runtime.linkml_model.meta import SchemaDefinition from linkml_runtime.linkml_model.meta import SchemaDefinition
from linkml.generators.pydanticgen.pydanticgen import SplitMode, _import_to_path, _ensure_inits
from pydantic import BaseModel from pydantic import BaseModel
from nwb_linkml.generators.pydantic import NWBPydanticGenerator from nwb_linkml.generators.pydantic import NWBPydanticGenerator
@ -107,7 +107,6 @@ class PydanticProvider(Provider):
# given a path to a namespace linkml yaml file # given a path to a namespace linkml yaml file
path = Path(namespace) path = Path(namespace)
if split: if split:
result = self._build_split(path, dump, force, **kwargs) result = self._build_split(path, dump, force, **kwargs)
else: else:
@ -116,14 +115,10 @@ class PydanticProvider(Provider):
self.install_pathfinder() self.install_pathfinder()
return result return result
def _build_unsplit( def _build_unsplit(self, path: Path, dump: bool, force: bool, **kwargs) -> str:
self, generator = NWBPydanticGenerator(
path: Path, str(path), split=False, split_pattern=self.SPLIT_PATTERN, **kwargs
dump: bool, )
force: bool,
**kwargs
) -> str:
generator = NWBPydanticGenerator(str(path), split=False, split_pattern=self.SPLIT_PATTERN, **kwargs)
out_module = generator.generate_module_import(generator.schemaview.schema) out_module = generator.generate_module_import(generator.schemaview.schema)
out_file = (self.path / _import_to_path(out_module)).resolve() out_file = (self.path / _import_to_path(out_module)).resolve()
if out_file.exists() and not force: if out_file.exists() and not force:
@ -141,14 +136,9 @@ class PydanticProvider(Provider):
return serialized return serialized
def _build_split( def _build_split(self, path: Path, dump: bool, force: bool, **kwargs) -> List[str]:
self, # FIXME: This is messy as all fuck, we're just getting it to work again
path: Path, # so we can start iterating on the models themselves
dump: bool,
force: bool,
**kwargs
) -> List[str]:
# FIXME: This is messy as all fuck, we're just getting it to work again so we can start iterating on the models themselves
res = [] res = []
module_paths = [] module_paths = []
@ -156,8 +146,10 @@ class PydanticProvider(Provider):
# remove any directory traversal at the head of the pattern for this, # remove any directory traversal at the head of the pattern for this,
# we're making relative to the provider's path not the generated schema at first # we're making relative to the provider's path not the generated schema at first
root_pattern = re.sub(r'^\.*', '', self.SPLIT_PATTERN) root_pattern = re.sub(r"^\.*", "", self.SPLIT_PATTERN)
gen = NWBPydanticGenerator(schema=path, split=True, split_pattern=root_pattern, split_mode=SplitMode.FULL) gen = NWBPydanticGenerator(
schema=path, split=True, split_pattern=root_pattern, split_mode=SplitMode.FULL
)
mod_name = gen.generate_module_import(gen.schemaview.schema) mod_name = gen.generate_module_import(gen.schemaview.schema)
ns_file = (self.path / _import_to_path(mod_name)).resolve() ns_file = (self.path / _import_to_path(mod_name)).resolve()
@ -170,11 +162,11 @@ class PydanticProvider(Provider):
ns_file.parent.mkdir(exist_ok=True, parents=True) ns_file.parent.mkdir(exist_ok=True, parents=True)
serialized = gen.serialize(rendered_module=rendered) serialized = gen.serialize(rendered_module=rendered)
if dump: if dump:
with open(ns_file, 'w') as ofile: with open(ns_file, "w") as ofile:
ofile.write(serialized) ofile.write(serialized)
module_paths.append(ns_file) module_paths.append(ns_file)
else: else:
with open(ns_file, 'r') as ofile: with open(ns_file) as ofile:
serialized = ofile.read() serialized = ofile.read()
res.append(serialized) res.append(serialized)
@ -189,27 +181,29 @@ class PydanticProvider(Provider):
import_file.parent.mkdir(exist_ok=True, parents=True) import_file.parent.mkdir(exist_ok=True, parents=True)
schema = imported_schema[generated_import.module] schema = imported_schema[generated_import.module]
is_namespace = False is_namespace = False
ns_annotation = schema.annotations.get('is_namespace', None) ns_annotation = schema.annotations.get("is_namespace", None)
if ns_annotation: if ns_annotation:
is_namespace = ns_annotation.value is_namespace = ns_annotation.value
# fix schema source to absolute path so schemaview can find imports # fix schema source to absolute path so schemaview can find imports
schema.source_file = (Path(gen.schemaview.schema.source_file).parent / schema.source_file).resolve() schema.source_file = (
Path(gen.schemaview.schema.source_file).parent / schema.source_file
).resolve()
import_gen = NWBPydanticGenerator( import_gen = NWBPydanticGenerator(
schema, schema,
split=True, split=True,
split_pattern=self.SPLIT_PATTERN, split_pattern=self.SPLIT_PATTERN,
split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO,
) )
serialized = import_gen.serialize() serialized = import_gen.serialize()
if dump: if dump:
with open(import_file, 'w') as ofile: with open(import_file, "w") as ofile:
ofile.write(serialized) ofile.write(serialized)
module_paths.append(import_file) module_paths.append(import_file)
else: else:
with open(import_file, 'r') as ofile: with open(import_file) as ofile:
serialized = ofile.read() serialized = ofile.read()
res.append(serialized) res.append(serialized)
@ -402,13 +396,17 @@ class PydanticProvider(Provider):
mod = self.get(namespace, version) mod = self.get(namespace, version)
return getattr(mod, class_) return getattr(mod, class_)
def install_pathfinder(self): def install_pathfinder(self) -> None:
""" """
Add a :class:`.EctopicModelFinder` instance that allows us to import from Add a :class:`.EctopicModelFinder` instance that allows us to import from
the directory that we are generating models into the directory that we are generating models into
""" """
# check if one already exists # check if one already exists
matches = [finder for finder in sys.meta_path if isinstance(finder, EctopicModelFinder) and finder.path == self.path] matches = [
finder
for finder in sys.meta_path
if isinstance(finder, EctopicModelFinder) and finder.path == self.path
]
if len(matches) > 0: if len(matches) > 0:
return return

View file

@ -18,15 +18,21 @@ def pytest_addoption(parser):
"--without-cache", action="store_true", help="Don't use a sqlite cache for network requests" "--without-cache", action="store_true", help="Don't use a sqlite cache for network requests"
) )
parser.addoption( parser.addoption(
"--dev", action="store_true", help="run tests that are intended only for development use, eg. those that generate output for inspection" "--dev",
action="store_true",
help=(
"run tests that are intended only for development use, eg. those that generate output"
" for inspection"
),
) )
def pytest_collection_modifyitems(config, items: List[pytest.Item]): def pytest_collection_modifyitems(config, items: List[pytest.Item]):
# remove dev tests from collection if we're not in dev mode! # remove dev tests from collection if we're not in dev mode!
if config.getoption('--dev'): if config.getoption("--dev"):
remove_tests = [t for t in items if not t.get_closest_marker('dev')] remove_tests = [t for t in items if not t.get_closest_marker("dev")]
else: else:
remove_tests = [t for t in items if t.get_closest_marker('dev')] remove_tests = [t for t in items if t.get_closest_marker("dev")]
for t in remove_tests: for t in remove_tests:
items.remove(t) items.remove(t)
@ -56,4 +62,3 @@ def patch_requests_cache(pytestconfig):
# delete cache file unless we have requested it to persist for inspection # delete cache file unless we have requested it to persist for inspection
if not pytestconfig.getoption("--with-output"): if not pytestconfig.getoption("--with-output"):
cache_file.unlink(missing_ok=True) cache_file.unlink(missing_ok=True)

View file

@ -18,11 +18,13 @@ from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from nwb_linkml.lang_elements import NwbLangSchema from nwb_linkml.lang_elements import NwbLangSchema
from nwb_linkml.providers import LinkMLProvider, PydanticProvider from nwb_linkml.providers import LinkMLProvider, PydanticProvider
@pytest.mark.dev @pytest.mark.dev
def test_generate_nwblang(tmp_output_dir): def test_generate_nwblang(tmp_output_dir):
output_file = (tmp_output_dir / NwbLangSchema.name).with_suffix(".yml") output_file = (tmp_output_dir / NwbLangSchema.name).with_suffix(".yml")
yaml_dumper.dump(NwbLangSchema, output_file) yaml_dumper.dump(NwbLangSchema, output_file)
@pytest.mark.dev @pytest.mark.dev
def test_generate_core(nwb_core_fixture, tmp_output_dir): def test_generate_core(nwb_core_fixture, tmp_output_dir):
schemas = nwb_core_fixture.build().schemas schemas = nwb_core_fixture.build().schemas
@ -43,6 +45,7 @@ def load_schema_files(path: Path) -> Dict[str, SchemaDefinition]:
preloaded_schema[sch.name] = sch preloaded_schema[sch.name] = sch
return preloaded_schema return preloaded_schema
@pytest.mark.dev @pytest.mark.dev
@pytest.mark.depends(on=["test_generate_core"]) @pytest.mark.depends(on=["test_generate_core"])
def test_generate_pydantic(tmp_output_dir): def test_generate_pydantic(tmp_output_dir):
@ -72,6 +75,7 @@ def test_generate_pydantic(tmp_output_dir):
with open(tmp_output_dir / "models" / "__init__.py", "w") as initfile: with open(tmp_output_dir / "models" / "__init__.py", "w") as initfile:
initfile.write("# Autogenerated module indicator") initfile.write("# Autogenerated module indicator")
@pytest.mark.provider @pytest.mark.provider
@pytest.mark.dev @pytest.mark.dev
def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture): def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
@ -80,9 +84,8 @@ def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
result = provider.build(nwb_core_fixture) result = provider.build(nwb_core_fixture)
@pytest.mark.provider @pytest.mark.provider
@pytest.mark.dev @pytest.mark.dev
def test_generate_pydantic_provider(tmp_output_dir): def test_generate_pydantic_provider(tmp_output_dir):
provider = PydanticProvider(path=tmp_output_dir, verbose=False) provider = PydanticProvider(path=tmp_output_dir, verbose=False)
result = provider.build('core') result = provider.build("core")

View file

@ -32,6 +32,7 @@ class TestModules(TypedDict):
TestModules.__test__ = False TestModules.__test__ = False
@pytest.mark.xfail() @pytest.mark.xfail()
def generate_and_import( def generate_and_import(
linkml_schema: TestSchemas, split: bool, generator_kwargs: Optional[dict] = None linkml_schema: TestSchemas, split: bool, generator_kwargs: Optional[dict] = None
@ -77,6 +78,7 @@ def generate_and_import(
return TestModules(core=core, imported=imported, namespace=namespace, split=split) return TestModules(core=core, imported=imported, namespace=namespace, split=split)
@pytest.mark.xfail() @pytest.mark.xfail()
@pytest.fixture(scope="module", params=["split", "unsplit"]) @pytest.fixture(scope="module", params=["split", "unsplit"])
def imported_schema(linkml_schema, request) -> TestModules: def imported_schema(linkml_schema, request) -> TestModules:
@ -105,6 +107,7 @@ def _model_correctness(modules: TestModules):
assert issubclass(modules["core"].StillAnotherClass, BaseModel) assert issubclass(modules["core"].StillAnotherClass, BaseModel)
assert issubclass(modules["imported"].MainThing, BaseModel) assert issubclass(modules["imported"].MainThing, BaseModel)
@pytest.mark.xfail() @pytest.mark.xfail()
def test_generate(linkml_schema): def test_generate(linkml_schema):
""" """
@ -130,6 +133,7 @@ def test_generate(linkml_schema):
del sys.modules["test_schema.imported"] del sys.modules["test_schema.imported"]
del sys.modules["test_schema.namespace"] del sys.modules["test_schema.namespace"]
@pytest.mark.xfail() @pytest.mark.xfail()
def test_generate_split(linkml_schema): def test_generate_split(linkml_schema):
""" """
@ -151,6 +155,7 @@ def test_generate_split(linkml_schema):
del sys.modules["test_schema.imported"] del sys.modules["test_schema.imported"]
del sys.modules["test_schema.namespace"] del sys.modules["test_schema.namespace"]
@pytest.mark.xfail() @pytest.mark.xfail()
def test_versions(linkml_schema): def test_versions(linkml_schema):
""" """
@ -195,6 +200,7 @@ def test_arraylike(imported_schema):
assert not hasattr(imported_schema["core"], "MainTopLevel__Array") assert not hasattr(imported_schema["core"], "MainTopLevel__Array")
assert not hasattr(imported_schema["core"], "MainTopLevelArray") assert not hasattr(imported_schema["core"], "MainTopLevelArray")
@pytest.mark.xfail() @pytest.mark.xfail()
def test_inject_fields(imported_schema): def test_inject_fields(imported_schema):
""" """
@ -215,6 +221,7 @@ def test_linkml_meta(imported_schema):
assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root
assert not imported_schema["core"].OtherClass.linkml_meta.default.tree_root assert not imported_schema["core"].OtherClass.linkml_meta.default.tree_root
@pytest.mark.xfail() @pytest.mark.xfail()
def test_skip(linkml_schema): def test_skip(linkml_schema):
""" """
@ -246,6 +253,7 @@ def test_inline_with_identifier(imported_schema):
assert otherclass is imported_schema["core"].OtherClass assert otherclass is imported_schema["core"].OtherClass
assert stillanother is imported_schema["core"].StillAnotherClass assert stillanother is imported_schema["core"].StillAnotherClass
@pytest.mark.xfail() @pytest.mark.xfail()
def test_namespace(imported_schema): def test_namespace(imported_schema):
""" """

View file

@ -3,18 +3,12 @@ import yaml
from nwb_linkml.io.yaml import yaml_peek from nwb_linkml.io.yaml import yaml_peek
@pytest.fixture() @pytest.fixture()
def yaml_file(tmp_path): def yaml_file(tmp_path):
data = { data = {"key1": "val1", "key2": "val2", "key3": {"key1": "val3", "key4": "val4"}}
'key1': 'val1', out_file = tmp_path / "test.yaml"
'key2': 'val2', with open(out_file, "w") as yfile:
'key3': {
'key1': 'val3',
'key4': 'val4'
}
}
out_file = tmp_path / 'test.yaml'
with open(out_file, 'w') as yfile:
yaml.dump(data, yfile) yaml.dump(data, yfile)
yield out_file yield out_file
@ -22,19 +16,18 @@ def yaml_file(tmp_path):
out_file.unlink() out_file.unlink()
@pytest.mark.parametrize( @pytest.mark.parametrize(
'key,expected,root,first', "key,expected,root,first",
[ [
('key1', 'val1', True, True), ("key1", "val1", True, True),
('key1', 'val1', False, True), ("key1", "val1", False, True),
('key1', ['val1'], True, False), ("key1", ["val1"], True, False),
('key1', ['val1', 'val3'], False, False), ("key1", ["val1", "val3"], False, False),
('key2', 'val2', True, True), ("key2", "val2", True, True),
('key3', False, True, True), ("key3", False, True, True),
('key4', False, True, True), ("key4", False, True, True),
('key4', 'val4', False, True) ("key4", "val4", False, True),
] ],
) )
def test_peek_yaml(key, expected, root, first, yaml_file): def test_peek_yaml(key, expected, root, first, yaml_file):
if not expected: if not expected: