working generation from provider

This commit is contained in:
sneakers-the-rat 2023-09-08 19:46:42 -07:00
parent 0b0fb6c67a
commit 2e87fa0556
9 changed files with 199 additions and 103 deletions

View file

@ -13,6 +13,7 @@ from pprint import pformat
from linkml_runtime.linkml_model import SchemaDefinition from linkml_runtime.linkml_model import SchemaDefinition
from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.dumpers import yaml_dumper
from time import sleep from time import sleep
from copy import copy
from nwb_schema_language import Namespaces from nwb_schema_language import Namespaces
@ -76,7 +77,6 @@ class NamespacesAdapter(Adapter):
sch_result += sch.build() sch_result += sch.build()
if progress is not None: if progress is not None:
progress.update(sch.namespace, advance=1) progress.update(sch.namespace, advance=1)
sleep(1)
# recursive step # recursive step
if not skip_imports: if not skip_imports:
@ -84,12 +84,19 @@ class NamespacesAdapter(Adapter):
imported_build = imported.build(progress=progress) imported_build = imported.build(progress=progress)
sch_result += imported_build sch_result += imported_build
# add in monkeypatch nwb types
sch_result.schemas.append(NwbLangSchema)
# now generate the top-level namespaces that import everything # now generate the top-level namespaces that import everything
for ns in self.namespaces.namespaces: for ns in self.namespaces.namespaces:
# add in monkeypatch nwb types
nwb_lang = copy(NwbLangSchema)
lang_schema_name = '.'.join([ns.name, 'nwb.language'])
nwb_lang.name = lang_schema_name
sch_result.schemas.append(nwb_lang)
ns_schemas = [sch.name for sch in self.schemas if sch.namespace == ns.name] ns_schemas = [sch.name for sch in self.schemas if sch.namespace == ns.name]
ns_schemas.append(lang_schema_name)
# also add imports bc, well, we need them # also add imports bc, well, we need them
if not skip_imports: if not skip_imports:
ns_schemas.extend([ns.name for imported in self.imported for ns in imported.namespaces.namespaces]) ns_schemas.extend([ns.name for imported in self.imported for ns in imported.namespaces.namespaces])
@ -179,6 +186,11 @@ class NamespacesAdapter(Adapter):
""" """
Populate the imports that are needed for each schema file Populate the imports that are needed for each schema file
This function adds a string version of imported schema assuming the
generated schema will live in the same directory. If the path to
the imported schema needs to be adjusted, that should happen elsewhere
(eg in :class:`.LinkMLProvider`) because we shouldn't know about
directory structure or anything like that here.
""" """
for sch in self.schemas: for sch in self.schemas:
for needs in sch.needed_imports: for needs in sch.needed_imports:

View file

@ -93,7 +93,7 @@ class SchemaAdapter(Adapter):
types=res.types types=res.types
) )
# every schema needs the language elements # every schema needs the language elements
sch.imports.append('nwb.language') sch.imports.append('.'.join([self.namespace, 'nwb.language']))
return BuildResult(schemas=[sch]) return BuildResult(schemas=[sch])
def split_subclasses(self, classes: BuildResult) -> BuildResult: def split_subclasses(self, classes: BuildResult) -> BuildResult:

View file

