Merge pull request #6 from p2p-ld/dtype-models

Allow arbitrary dtypes, support pydantic models as dtypes :)
This commit is contained in:
Jonny Saunders 2024-08-12 21:16:11 -07:00 committed by GitHub
commit e9d766aad1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 111 additions and 22 deletions

View file

@ -2,6 +2,23 @@
## 1.* ## 1.*
### 1.3.1 - 24-08-12 - Allow arbitrary dtypes, pydantic models as dtypes
Previously we would only allow dtypes if we knew for sure that there was some
python base type to generate a schema with.
That seems overly restrictive, so relax the requirements to allow
any type to be a dtype. If there are problems with serialization (we assume there will)
or handling the object in a given array framework, we leave that up to the person
who declared the model to handle :). Let people break things and have fun!
Also support the ability to use a pydantic model as the inner type, which works
as expected because pydantic already knows how to generate a schema from its own models.
Only one substantial change, and that is a `get_object_dtype` method which
interfaces can override if there is some fancy way they have of getting
types/items from an object array.
### 1.3.0 - 24-08-05 - Better string dtype handling ### 1.3.0 - 24-08-05 - Better string dtype handling
API Changes: API Changes:

View file

@ -1,6 +1,6 @@
[project] [project]
name = "numpydantic" name = "numpydantic"
version = "1.3.0" version = "1.3.1"
description = "Type and shape validation and serialization for numpy arrays in pydantic models" description = "Type and shape validation and serialization for numpy arrays in pydantic models"
authors = [ authors = [
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"}, {name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},

View file

@ -8,6 +8,7 @@ import numpy as np
from pydantic import SerializationInfo from pydantic import SerializationInfo
from numpydantic.interface.interface import Interface from numpydantic.interface.interface import Interface
from numpydantic.types import DtypeType, NDArrayType
try: try:
from dask.array.core import Array as DaskArray from dask.array.core import Array as DaskArray
@ -30,6 +31,10 @@ class DaskInterface(Interface):
""" """
return DaskArray is not None and isinstance(array, DaskArray) return DaskArray is not None and isinstance(array, DaskArray)
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""Dask arrays require a compute() call to retrieve a single value"""
return type(array.ravel()[0].compute())
@classmethod @classmethod
def enabled(cls) -> bool: def enabled(cls) -> bool:
"""check if we successfully imported dask""" """check if we successfully imported dask"""

View file

@ -101,8 +101,18 @@ class Interface(ABC, Generic[T]):
""" """
Get the dtype from the input array Get the dtype from the input array
""" """
if hasattr(array.dtype, "type") and array.dtype.type is np.object_:
return self.get_object_dtype(array)
else:
return array.dtype return array.dtype
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
"""
When an array contains an object, get the dtype of the object contained
by the array.
"""
return type(array.ravel()[0])
def validate_dtype(self, dtype: DtypeType) -> bool: def validate_dtype(self, dtype: DtypeType) -> bool:
""" """
Validate the dtype of the given array, returning Validate the dtype of the given array, returning

View file

@ -125,14 +125,10 @@ class NDArrayMeta(_NDArrayMeta, implementation="NDArray"):
check_type_names(dtype, dtype_per_name) check_type_names(dtype, dtype_per_name)
elif isinstance(dtype_candidate, tuple): # pragma: no cover elif isinstance(dtype_candidate, tuple): # pragma: no cover
dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate]) dtype = tuple([cls._get_dtype(dt) for dt in dtype_candidate])
else: # pragma: no cover else:
raise InvalidArgumentsError( # arbitrary dtype - allow failure elsewhere :)
f"Unexpected argument '{dtype_candidate}', expecting" dtype = dtype_candidate
" Structure[<StructureExpression>]"
" or Literal[<StructureExpression>]"
" or a dtype"
" or typing.Any."
)
return dtype return dtype
def _dtype_to_str(cls, dtype: Any) -> str: def _dtype_to_str(cls, dtype: Any) -> str:

View file

@ -8,7 +8,7 @@ import json
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import numpy as np import numpy as np
from pydantic import SerializationInfo from pydantic import BaseModel, SerializationInfo
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
from pydantic_core.core_schema import ListSchema, ValidationInfo from pydantic_core.core_schema import ListSchema, ValidationInfo
@ -66,18 +66,18 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
else: else:
try: try:
python_type = np_to_python[dtype] python_type = np_to_python[dtype]
except KeyError as e: # pragma: no cover except KeyError: # pragma: no cover
# this should pretty much only happen in downstream/3rd-party interfaces # this should pretty much only happen in downstream/3rd-party interfaces
# that use interface-specific types. those need to provide mappings back # that use interface-specific types. those need to provide mappings back
# to base python types (making this more streamlined is TODO) # to base python types (making this more streamlined is TODO)
if dtype in np_to_python.values(): if dtype in np_to_python.values():
# it's already a python type # it's already a python type
python_type = dtype python_type = dtype
elif issubclass(dtype, BaseModel):
python_type = dtype
else: else:
raise ValueError( # does this need a warning?
"dtype given in model does not have a corresponding python base " python_type = Any
"type - add one to the `maps.np_to_python` dict"
) from e
if python_type in _UNSUPPORTED_TYPES: if python_type in _UNSUPPORTED_TYPES:
array_type = core_schema.any_schema() array_type = core_schema.any_schema()

View file

@ -58,6 +58,14 @@ class ValidationCase(BaseModel):
return Model return Model
class BasicModel(BaseModel):
x: int
class BadModel(BaseModel):
x: int
RGB_UNION: TypeAlias = Union[ RGB_UNION: TypeAlias = Union[
NDArray[Shape["* x, * y"], Number], NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number], NDArray[Shape["* x, * y, 3 r_g_b"], Number],
@ -68,6 +76,7 @@ NUMBER: TypeAlias = NDArray[Shape["*, *, *"], Number]
INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer] INTEGER: TypeAlias = NDArray[Shape["*, *, *"], Integer]
FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float] FLOAT: TypeAlias = NDArray[Shape["*, *, *"], Float]
STRING: TypeAlias = NDArray[Shape["*, *, *"], str] STRING: TypeAlias = NDArray[Shape["*, *, *"], str]
MODEL: TypeAlias = NDArray[Shape["*, *, *"], BasicModel]
@pytest.fixture( @pytest.fixture(
@ -131,6 +140,9 @@ def shape_cases(request) -> ValidationCase:
ValidationCase(annotation=STRING, dtype=str, passes=True), ValidationCase(annotation=STRING, dtype=str, passes=True),
ValidationCase(annotation=STRING, dtype=int, passes=False), ValidationCase(annotation=STRING, dtype=int, passes=False),
ValidationCase(annotation=STRING, dtype=float, passes=False), ValidationCase(annotation=STRING, dtype=float, passes=False),
ValidationCase(annotation=MODEL, dtype=BasicModel, passes=True),
ValidationCase(annotation=MODEL, dtype=BadModel, passes=False),
ValidationCase(annotation=MODEL, dtype=int, passes=False),
], ],
ids=[ ids=[
"float", "float",
@ -154,6 +166,9 @@ def shape_cases(request) -> ValidationCase:
"str-str", "str-str",
"str-int", "str-int",
"str-float", "str-float",
"model-model",
"model-badmodel",
"model-int",
], ],
) )
def dtype_cases(request) -> ValidationCase: def dtype_cases(request) -> ValidationCase:

View file

@ -4,7 +4,7 @@ import pytest
import json import json
import dask.array as da import dask.array as da
from pydantic import ValidationError from pydantic import BaseModel, ValidationError
from numpydantic.interface import DaskInterface from numpydantic.interface import DaskInterface
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import DtypeError, ShapeError
@ -13,6 +13,9 @@ from tests.conftest import ValidationCase
def dask_array(case: ValidationCase) -> da.Array: def dask_array(case: ValidationCase) -> da.Array:
if issubclass(case.dtype, BaseModel):
return da.full(shape=case.shape, fill_value=case.dtype(x=1), chunks=-1)
else:
return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10) return da.zeros(shape=case.shape, dtype=case.dtype, chunks=10)

View file

@ -20,6 +20,8 @@ def hdf5_array_case(case: ValidationCase, array_func) -> H5ArrayPath:
Returns: Returns:
""" """
if issubclass(case.dtype, BaseModel):
pytest.skip("hdf5 cant support arbitrary python objects")
return array_func(case.shape, case.dtype) return array_func(case.shape, case.dtype)

View file

@ -1,12 +1,15 @@
import numpy as np import numpy as np
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError, BaseModel
from numpydantic.exceptions import DtypeError, ShapeError from numpydantic.exceptions import DtypeError, ShapeError
from tests.conftest import ValidationCase from tests.conftest import ValidationCase
def numpy_array(case: ValidationCase) -> np.ndarray: def numpy_array(case: ValidationCase) -> np.ndarray:
if issubclass(case.dtype, BaseModel):
return np.full(shape=case.shape, fill_value=case.dtype(x=1))
else:
return np.zeros(shape=case.shape, dtype=case.dtype) return np.zeros(shape=case.shape, dtype=case.dtype)

View file

@ -3,7 +3,9 @@ import json
import pytest import pytest
import zarr import zarr
from pydantic import ValidationError from pydantic import BaseModel, ValidationError
from numcodecs import Pickle
from numpydantic.interface import ZarrInterface from numpydantic.interface import ZarrInterface
from numpydantic.interface.zarr import ZarrArrayPath from numpydantic.interface.zarr import ZarrArrayPath
@ -31,6 +33,18 @@ def nested_dir_array(tmp_output_dir_func) -> zarr.NestedDirectoryStore:
def _zarr_array(case: ValidationCase, store) -> zarr.core.Array: def _zarr_array(case: ValidationCase, store) -> zarr.core.Array:
if issubclass(case.dtype, BaseModel):
pytest.skip(
f"Zarr can't handle objects properly at the moment, "
"see https://github.com/zarr-developers/zarr-python/issues/2081"
)
# return zarr.full(
# shape=case.shape,
# fill_value=case.dtype(x=1),
# dtype=object,
# object_codec=Pickle(),
# )
else:
return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store) return zarr.zeros(shape=case.shape, dtype=case.dtype, store=store)

View file

@ -266,6 +266,30 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
assert inner_type["type"] == expected assert inner_type["type"] == expected
def test_json_schema_dtype_model():
"""
Pydantic models can be used in arrays as dtypes
"""
class TestModel(BaseModel):
x: int
y: int
z: int
class MyModel(BaseModel):
array: NDArray[Shape["*, *"], TestModel]
schema = MyModel.model_json_schema()
# we should have a "$defs" with TestModel in it,
# and our array should be objects of that type
assert schema["properties"]["array"]["items"]["items"] == {
"$ref": "#/$defs/TestModel"
}
# we don't test pydantic' generic json schema model generation,
# just that one was defined
assert "TestModel" in schema["$defs"]
def _recursive_array(schema): def _recursive_array(schema):
assert "$defs" in schema assert "$defs" in schema
# get the key uses for the array # get the key uses for the array