for pydantic provider test by escaping additional bad chars in numpy field name

This commit is contained in:
sneakers-the-rat 2023-10-12 00:02:52 -07:00
parent 3101797b0d
commit d2185ee1c3
2 changed files with 13 additions and 4 deletions

View file

@ -24,6 +24,7 @@ The `serialize` method
""" """
import pdb import pdb
import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import List, Dict, Set, Tuple, Optional, TypedDict, Type from typing import List, Dict, Set, Tuple, Optional, TypedDict, Type
@ -407,8 +408,15 @@ class NWBPydanticGenerator(PydanticGenerator):
shape_part = str(attr.maximum_cardinality) shape_part = str(attr.maximum_cardinality)
else: else:
shape_part = "*" shape_part = "*"
# do this cheaply instead of using regex because i want to see if this works at all first...
name_part = attr.name.replace(',', '_').replace(' ', '_').replace('__', '_').replace('|','_') # do this with the most heinous chain of string replacements rather than regex
# because i am still figuring out what needs to be subbed lol
name_part = attr.name.replace(',', '_'
).replace(' ', '_'
).replace('__', '_'
).replace('|','_'
).replace('-','_'
).replace('+','plus')
dim_pieces.append(' '.join([shape_part, name_part])) dim_pieces.append(' '.join([shape_part, name_part]))

View file

@ -17,6 +17,8 @@ import nwb_linkml
from nwb_linkml.maps.naming import version_module_case from nwb_linkml.maps.naming import version_module_case
from nwb_linkml.providers.git import DEFAULT_REPOS from nwb_linkml.providers.git import DEFAULT_REPOS
from nwb_linkml.adapters import NamespacesAdapter from nwb_linkml.adapters import NamespacesAdapter
from nwb_linkml.types.ndarray import NDArray
from nptyping import Shape, UByte
CORE_MODULES = ( CORE_MODULES = (
@ -71,7 +73,6 @@ def test_linkml_build_from_yaml(tmp_output_dir):
res = provider.build_from_yaml(ns_file) res = provider.build_from_yaml(ns_file)
@pytest.mark.skip()
@pytest.mark.depends(on=['test_linkml_provider']) @pytest.mark.depends(on=['test_linkml_provider'])
@pytest.mark.parametrize( @pytest.mark.parametrize(
['class_name', 'test_fields'], ['class_name', 'test_fields'],
@ -82,7 +83,7 @@ def test_linkml_build_from_yaml(tmp_output_dir):
'comments': Optional[str], 'comments': Optional[str],
'data': 'TimeSeriesData', 'data': 'TimeSeriesData',
'timestamps': 'Optional', # __name__ just gets the first part of Optional[TimeSeriesTimestamps] 'timestamps': 'Optional', # __name__ just gets the first part of Optional[TimeSeriesTimestamps]
'control': Optional[List[int]], 'control': Optional[NDArray[Shape['* num_times'], UByte]],
}) })
] ]
) )