@ -17,7 +17,7 @@ The `serialize` method
""" """
import pdb import pdb
from dataclasses import dataclass from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import List, Dict, Set, Tuple, Optional, TypedDict from typing import List, Dict, Set, Tuple, Optional, TypedDict
import os, sys import os, sys
@ -193,8 +193,8 @@ class NWBPydanticGenerator(PydanticGenerator):
SKIP_CLASSES=('',) SKIP_CLASSES=('',)
# SKIP_CLASSES=('VectorData','VectorIndex') # SKIP_CLASSES=('VectorData','VectorIndex')
split:bool=True split:bool=True
schema_map:Dict[str, SchemaDefinition]=None schema_map:Optional[Dict[str, SchemaDefinition]]=None
versions:List[dict] = None versions:dict = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from""" """See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""
@ -220,8 +220,8 @@ class NWBPydanticGenerator(PydanticGenerator):
if module_name == self.schema.name: if module_name == self.schema.name:
continue continue
if self.versions and module_name in [v['name'] for v in self.versions]: if self.versions and module_name in self.versions:
version = version_module_case([v['version'] for v in self.versions if v['name'] == module_name][0]) version = version_module_case(self.versions[module_name])
local_mod_name = '....' + module_case(module_name) + '.' + version + '.' + 'namespace' local_mod_name = '....' + module_case(module_name) + '.' + version + '.' + 'namespace'
else: else:
@ -537,7 +537,8 @@ class NWBPydanticGenerator(PydanticGenerator):
sv: SchemaView sv: SchemaView
sv = self.schemaview sv = self.schemaview
sv.schema_map = self.schema_map if self.schema_map is not None:
sv.schema_map = self.schema_map
schema = sv.schema schema = sv.schema
pyschema = SchemaDefinition( pyschema = SchemaDefinition(
id=schema.id, id=schema.id,

View file

@ -1,3 +1,4 @@
import pdb
import re import re
from pathlib import Path from pathlib import Path
@ -44,9 +45,19 @@ def relative_path(target: Path, origin: Path):
References: References:
- https://stackoverflow.com/a/71874881 - https://stackoverflow.com/a/71874881
""" """
try: def _relative_path(target:Path, origin:Path):
return Path(target).resolve().relative_to(Path(origin).resolve()) try:
except ValueError as e: # target does not start with origin return Path(target).resolve().relative_to(Path(origin).resolve())
# recursion with origin (eventually origin is root so try will succeed) except ValueError as e: # target does not start with origin
return Path('..').joinpath(relative_path(target, Path(origin).parent)) # recursion with origin (eventually origin is root so try will succeed)
return Path('..').joinpath(_relative_path(target, Path(origin).parent))
try:
successful = Path(target).resolve().relative_to(Path(origin).resolve())
return successful
except ValueError as e: # target does not start with origin
# recursion with origin (eventually origin is root so try will succeed)
relative = Path('..').joinpath(_relative_path(target, Path(origin).parent))
# remove the first '..' because this thing freaking double counts
return Path(*relative.parts[1:])

View file

@ -14,11 +14,13 @@ Relationship to other modules:
""" """
import pdb import pdb
from typing import Dict, TypedDict, List, Optional, Literal, TypeVar, Any, Dict from typing import Dict, TypedDict, List, Optional, Literal, TypeVar, Any, Dict
from types import ModuleType
from pathlib import Path from pathlib import Path
import os import os
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
import warnings import warnings
import importlib import importlib
import sys
from linkml_runtime.linkml_model import SchemaDefinition, SchemaDefinitionName from linkml_runtime.linkml_model import SchemaDefinition, SchemaDefinitionName
from linkml_runtime.dumpers import yaml_dumper from linkml_runtime.dumpers import yaml_dumper
@ -34,9 +36,6 @@ from nwb_linkml.generators.pydantic import NWBPydanticGenerator
from nwb_linkml.providers.git import DEFAULT_REPOS from nwb_linkml.providers.git import DEFAULT_REPOS
from nwb_linkml.ui import AdapterProgress from nwb_linkml.ui import AdapterProgress
class NamespaceVersion(TypedDict):
namespace: str
version: str
P = TypeVar('P') P = TypeVar('P')
@ -112,14 +111,16 @@ class Provider(ABC):
if version is not None: if version is not None:
version_path = namespace_path / version_module_case(version) version_path = namespace_path / version_module_case(version)
version_path.mkdir(exist_ok=True, parents=True) #version_path.mkdir(exist_ok=True, parents=True)
else: else:
# or find the most recently built one # or find the most recently built one
versions = sorted(namespace_path.iterdir(), key=os.path.getmtime) versions = sorted(namespace_path.iterdir(), key=os.path.getmtime)
versions = [v for v in versions if v.is_dir() and v.name not in ('__pycache__')]
if len(versions) == 0: if len(versions) == 0:
raise FileNotFoundError('No version provided, and no existing schema found') raise FileNotFoundError('No version provided, and no existing schema found')
version_path = versions[-1] version_path = versions[-1]
return version_path return version_path
@ -180,20 +181,21 @@ class LinkMLProvider(Provider):
def build( def build(
self, self,
ns_adapter: adapters.NamespacesAdapter, ns_adapter: adapters.NamespacesAdapter,
versions: Optional[List[NamespaceVersion]] = None, versions: Optional[dict] = None,
dump: bool = True, dump: bool = True,
) -> Dict[str | SchemaDefinitionName, LinkMLSchemaBuild]: ) -> Dict[str | SchemaDefinitionName, LinkMLSchemaBuild]:
""" """
Arguments: Arguments:
namespaces (:class:`.NamespacesAdapter`): Adapter (populated with any necessary imported namespaces) namespaces (:class:`.NamespacesAdapter`): Adapter (populated with any necessary imported namespaces)
to build to build
versions (List[NamespaceVersion]): List of specific versions to use versions (dict): Dict of specific versions to use
for cross-namespace imports. If none is provided, use the most recent version for cross-namespace imports. as ``{'namespace': 'version'}``
If none is provided, use the most recent version
available. available.
dump (bool): If ``True`` (default), dump generated schema to YAML. otherwise just return dump (bool): If ``True`` (default), dump generated schema to YAML. otherwise just return
""" """
self._find_imports(ns_adapter, versions, populate=True) #self._find_imports(ns_adapter, versions, populate=True)
if self.verbose: if self.verbose:
progress = AdapterProgress(ns_adapter) progress = AdapterProgress(ns_adapter)
#progress.start() #progress.start()
@ -212,16 +214,22 @@ class LinkMLProvider(Provider):
build_result = {} build_result = {}
namespace_sch = [sch for sch in built.schemas if 'namespace' in sch.annotations.keys()] namespace_sch = [sch for sch in built.schemas if 'namespace' in sch.annotations.keys()]
namespace_names = [sch.name for sch in namespace_sch]
for ns_linkml in namespace_sch: for ns_linkml in namespace_sch:
version = ns_adapter.versions[ns_linkml.name] version = ns_adapter.versions[ns_linkml.name]
version_path = self.namespace_path(ns_linkml.name, version, allow_repo=False) version_path = self.namespace_path(ns_linkml.name, version, allow_repo=False)
version_path.mkdir(exist_ok=True, parents=True)
ns_file = version_path / 'namespace.yaml' ns_file = version_path / 'namespace.yaml'
ns_linkml = self._fix_schema_imports(ns_linkml, ns_adapter, ns_file)
yaml_dumper.dump(ns_linkml, ns_file) yaml_dumper.dump(ns_linkml, ns_file)
# write the schemas for this namespace # write the schemas for this namespace
ns_schema_names = [name.strip('.yaml') for name in ns_adapter.namespace_schemas(ns_linkml.name)] other_schema = [sch for sch in built.schemas if sch.name.split('.')[0] == ns_linkml.name and sch not in namespace_sch]
other_schema = [sch for sch in built.schemas if sch.name in ns_schema_names]
for sch in other_schema: for sch in other_schema:
output_file = version_path / (sch.name + '.yaml') output_file = version_path / (sch.name + '.yaml')
# fix the paths for intra-schema imports
sch = self._fix_schema_imports(sch, ns_adapter, output_file)
yaml_dumper.dump(sch, output_file) yaml_dumper.dump(sch, output_file)
# make return result for just this namespace # make return result for just this namespace
@ -233,6 +241,20 @@ class LinkMLProvider(Provider):
return build_result return build_result
def _fix_schema_imports(self, sch: SchemaDefinition,
ns_adapter: adapters.NamespacesAdapter,
output_file: Path) -> SchemaDefinition:
for animport in sch.imports:
if animport.split('.')[0] in ns_adapter.versions.keys():
imported_path = self.namespace_path(animport.split('.')[0], ns_adapter.versions[animport.split('.')[0]]) / 'namespace'
rel_path = relative_path(imported_path, output_file)
if str(rel_path) == '.' or str(rel_path) == 'namespace':
# same directory, just keep the existing import
continue
idx = sch.imports.index(animport)
del sch.imports[idx]
sch.imports.insert(idx, str(rel_path))
return sch
def get(self, namespace: str, version: Optional[str] = None) -> SchemaView: def get(self, namespace: str, version: Optional[str] = None) -> SchemaView:
""" """
Get a schema view over the namespace Get a schema view over the namespace
@ -240,7 +262,9 @@ class LinkMLProvider(Provider):
path = self.namespace_path(namespace, version) / 'namespace.yaml' path = self.namespace_path(namespace, version) / 'namespace.yaml'
if not path.exists(): if not path.exists():
path = self._find_source(namespace, version) path = self._find_source(namespace, version)
return SchemaView(path) sv = SchemaView(path)
sv.path = path
return sv
def _find_source(self, namespace:str, version: Optional[str] = None) -> Path: def _find_source(self, namespace:str, version: Optional[str] = None) -> Path:
"""Try and find the namespace if it exists in our default repository and build it!""" """Try and find the namespace if it exists in our default repository and build it!"""
@ -254,43 +278,42 @@ class LinkMLProvider(Provider):
#
def _find_imports(self, # def _find_imports(self,
ns: adapters.NamespacesAdapter, # ns: adapters.NamespacesAdapter,
versions: Optional[List[NamespaceVersion]] = None, # versions: Optional[dict] = None,
populate: bool=True) -> Dict[str, List[str]]: # populate: bool=True) -> Dict[str, List[str]]:
""" # """
Find relative paths to other linkml schema that need to be # Find relative paths to other linkml schema that need to be
imported, but lack an explicit source # imported, but lack an explicit source
#
Arguments: # Arguments:
ns (:class:`.NamespacesAdapter`): Namespaces to find imports to # ns (:class:`.NamespacesAdapter`): Namespaces to find imports to
versions (List[:class:`.NamespaceVersion`]): Specific versions to import # versions (dict): Specific versions to import
populate (bool): If ``True`` (default), modify the namespace adapter to include the imports, # populate (bool): If ``True`` (default), modify the namespace adapter to include the imports,
otherwise just return # otherwise just return
#
Returns: # Returns:
dict of lists for relative paths to other schema namespaces # dict of lists for relative paths to other schema namespaces
""" # """
import_paths = {} # import_paths = {}
for ns_name, needed_imports in ns.needed_imports.items(): # for ns_name, needed_imports in ns.needed_imports.items():
our_path = self.namespace_path(ns_name, ns.versions[ns_name]) / 'namespace.yaml' # our_path = self.namespace_path(ns_name, ns.versions[ns_name], allow_repo=False) / 'namespace.yaml'
import_paths[ns_name] = [] # import_paths[ns_name] = []
for needed_import in needed_imports: # for needed_import in needed_imports:
needed_version = None # needed_version = None
if versions: # if versions:
needed_versions = [v['version'] for v in versions if v['namespace'] == needed_import] # needed_version = versions.get(needed_import, None)
if len(needed_versions) > 0: #
needed_version = needed_versions[0] # version_path = self.namespace_path(needed_import, needed_version, allow_repo=False) / 'namespace.yaml'
# import_paths[ns_name].append(str(relative_path(version_path, our_path)))
version_path = self.namespace_path(needed_import, needed_version) / 'namespace.yaml' #
import_paths[ns_name].append(str(relative_path(version_path, our_path))) # if populate:
# pdb.set_trace()
if populate: # for sch in ns.schemas:
for sch in ns.schemas: # sch.imports.extend(import_paths[ns_name])
sch.imports.extend(import_paths) #
# return import_paths
return import_paths
class PydanticProvider(Provider): class PydanticProvider(Provider):
@ -304,9 +327,24 @@ class PydanticProvider(Provider):
self, self,
namespace: str | Path, namespace: str | Path,
version: Optional[str] = None, version: Optional[str] = None,
versions: Optional[List[NamespaceVersion]] = None, versions: Optional[dict] = None,
dump: bool = True dump: bool = True,
**kwargs
) -> str: ) -> str:
"""
Args:
namespace:
version:
versions:
dump:
**kwargs: Passed to :class:`.NWBPydanticGenerator`
Returns:
"""
if isinstance(namespace, str) and not (namespace.endswith('.yaml') or namespace.endswith('.yml')): if isinstance(namespace, str) and not (namespace.endswith('.yaml') or namespace.endswith('.yml')):
# we're given a name of a namespace to build # we're given a name of a namespace to build
path = LinkMLProvider(path=self.config.cache_dir).namespace_path(namespace, version) / 'namespace.yaml' path = LinkMLProvider(path=self.config.cache_dir).namespace_path(namespace, version) / 'namespace.yaml'
@ -314,22 +352,67 @@ class PydanticProvider(Provider):
# given a path to a namespace linkml yaml file # given a path to a namespace linkml yaml file
path = Path(namespace) path = Path(namespace)
default_kwargs = {
'split': False,
'emit_metadata': True,
'gen_slots': True,
'pydantic_version': '2'
}
default_kwargs.update(kwargs)
generator = NWBPydanticGenerator( generator = NWBPydanticGenerator(
str(path), str(path),
split=False,
versions=versions, versions=versions,
emit_metadata=True, **default_kwargs
gen_slots=True,
pydantic_version='2'
) )
serialized = generator.serialize() serialized = generator.serialize()
if dump: if dump:
out_file = self.path / path.parts[-3] / path.parts[-2] / 'namespace.py' out_file = self.path / path.parts[-3] / path.parts[-2] / 'namespace.py'
out_file.parent.mkdir(parents=True,exist_ok=True)
with open(out_file, 'w') as ofile: with open(out_file, 'w') as ofile:
ofile.write(serialized) ofile.write(serialized)
return serialized return serialized
@classmethod
def module_name(self, namespace:str, version:Optional[str]=None) -> str:
name_pieces = ['nwb_linkml', 'models', namespace]
if version is not None:
name_pieces.append(version_module_case(version))
module_name = '.'.join(name_pieces)
return module_name
def import_module(
self,
namespace: str,
version: Optional[str] = None
) -> ModuleType:
path = self.namespace_path(namespace, version) / 'namespace.py'
if not path.exists():
raise ImportError(f'Module has not been built yet {path}')
module_name = self.module_name(namespace, version)
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def get(self, namespace: str, version: Optional[str] = None) -> ModuleType:
module_name = self.module_name(namespace, version)
if module_name in sys.modules:
return sys.modules[module_name]
try:
path = self.namespace_path(namespace, version)
except FileNotFoundError:
path = None
if path is None or not path.exists():
_ = self.build(namespace, version)
module = self.import_module(namespace, version)
return module
class SchemaProvider: class SchemaProvider:
""" """
@ -342,30 +425,29 @@ class SchemaProvider:
consistency. consistency.
Store each generated schema in a directory structure indexed by Store each generated schema in a directory structure indexed by
schema namespace name and a truncated hash of the loaded schema dictionaries schema namespace name and version
(not the hash of the .yaml file, since we are also provided schema in nwbfiles)
eg: eg:
cache_dir cache_dir
- linkml - linkml
- nwb_core - nwb_core
- hash_532gn90f - v0_2_0
- nwb.core.namespace.yaml - nwb.core.namespace.yaml
- nwb.fore.file.yaml - nwb.fore.file.yaml
- ... - ...
- hash_fuia082f - v0_2_1
- nwb.core.namespace.yaml - nwb.core.namespace.yaml
- ... - ...
- my_schema - my_schema
- hash_t3tn908h - v0_1_0
- ... - ...
- pydantic - pydantic
- nwb_core - nwb_core
- hash_532gn90f - v0_2_0
- core.py - core.py
- ... - ...
- hash_fuia082f - v0_2_1
- core.py - core.py
- ... - ...
@ -393,25 +475,6 @@ class SchemaProvider:
def generate_linkml(
self,
schemas:Dict[str, dict],
versions: Optional[List[NamespaceVersion]] = None
):
"""
Generate linkml from loaded nwb schemas, either from yaml or from an
nwb file's ``/specifications`` group.
Arguments:
schemas (dict): A dictionary of ``{'schema_name': {:schema_definition}}``.
The "namespace" schema should have the key ``namespace``, which is used
to infer version and schema name. Post-load maps should have already
been applied
versions (List[NamespaceVersion]): List of specific versions to use
for cross-namespace imports. If none is provided, use the most recent version
available.
"""

View file

@ -1,8 +0,0 @@
import pytest
from nwb_linkml.providers.schema import LinkMLProvider
def test_linkml_provider():
provider = LinkMLProvider()
core = provider.get('core')

View file

@ -0,0 +1,17 @@
import pdb
import pytest
from nwb_linkml.providers.schema import LinkMLProvider, PydanticProvider
def test_linkml_provider():
provider = LinkMLProvider()
core = provider.get('core')
@pytest.mark.depends(on=['test_linkml_provider'])
def test_pydantic_provider():
provider = PydanticProvider()
core = provider.get('core')