handle and test complex

This commit is contained in:
sneakers-the-rat 2024-05-17 18:05:36 -07:00
parent 3a6984f3f0
commit a1a440e6ad
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
3 changed files with 16 additions and 4 deletions

View file

@ -58,5 +58,11 @@ flat_to_nptyping = {
} }
"""Map from NWB-style flat dtypes to nptyping types""" """Map from NWB-style flat dtypes to nptyping types"""
python_to_nptyping = {float: dt.Float, str: dt.String, int: dt.Int, bool: dt.Bool} python_to_nptyping = {
float: dt.Float,
str: dt.String,
int: dt.Int,
bool: dt.Bool,
complex: dt.Complex,
}
"""Map from python types to nptyping types""" """Map from python types to nptyping types"""

View file

@ -41,7 +41,6 @@ def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema: def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
"""Get the innermost dtype schema to use in the generated pydantic schema""" """Get the innermost dtype schema to use in the generated pydantic schema"""
if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover if isinstance(dtype, nptyping.structure.StructureMeta): # pragma: no cover
raise NotImplementedError("Structured dtypes are currently unsupported") raise NotImplementedError("Structured dtypes are currently unsupported")
@ -63,7 +62,10 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
else: else:
try: try:
python_type = np_to_python[dtype] python_type = np_to_python[dtype]
except KeyError as e: except KeyError as e: # pragma: no cover
# this should pretty much only happen in downstream/3rd-party interfaces
# that use interface-specific types. those need to provide mappings back
# to base python types (making this more streamlined is TODO)
if dtype in np_to_python.values(): if dtype in np_to_python.values():
# it's already a python type # it's already a python type
python_type = dtype python_type = dtype

View file

@ -169,6 +169,7 @@ def test_json_schema_dtype_single(dtype, array_model):
(int, "integer"), (int, "integer"),
(float, "number"), (float, "number"),
(bool, "boolean"), (bool, "boolean"),
(complex, "any"),
], ],
) )
def test_json_schema_dtype_builtin(dtype, expected, array_model): def test_json_schema_dtype_builtin(dtype, expected, array_model):
@ -179,6 +180,9 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
model = array_model(dtype=dtype) model = array_model(dtype=dtype)
schema = model.model_json_schema() schema = model.model_json_schema()
inner_type = schema["properties"]["array"]["items"]["items"] inner_type = schema["properties"]["array"]["items"]["items"]
if expected == "any":
assert inner_type == {}
else:
assert inner_type["type"] == expected assert inner_type["type"] == expected