Add ability to make JSON schema with numpy arrays!

This commit is contained in:
sneakers-the-rat 2023-09-14 23:43:00 -07:00
parent e6a41415f5
commit 3e2e6915cf
11 changed files with 200 additions and 9 deletions

View file

@ -26,7 +26,6 @@ pytest = { version="^7.4.0", optional=true}
pytest-depends = {version="^1.0.1", optional=true}
coverage = {version = "^6.1.1", optional = true}
pytest-md = {version = "^0.2.0", optional = true}
pytest-emoji = {version="^0.2.0", optional = true}
pytest-cov = {version = "^4.1.0", optional = true}
coveralls = {version = "^3.3.1", optional = true}
pytest-profiling = {version = "^1.7.0", optional = true}
@ -35,7 +34,7 @@ pydantic-settings = "^2.0.3"
[tool.poetry.extras]
tests = [
"pytest", "pytest-depends", "coverage", "pytest-md",
"pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling"
"pytest-cov", "coveralls", "pytest-profiling"
]
plot = ["dash", "dash-cytoscape"]
@ -69,8 +68,7 @@ build-backend = "poetry.core.masonry.api"
addopts = [
"--cov=nwb_linkml",
"--cov-append",
"--cov-config=.coveragerc",
"--emoji",
"--cov-config=.coveragerc"
]
testpaths = [
"tests",

View file

@ -47,8 +47,7 @@ class NamespacesAdapter(Adapter):
with hdmf-common)
"""
from nwb_linkml.io import schema as schema_io
ns_adapter = schema_io.load_namespaces(path)
ns_adapter = schema_io.load_namespace_adapter(ns_adapter, path)
ns_adapter = schema_io.load_namespace_adapter(path)
# try and find imported schema

View file

@ -64,7 +64,8 @@ from datetime import datetime, date
from enum import Enum
from typing import List, Dict, Optional, Any, Union
from pydantic import BaseModel as BaseModel, Field
from nptyping import NDArray, Shape, Float, Float32, Double, Float64, LongLong, Int64, Int, Int32, Int16, Short, Int8, UInt, UInt32, UInt16, UInt8, UInt64, Number, String, Unicode, Unicode, Unicode, String, Bool, Datetime64
from nptyping import Shape, Float, Float32, Double, Float64, LongLong, Int64, Int, Int32, Int16, Short, Int8, UInt, UInt32, UInt16, UInt8, UInt64, Number, String, Unicode, Unicode, Unicode, String, Bool, Datetime64
from nwb_linkml.types import NDArray
import sys
if sys.version_info >= (3, 8):
from typing import Literal

View file

@ -26,6 +26,7 @@ class HDF5Element():
cls: h5py.Dataset | h5py.Group
parent: Type[BaseModel]
model: Optional[Any] = None
root_model: Optional[Type[BaseModel]] = None
@abstractmethod
def read(self) -> BaseModel | List[BaseModel]:
@ -89,6 +90,13 @@ def take_outer_type(annotation):
if typing.get_origin(annotation) is list:
return list
return annotation
def submodel_by_path(model: BaseModel, path:str) -> Type[BaseModel | dict | list]:
"""
Given a pydantic model and an absolute HDF5 path, get the type annotation
"""
@dataclass
class H5Dataset(HDF5Element):
cls: h5py.Dataset
@ -178,7 +186,7 @@ class HDF5IO():
data = {}
for k, v in src.items():
if isinstance(v, h5py.Group):
data[k] = H5Group(cls=v, parent=parent).read()
data[k] = H5Group(cls=v, parent=parent, root_model=parent).read()
elif isinstance(v, h5py.Dataset):
data[k] = H5Dataset(cls=v, parent=parent).read()

View file

@ -1,4 +1,5 @@
import numpy as np
from typing import Any
flat_to_linkml = {
"float" : "float",
@ -56,4 +57,17 @@ flat_to_npytyping = {
"bool": "Bool",
"isodatetime": "Datetime64",
'AnyType': 'Any'
}
np_to_python = {
Any: Any,
np.number: float,
np.object_: Any,
np.bool_: bool,
np.integer: int,
np.byte: bytes,
np.bytes_: bytes,
**{n:int for n in (np.int8, np.int16, np.int32, np.int64, np.short, np.uint8, np.uint16, np.uint32, np.uint64, np.uint)},
**{n:float for n in (np.float16, np.float32, np.floating, np.float32, np.float64, np.single, np.double, np.float_)},
**{n:str for n in (np.character, np.str_, np.string_, np.unicode_)}
}

View file

@ -0,0 +1 @@
from nwb_linkml.providers.schema import LinkMLProvider, SchemaProvider, PydanticProvider

View file

@ -0,0 +1 @@
from nwb_linkml.types.ndarray import NDArray

View file

@ -0,0 +1,110 @@
"""
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization
* Order to store data in (row first)
"""
import pdb
from typing import (
Any,
Callable,
Annotated,
Generic,
TypeVar
)
from pydantic_core import core_schema
from pydantic import (
BaseModel,
GetJsonSchemaHandler,
ValidationError,
GetCoreSchemaHandler
)
from pydantic.json_schema import JsonSchemaValue
import numpy as np
from nptyping import NDArray as _NDArray
from nptyping.ndarray import NDArrayMeta
from nptyping import Shape, Number
from nptyping.shape_expression import check_shape
from nwb_linkml.maps.dtype import np_to_python
class NDArray(_NDArray):
"""
Following the example here: https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
"""
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: _NDArray,
_handler: Callable[[Any], core_schema.CoreSchema],
) -> core_schema.CoreSchema:
shape, dtype = _source_type.__args__
# get pydantic core schema for the given specified type
array_type_handler = _handler.generate_schema(
np_to_python[dtype])
def validate_dtype(value: np.ndarray) -> np.ndarray:
assert value.dtype == dtype, f"Invalid dtype! expected {dtype}, got {value.dtype}"
return value
def validate_array(value: Any) -> np.ndarray:
assert cls.__instancecheck__(value), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
return value
# get the names of the shape constraints, if any
shape_parts = shape.__args__[0].split(',')
split_parts = [p.split(' ')[1] if len(p.split(' ')) == 2 else None for p in shape_parts]
# Construct a list of list schema
# go in reverse order - construct list schemas such that
# the final schema is the one that checks the first dimension
shape_labels = reversed(split_parts)
shape_args = reversed(shape.prepared_args)
list_schema = None
for arg, label in zip(shape_args, shape_labels):
# which handler to use? for the first we use the actual type
# handler, everywhere else we use the prior list handler
if list_schema is None:
inner_schema = array_type_handler
else:
inner_schema = list_schema
# make a label annotation, if we have one
if label is not None:
metadata = {'name': label}
else:
metadata = None
# make the current level list schema, accounting for shape
if arg == '*':
list_schema = core_schema.list_schema(inner_schema,
metadata=metadata)
else:
arg = int(arg)
list_schema = core_schema.list_schema(
inner_schema,
min_length=arg,
max_length=arg,
metadata=metadata
)
return core_schema.json_or_python_schema(
json_schema=list_schema,
python_schema=core_schema.chain_schema(
[
core_schema.is_instance_schema(np.ndarray),
core_schema.no_info_plain_validator_function(validate_dtype),
core_schema.no_info_plain_validator_function(validate_array)
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.tolist(),
when_used='json'
)
)

View file

@ -0,0 +1,3 @@
import numpy as np
NDArray = np.ndarray

View file

View file

@ -0,0 +1,56 @@
import pdb
from typing import Union, Optional, Any
import pytest
import numpy as np
from pydantic import BaseModel, ValidationError, Field
from nwb_linkml.types.ndarray import NDArray
from nptyping import Shape, Number
def test_ndarray_type():
class Model(BaseModel):
array: NDArray[Shape["2 x, * y"], Number]
schema = Model.model_json_schema()
assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'}
assert schema['properties']['array']['maxItems'] == 2
assert schema['properties']['array']['minItems'] == 2
# models should instantiate correctly!
instance = Model(array=np.zeros((2,3)))
with pytest.raises(ValidationError):
instance = Model(array=np.zeros((4,6)))
with pytest.raises(ValidationError):
instance = Model(array=np.ones((2,3), dtype=bool))
def test_ndarray_union():
class Model(BaseModel):
array: Optional[Union[
NDArray[Shape["* x, * y"], Number],
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number]
]] = Field(None)
instance = Model()
instance = Model(array=np.random.random((5,10)))
instance = Model(array=np.random.random((5,10,3)))
instance = Model(array=np.random.random((5,10,3,4)))
with pytest.raises(ValidationError):
instance = Model(array=np.random.random((5,)))
with pytest.raises(ValidationError):
instance = Model(array=np.random.random((5,10,4)))
with pytest.raises(ValidationError):
instance = Model(array=np.random.random((5,10,3,6)))
with pytest.raises(ValidationError):
instance = Model(array=np.random.random((5,10,4,6)))