This commit is contained in:
sneakers-the-rat 2024-07-24 22:49:35 -07:00
parent af11bb61ec
commit 3d8403e9e3
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
10 changed files with 111 additions and 90 deletions

View file

@ -1,6 +1,7 @@
"""
Adapter for NWB datasets to linkml Classes
"""
from abc import abstractmethod
from typing import ClassVar, Optional, Type

View file

@ -63,7 +63,8 @@ from pydantic import BaseModel
from nwb_linkml.maps import flat_to_nptyping
from nwb_linkml.maps.naming import module_case, version_module_case
OPTIONAL_PATTERN = re.compile(r'Optional\[([\w\.]*)\]')
OPTIONAL_PATTERN = re.compile(r"Optional\[([\w\.]*)\]")
@dataclass
class NWBPydanticGenerator(PydanticGenerator):
@ -77,7 +78,6 @@ class NWBPydanticGenerator(PydanticGenerator):
)
split: bool = True
schema_map: Optional[Dict[str, SchemaDefinition]] = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""
array_representations: List[ArrayRepresentation] = field(
@ -89,8 +89,7 @@ class NWBPydanticGenerator(PydanticGenerator):
gen_classvars: bool = True
gen_slots: bool = True
skip_meta: ClassVar[Tuple[str]] = ('domain_of','alias')
skip_meta: ClassVar[Tuple[str]] = ("domain_of", "alias")
def _check_anyof(
self, s: SlotDefinition, sn: SlotDefinitionName, sv: SchemaView
@ -125,9 +124,9 @@ class NWBPydanticGenerator(PydanticGenerator):
del slot.attribute.meta[key]
# make array ranges in any_of
if 'any_of' in slot.attribute.meta:
any_ofs = slot.attribute.meta['any_of']
if all(['array' in expr for expr in any_ofs]):
if "any_of" in slot.attribute.meta:
any_ofs = slot.attribute.meta["any_of"]
if all(["array" in expr for expr in any_ofs]):
ranges = []
is_optional = False
for expr in any_ofs:
@ -136,20 +135,19 @@ class NWBPydanticGenerator(PydanticGenerator):
is_optional = OPTIONAL_PATTERN.match(pyrange)
if is_optional:
pyrange = is_optional.groups()[0]
range_generator = NumpydanticArray(ArrayExpression(**expr['array']), pyrange)
range_generator = NumpydanticArray(ArrayExpression(**expr["array"]), pyrange)
ranges.append(range_generator.make().range)
slot.attribute.range = 'Union[' + ', '.join(ranges) + ']'
slot.attribute.range = "Union[" + ", ".join(ranges) + "]"
if is_optional:
slot.attribute.range = 'Optional[' + slot.attribute.range + ']'
del slot.attribute.meta['any_of']
slot.attribute.range = "Optional[" + slot.attribute.range + "]"
del slot.attribute.meta["any_of"]
return slot
def before_render_template(self, template: PydanticModule, sv: SchemaView) -> PydanticModule:
if 'source_file' in template.meta:
del template.meta['source_file']
if "source_file" in template.meta:
del template.meta["source_file"]
def compile_module(
self, module_path: Path = None, module_name: str = "test", **kwargs
@ -171,8 +169,6 @@ class NWBPydanticGenerator(PydanticGenerator):
raise e
def compile_python(
text_or_fn: str, package_path: Path = None, module_name: str = "test"
) -> ModuleType:

View file

@ -3,9 +3,10 @@ Utility functions for dealing with yaml files.
No we are not going to implement a yaml parser here
"""
import re
from pathlib import Path
from typing import Literal, List, Union, overload
from typing import List, Literal, Union, overload
import yaml
@ -13,15 +14,26 @@ from nwb_linkml.maps.postload import apply_postload
@overload
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:Literal[True]=True) -> str: ...
def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: Literal[True] = True
) -> str: ...
@overload
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:Literal[False]=False) -> List[str]: ...
def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: Literal[False] = False
) -> List[str]: ...
@overload
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:bool=True) -> Union[str, List[str]]: ...
def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: bool = True
) -> Union[str, List[str]]: ...
def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:bool=True) -> Union[str, List[str]]:
def yaml_peek(
key: str, path: Union[str, Path], root: bool = True, first: bool = True
) -> Union[str, List[str]]:
"""
Peek into a yaml file without parsing the whole file to retrieve the value of a single key.
@ -43,27 +55,27 @@ def yaml_peek(key: str, path: Union[str, Path], root:bool = True, first:bool=Tru
str
"""
if root:
pattern = re.compile(rf'^(?P<key>{key}):\s*(?P<value>\S.*)')
pattern = re.compile(rf"^(?P<key>{key}):\s*(?P<value>\S.*)")
else:
pattern = re.compile(rf'^\s*(?P<key>{key}):\s*(?P<value>\S.*)')
pattern = re.compile(rf"^\s*(?P<key>{key}):\s*(?P<value>\S.*)")
res = None
if first:
with open(path, 'r') as yfile:
for l in yfile:
res = pattern.match(l)
with open(path) as yfile:
for line in yfile:
res = pattern.match(line)
if res:
break
if res:
return res.groupdict()['value']
return res.groupdict()["value"]
else:
with open(path, 'r') as yfile:
with open(path) as yfile:
text = yfile.read()
res = [match.groupdict()['value'] for match in pattern.finditer(text)]
res = [match.groupdict()["value"] for match in pattern.finditer(text)]
if res:
return res
raise KeyError(f'Key {key} not found in {path}')
raise KeyError(f"Key {key} not found in {path}")
def load_yaml(path: Path | str) -> dict:

View file

@ -2,7 +2,9 @@
Language elements in nwb schema language that have a fixed, alternative representation
in LinkML. These are exported as an nwb.language.yml file along with every generated namespace
"""
from typing import List
from linkml_runtime.linkml_model import (
ClassDefinition,
EnumDefinition,
@ -35,12 +37,15 @@ def _make_dtypes() -> List[TypeDefinition]:
np_type = flat_to_np[nwbtype]
repr_string = f'np.{np_type.__name__}' if np_type.__module__ == 'numpy' else None
repr_string = f"np.{np_type.__name__}" if np_type.__module__ == "numpy" else None
atype = TypeDefinition(name=nwbtype, minimum_value=amin, typeof=linkmltype, repr=repr_string)
atype = TypeDefinition(
name=nwbtype, minimum_value=amin, typeof=linkmltype, repr=repr_string
)
DTypeTypes.append(atype)
return DTypeTypes
DTypeTypes = _make_dtypes()
AnyType = ClassDefinition(

View file

@ -2,7 +2,7 @@
Mapping from one domain to another
"""
from nwb_linkml.maps.dtype import flat_to_linkml, flat_to_nptyping, flat_to_np
from nwb_linkml.maps.dtype import flat_to_linkml, flat_to_np, flat_to_nptyping
from nwb_linkml.maps.map import Map
from nwb_linkml.maps.postload import MAP_HDMF_DATATYPE_DEF, MAP_HDMF_DATATYPE_INC
from nwb_linkml.maps.quantity import QUANTITY_MAP

View file

@ -11,8 +11,8 @@ from pathlib import Path
from types import ModuleType
from typing import List, Optional, Type
from linkml.generators.pydanticgen.pydanticgen import SplitMode, _ensure_inits, _import_to_path
from linkml_runtime.linkml_model.meta import SchemaDefinition
from linkml.generators.pydanticgen.pydanticgen import SplitMode, _import_to_path, _ensure_inits
from pydantic import BaseModel
from nwb_linkml.generators.pydantic import NWBPydanticGenerator
@ -107,7 +107,6 @@ class PydanticProvider(Provider):
# given a path to a namespace linkml yaml file
path = Path(namespace)
if split:
result = self._build_split(path, dump, force, **kwargs)
else:
@ -116,14 +115,10 @@ class PydanticProvider(Provider):
self.install_pathfinder()
return result
def _build_unsplit(
self,
path: Path,
dump: bool,
force: bool,
**kwargs
) -> str:
generator = NWBPydanticGenerator(str(path), split=False, split_pattern=self.SPLIT_PATTERN, **kwargs)
def _build_unsplit(self, path: Path, dump: bool, force: bool, **kwargs) -> str:
generator = NWBPydanticGenerator(
str(path), split=False, split_pattern=self.SPLIT_PATTERN, **kwargs
)
out_module = generator.generate_module_import(generator.schemaview.schema)
out_file = (self.path / _import_to_path(out_module)).resolve()
if out_file.exists() and not force:
@ -141,14 +136,9 @@ class PydanticProvider(Provider):
return serialized
def _build_split(
self,
path: Path,
dump: bool,
force: bool,
**kwargs
) -> List[str]:
# FIXME: This is messy as all fuck, we're just getting it to work again so we can start iterating on the models themselves
def _build_split(self, path: Path, dump: bool, force: bool, **kwargs) -> List[str]:
# FIXME: This is messy as all fuck, we're just getting it to work again
# so we can start iterating on the models themselves
res = []
module_paths = []
@ -156,8 +146,10 @@ class PydanticProvider(Provider):
# remove any directory traversal at the head of the pattern for this,
# we're making relative to the provider's path not the generated schema at first
root_pattern = re.sub(r'^\.*', '', self.SPLIT_PATTERN)
gen = NWBPydanticGenerator(schema=path, split=True, split_pattern=root_pattern, split_mode=SplitMode.FULL)
root_pattern = re.sub(r"^\.*", "", self.SPLIT_PATTERN)
gen = NWBPydanticGenerator(
schema=path, split=True, split_pattern=root_pattern, split_mode=SplitMode.FULL
)
mod_name = gen.generate_module_import(gen.schemaview.schema)
ns_file = (self.path / _import_to_path(mod_name)).resolve()
@ -170,11 +162,11 @@ class PydanticProvider(Provider):
ns_file.parent.mkdir(exist_ok=True, parents=True)
serialized = gen.serialize(rendered_module=rendered)
if dump:
with open(ns_file, 'w') as ofile:
with open(ns_file, "w") as ofile:
ofile.write(serialized)
module_paths.append(ns_file)
else:
with open(ns_file, 'r') as ofile:
with open(ns_file) as ofile:
serialized = ofile.read()
res.append(serialized)
@ -189,27 +181,29 @@ class PydanticProvider(Provider):
import_file.parent.mkdir(exist_ok=True, parents=True)
schema = imported_schema[generated_import.module]
is_namespace = False
ns_annotation = schema.annotations.get('is_namespace', None)
ns_annotation = schema.annotations.get("is_namespace", None)
if ns_annotation:
is_namespace = ns_annotation.value
# fix schema source to absolute path so schemaview can find imports
schema.source_file = (Path(gen.schemaview.schema.source_file).parent / schema.source_file).resolve()
schema.source_file = (
Path(gen.schemaview.schema.source_file).parent / schema.source_file
).resolve()
import_gen = NWBPydanticGenerator(
schema,
split=True,
split_pattern=self.SPLIT_PATTERN,
split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO
split_mode=SplitMode.FULL if is_namespace else SplitMode.AUTO,
)
serialized = import_gen.serialize()
if dump:
with open(import_file, 'w') as ofile:
with open(import_file, "w") as ofile:
ofile.write(serialized)
module_paths.append(import_file)
else:
with open(import_file, 'r') as ofile:
with open(import_file) as ofile:
serialized = ofile.read()
res.append(serialized)
@ -402,13 +396,17 @@ class PydanticProvider(Provider):
mod = self.get(namespace, version)
return getattr(mod, class_)
def install_pathfinder(self):
def install_pathfinder(self) -> None:
"""
Add a :class:`.EctopicModelFinder` instance that allows us to import from
the directory that we are generating models into
"""
# check if one already exists
matches = [finder for finder in sys.meta_path if isinstance(finder, EctopicModelFinder) and finder.path == self.path]
matches = [
finder
for finder in sys.meta_path
if isinstance(finder, EctopicModelFinder) and finder.path == self.path
]
if len(matches) > 0:
return

View file

@ -18,15 +18,21 @@ def pytest_addoption(parser):
"--without-cache", action="store_true", help="Don't use a sqlite cache for network requests"
)
parser.addoption(
"--dev", action="store_true", help="run tests that are intended only for development use, eg. those that generate output for inspection"
"--dev",
action="store_true",
help=(
"run tests that are intended only for development use, eg. those that generate output"
" for inspection"
),
)
def pytest_collection_modifyitems(config, items: List[pytest.Item]):
# remove dev tests from collection if we're not in dev mode!
if config.getoption('--dev'):
remove_tests = [t for t in items if not t.get_closest_marker('dev')]
if config.getoption("--dev"):
remove_tests = [t for t in items if not t.get_closest_marker("dev")]
else:
remove_tests = [t for t in items if t.get_closest_marker('dev')]
remove_tests = [t for t in items if t.get_closest_marker("dev")]
for t in remove_tests:
items.remove(t)
@ -56,4 +62,3 @@ def patch_requests_cache(pytestconfig):
# delete cache file unless we have requested it to persist for inspection
if not pytestconfig.getoption("--with-output"):
cache_file.unlink(missing_ok=True)

View file

@ -18,11 +18,13 @@ from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from nwb_linkml.lang_elements import NwbLangSchema
from nwb_linkml.providers import LinkMLProvider, PydanticProvider
@pytest.mark.dev
def test_generate_nwblang(tmp_output_dir):
output_file = (tmp_output_dir / NwbLangSchema.name).with_suffix(".yml")
yaml_dumper.dump(NwbLangSchema, output_file)
@pytest.mark.dev
def test_generate_core(nwb_core_fixture, tmp_output_dir):
schemas = nwb_core_fixture.build().schemas
@ -43,6 +45,7 @@ def load_schema_files(path: Path) -> Dict[str, SchemaDefinition]:
preloaded_schema[sch.name] = sch
return preloaded_schema
@pytest.mark.dev
@pytest.mark.depends(on=["test_generate_core"])
def test_generate_pydantic(tmp_output_dir):
@ -72,6 +75,7 @@ def test_generate_pydantic(tmp_output_dir):
with open(tmp_output_dir / "models" / "__init__.py", "w") as initfile:
initfile.write("# Autogenerated module indicator")
@pytest.mark.provider
@pytest.mark.dev
def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
@ -80,9 +84,8 @@ def test_generate_linkml_provider(tmp_output_dir, nwb_core_fixture):
result = provider.build(nwb_core_fixture)
@pytest.mark.provider
@pytest.mark.dev
def test_generate_pydantic_provider(tmp_output_dir):
provider = PydanticProvider(path=tmp_output_dir, verbose=False)
result = provider.build('core')
result = provider.build("core")

View file

@ -32,6 +32,7 @@ class TestModules(TypedDict):
TestModules.__test__ = False
@pytest.mark.xfail()
def generate_and_import(
linkml_schema: TestSchemas, split: bool, generator_kwargs: Optional[dict] = None
@ -77,6 +78,7 @@ def generate_and_import(
return TestModules(core=core, imported=imported, namespace=namespace, split=split)
@pytest.mark.xfail()
@pytest.fixture(scope="module", params=["split", "unsplit"])
def imported_schema(linkml_schema, request) -> TestModules:
@ -105,6 +107,7 @@ def _model_correctness(modules: TestModules):
assert issubclass(modules["core"].StillAnotherClass, BaseModel)
assert issubclass(modules["imported"].MainThing, BaseModel)
@pytest.mark.xfail()
def test_generate(linkml_schema):
"""
@ -130,6 +133,7 @@ def test_generate(linkml_schema):
del sys.modules["test_schema.imported"]
del sys.modules["test_schema.namespace"]
@pytest.mark.xfail()
def test_generate_split(linkml_schema):
"""
@ -151,6 +155,7 @@ def test_generate_split(linkml_schema):
del sys.modules["test_schema.imported"]
del sys.modules["test_schema.namespace"]
@pytest.mark.xfail()
def test_versions(linkml_schema):
"""
@ -195,6 +200,7 @@ def test_arraylike(imported_schema):
assert not hasattr(imported_schema["core"], "MainTopLevel__Array")
assert not hasattr(imported_schema["core"], "MainTopLevelArray")
@pytest.mark.xfail()
def test_inject_fields(imported_schema):
"""
@ -215,6 +221,7 @@ def test_linkml_meta(imported_schema):
assert imported_schema["core"].MainTopLevel.linkml_meta.default.tree_root
assert not imported_schema["core"].OtherClass.linkml_meta.default.tree_root
@pytest.mark.xfail()
def test_skip(linkml_schema):
"""
@ -246,6 +253,7 @@ def test_inline_with_identifier(imported_schema):
assert otherclass is imported_schema["core"].OtherClass
assert stillanother is imported_schema["core"].StillAnotherClass
@pytest.mark.xfail()
def test_namespace(imported_schema):
"""

View file

@ -3,18 +3,12 @@ import yaml
from nwb_linkml.io.yaml import yaml_peek
@pytest.fixture()
def yaml_file(tmp_path):
data = {
'key1': 'val1',
'key2': 'val2',
'key3': {
'key1': 'val3',
'key4': 'val4'
}
}
out_file = tmp_path / 'test.yaml'
with open(out_file, 'w') as yfile:
data = {"key1": "val1", "key2": "val2", "key3": {"key1": "val3", "key4": "val4"}}
out_file = tmp_path / "test.yaml"
with open(out_file, "w") as yfile:
yaml.dump(data, yfile)
yield out_file
@ -22,19 +16,18 @@ def yaml_file(tmp_path):
out_file.unlink()
@pytest.mark.parametrize(
'key,expected,root,first',
"key,expected,root,first",
[
('key1', 'val1', True, True),
('key1', 'val1', False, True),
('key1', ['val1'], True, False),
('key1', ['val1', 'val3'], False, False),
('key2', 'val2', True, True),
('key3', False, True, True),
('key4', False, True, True),
('key4', 'val4', False, True)
]
("key1", "val1", True, True),
("key1", "val1", False, True),
("key1", ["val1"], True, False),
("key1", ["val1", "val3"], False, False),
("key2", "val2", True, True),
("key3", False, True, True),
("key4", False, True, True),
("key4", "val4", False, True),
],
)
def test_peek_yaml(key, expected, root, first, yaml_file):
if not expected: