mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
Merge pull request #6 from p2p-ld/dtype-models
Allow arbitrary dtypes, support pydantic models as dtypes :)
This commit is contained in:
commit
e9d766aad1
12 changed files with 111 additions and 22 deletions
|
@ -2,6 +2,23 @@
|
||||||
|
|
||||||
## 1.*
|
## 1.*
|
||||||
|
|
||||||
|
### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes
|
||||||
|
|
||||||
|
Previously we would only allow dtypes if we knew for sure that there was some
|
||||||
|
python base type to generate a schema with.
|
||||||
|
|
||||||
|
That seems overly restrictive, so relax the requirements to allow
|
||||||
|
any type to be a dtype. If there are problems with serialization (we assume there will)
|
||||||
|
or handling the object in a given array framework, we leave that up to the person
|
||||||
|
who declared the model to handle :). Let people break things and have fun!
|
||||||
|
|
||||||
|
Also support the ability to use a pydantic model as the inner type, which works
|
||||||
|
as expected because pydantic already knows how to generate a schema from its own models.
|
||||||
|
|
||||||
|
Only one substantial change, and that is a `get_object_dtype` method which
|
||||||
|
interfaces can override if there is some fancy way they have of getting
|
||||||
|
types/items from an object array.
|
||||||
|
|
||||||
### 1.3.0 - 24-08-05 - Better string dtype handling
|
### 1.3.0 - 24-08-05 - Better string dtype handling
|
||||||
|
|
||||||
API Changes:
|
API Changes:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "numpydantic"
|
name = "numpydantic"
|
||||||
version = "1.3.0"
|
version = "1.3.1"
|
||||||
description = "Type and shape validation and serialization for numpy arrays in pydantic models"
|
description = "Type and shape validation and serialization for numpy arrays in pydantic models"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
||||||
|
|
|
@ -8,6 +8,7 @@ import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface
|
||||||
|
from numpydantic.types import DtypeType, NDArrayType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dask.array.core import Array as DaskArray
|
from dask.array.core import Array as DaskArray
|
||||||
|
@ -30,6 +31,10 @@ class DaskInterface(Interface):
|
||||||
"""
|
"""
|
||||||
return DaskArray is not None and isinstance(array, DaskArray)
|
return DaskArray is not None and isinstance(array, DaskArray)
|
||||||
|
|
||||||
|
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||||
|
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||||
|
return type(array.ravel()[0].compute())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
"""check if we successfully imported dask"""
|
"""check if we successfully imported dask"""
|
||||||
|
|
|
@ -101,7 +101,17 @@ class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
Get the dtype from the input array
|
Get the dtype from the input array
|
||||||
"""
|
"""
|
||||||
return array.dtype
|
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
|
||||||
|
return self.get_object_dtype(array)
|
||||||
|
else:
|
||||||
|
return array.dtype
|
||||||
|
|
||||||
|
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||||
|
"""
|
||||||
|
When an array contains an object, get the dtype of the object contained
|
||||||
|
by the array.
|
||||||
|
"""
|
||||||
|
return type(array.ravel()[0])
|
||||||
|
|
||||||
def validate_dtype(self, dtype: DtypeType) -> bool:
|
def validate_dtype(self, dtype: DtypeType) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -125,14 +125,10 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
|
||||||
check_type_names(dtype, dtype_per_name)
|
check_type_names(dtype, dtype_per_name)
|
||||||
elif isinstance(dtype_candidate, tuple): # pragma: no cover
|
elif isinstance(dtype_candidate, tuple): # pragma: no cover
|
||||||
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
|
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
|
||||||
else: # pragma: no cover
|
else:
|
||||||
raise InvalidArgumentsError(
|
# arbitrary dtype - allow failure elsewhere :)
|
||||||
f"Unexpected argument '{dtype_candidate}', expecting"
|
dtype = dtype_candidate
|
||||||
" Structure[<StructureExpression>]"
|
|
||||||
" or Literal[<StructureExpression>]"
|
|
||||||
" or a dtype"
|
|
||||||
" or typing.Any."
|
|
||||||
)
|
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
def _dtype_to_str(cls, dtype: Any) -> str:
|
def _dtype_to_str(cls, dtype: Any) -> str:
|
||||||
|
|
|
@ -8,7 +8,7 @@ import json
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import BaseModel, SerializationInfo
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
||||||
|
|
||||||
|
@ -66,18 +66,18 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
python_type = np_to_python[dtype]
|
python_type = np_to_python[dtype]
|
||||||
except KeyError as e: # pragma: no cover
|
except KeyError: # pragma: no cover
|
||||||
# this should pretty much only happen in downstream/3rd-party interfaces
|
# this should pretty much only happen in downstream/3rd-party interfaces
|
||||||
# that use interface-specific types. those need to provide mappings back
|
# that use interface-specific types. those need to provide mappings back
|
||||||
# to base python types (making this more streamlined is TODO)
|
# to base python types (making this more streamlined is TODO)
|
||||||
if dtype in np_to_python.values():
|
if dtype in np_to_python.values():
|
||||||
# it's already a python type
|
# it's already a python type
|
||||||
python_type = dtype
|
python_type = dtype
|
||||||
|
elif issubclass(dtype, BaseModel):
|
||||||
|
python_type = dtype
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
# does this need a warning?
|
||||||
"dtype given in model does not have a corresponding python base "
|
python_type = Any
|
||||||
"type - add one to the `maps.np_to_python` dict"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if python_type in _UNSUPPORTED_TYPES:
|
if python_type in _UNSUPPORTED_TYPES:
|
||||||
array_type = core_schema.any_schema()
|
array_type = core_schema.any_schema()
|
||||||
|
|
|
@ -58,6 +58,14 @@ class ValidationCase(BaseModel):
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
|
|
||||||
|
class BasicModel(BaseModel):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
|
||||||
|
class BadModel(BaseModel):
|
||||||
|
x: int
|
||||||
|
|
||||||
|
|
||||||
RGB_UNION: TypeAlias = Union[
|
RGB_UNION: TypeAlias = Union[
|
||||||
NDArray[Shape["* x, * y"], Number],
|
NDArray[Shape["* x, * y"], Number],
|
||||||
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||||
|
@ -68,6 +76,7 @@ NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
|
||||||
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
|
||||||
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
|
||||||
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
|
||||||
|
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
|
@ -131,6 +140,9 @@ def shape_cases(request) -> ValidationCase:
|
||||||
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
ValidationCase(annotation=STRING, dtype=str, passes=True),
|
||||||
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
ValidationCase(annotation=STRING, dtype=int, passes=False),
|
||||||
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
ValidationCase(annotation=STRING, dtype=float, passes=False),
|
||||||
|
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
|
||||||
|
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
|
||||||
|
ValidationCase(annotation=MODEL, dtype=int, passes=False),
|
||||||
],
|
],
|
||||||
ids=[
|
ids=[
|
||||||
"float",
|
"float",
|
||||||
|
@ -154,6 +166,9 @@ def shape_cases(request) -> ValidationCase:
|
||||||
"str-str",
|
"str-str",
|
||||||
"str-int",
|
"str-int",
|
||||||
"str-float",
|
"str-float",
|
||||||
|
"model-model",
|
||||||
|
"model-badmodel",
|
||||||
|
"model-int",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def dtype_cases(request) -> ValidationCase:
|
def dtype_cases(request) -> ValidationCase:
|
||||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
from pydantic import ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from numpydantic.interface import DaskInterface
|
from numpydantic.interface import DaskInterface
|
||||||
from numpydantic.exceptions import DtypeError, ShapeError
|
from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
@ -13,7 +13,10 @@ from tests.conftest import ValidationCase
|
||||||
|
|
||||||
|
|
||||||
def dask_array(case: ValidationCase) -> da.Array:
|
def dask_array(case: ValidationCase) -> da.Array:
|
||||||
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
|
||||||
|
else:
|
||||||
|
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)
|
||||||
|
|
||||||
|
|
||||||
def _test_dask_case(case: ValidationCase):
|
def _test_dask_case(case: ValidationCase):
|
||||||
|
|
|
@ -20,6 +20,8 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
pytest.skip("hdf5 cant support arbitrary python objects")
|
||||||
return array_func(case.shape, case.dtype)
|
return array_func(case.shape, case.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError, BaseModel
|
||||||
from numpydantic.exceptions import DtypeError, ShapeError
|
from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
|
||||||
from tests.conftest import ValidationCase
|
from tests.conftest import ValidationCase
|
||||||
|
|
||||||
|
|
||||||
def numpy_array(case: ValidationCase) -> np.ndarray:
|
def numpy_array(case: ValidationCase) -> np.ndarray:
|
||||||
return np.zeros(shape=case.shape, dtype=case.dtype)
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
return np.full(shape=case.shape, fill_value=case.dtype(x=1))
|
||||||
|
else:
|
||||||
|
return np.zeros(shape=case.shape, dtype=case.dtype)
|
||||||
|
|
||||||
|
|
||||||
def _test_np_case(case: ValidationCase):
|
def _test_np_case(case: ValidationCase):
|
||||||
|
|
|
@ -3,7 +3,9 @@ import json
|
||||||
import pytest
|
import pytest
|
||||||
import zarr
|
import zarr
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from numcodecs import Pickle
|
||||||
|
|
||||||
|
|
||||||
from numpydantic.interface import ZarrInterface
|
from numpydantic.interface import ZarrInterface
|
||||||
from numpydantic.interface.zarr import ZarrArrayPath
|
from numpydantic.interface.zarr import ZarrArrayPath
|
||||||
|
@ -31,7 +33,19 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore:
|
||||||
|
|
||||||
|
|
||||||
def _zarr_array(case: ValidationCase, store) -> zarr.core.Array:
|
def _zarr_array(case: ValidationCase, store) -> zarr.core.Array:
|
||||||
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
pytest.skip(
|
||||||
|
f"Zarr can't handle objects properly at the moment, "
|
||||||
|
"see https://github.com/zarr-developers/zarr-python/issues/2081"
|
||||||
|
)
|
||||||
|
# return zarr.full(
|
||||||
|
# shape=case.shape,
|
||||||
|
# fill_value=case.dtype(x=1),
|
||||||
|
# dtype=object,
|
||||||
|
# object_codec=Pickle(),
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)
|
||||||
|
|
||||||
|
|
||||||
def _test_zarr_case(case: ValidationCase, store):
|
def _test_zarr_case(case: ValidationCase, store):
|
||||||
|
|
|
@ -266,6 +266,30 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
|
||||||
assert inner_type["type"] == expected
|
assert inner_type["type"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_schema_dtype_model():
|
||||||
|
"""
|
||||||
|
Pydantic models can be used in arrays as dtypes
|
||||||
|
"""
|
||||||
|
|
||||||
|
class TestModel(BaseModel):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
z: int
|
||||||
|
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray[Shape["*, *"], TestModel]
|
||||||
|
|
||||||
|
schema = MyModel.model_json_schema()
|
||||||
|
# we should have a "$defs" with TestModel in it,
|
||||||
|
# and our array should be objects of that type
|
||||||
|
assert schema["properties"]["array"]["items"]["items"] == {
|
||||||
|
"$ref": "#/$defs/TestModel"
|
||||||
|
}
|
||||||
|
# we don't test pydantic' generic json schema model generation,
|
||||||
|
# just that one was defined
|
||||||
|
assert "TestModel" in schema["$defs"]
|
||||||
|
|
||||||
|
|
||||||
def _recursive_array(schema):
|
def _recursive_array(schema):
|
||||||
assert "$defs" in schema
|
assert "$defs" in schema
|
||||||
# get the key uses for the array
|
# get the key uses for the array
|
||||||
|
|
Loading…
Reference in a new issue