mirror of
https://github.com/p2p-ld/nwb-linkml.git
synced 2024-11-10 00:34:29 +00:00
Add ability to make JSON schema with numpy arrays!
This commit is contained in:
parent
e6a41415f5
commit
3e2e6915cf
11 changed files with 200 additions and 9 deletions
|
@ -26,7 +26,6 @@ pytest = { version="^7.4.0", optional=true}
|
||||||
pytest-depends = {version="^1.0.1", optional=true}
|
pytest-depends = {version="^1.0.1", optional=true}
|
||||||
coverage = {version = "^6.1.1", optional = true}
|
coverage = {version = "^6.1.1", optional = true}
|
||||||
pytest-md = {version = "^0.2.0", optional = true}
|
pytest-md = {version = "^0.2.0", optional = true}
|
||||||
pytest-emoji = {version="^0.2.0", optional = true}
|
|
||||||
pytest-cov = {version = "^4.1.0", optional = true}
|
pytest-cov = {version = "^4.1.0", optional = true}
|
||||||
coveralls = {version = "^3.3.1", optional = true}
|
coveralls = {version = "^3.3.1", optional = true}
|
||||||
pytest-profiling = {version = "^1.7.0", optional = true}
|
pytest-profiling = {version = "^1.7.0", optional = true}
|
||||||
|
@ -35,7 +34,7 @@ pydantic-settings = "^2.0.3"
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
tests = [
|
tests = [
|
||||||
"pytest", "pytest-depends", "coverage", "pytest-md",
|
"pytest", "pytest-depends", "coverage", "pytest-md",
|
||||||
"pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling"
|
"pytest-cov", "coveralls", "pytest-profiling"
|
||||||
]
|
]
|
||||||
plot = ["dash", "dash-cytoscape"]
|
plot = ["dash", "dash-cytoscape"]
|
||||||
|
|
||||||
|
@ -69,8 +68,7 @@ build-backend = "poetry.core.masonry.api"
|
||||||
addopts = [
|
addopts = [
|
||||||
"--cov=nwb_linkml",
|
"--cov=nwb_linkml",
|
||||||
"--cov-append",
|
"--cov-append",
|
||||||
"--cov-config=.coveragerc",
|
"--cov-config=.coveragerc"
|
||||||
"--emoji",
|
|
||||||
]
|
]
|
||||||
testpaths = [
|
testpaths = [
|
||||||
"tests",
|
"tests",
|
||||||
|
|
|
@ -47,8 +47,7 @@ class NamespacesAdapter(Adapter):
|
||||||
with hdmf-common)
|
with hdmf-common)
|
||||||
"""
|
"""
|
||||||
from nwb_linkml.io import schema as schema_io
|
from nwb_linkml.io import schema as schema_io
|
||||||
ns_adapter = schema_io.load_namespaces(path)
|
ns_adapter = schema_io.load_namespace_adapter(path)
|
||||||
ns_adapter = schema_io.load_namespace_adapter(ns_adapter, path)
|
|
||||||
|
|
||||||
# try and find imported schema
|
# try and find imported schema
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,8 @@ from datetime import datetime, date
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Dict, Optional, Any, Union
|
from typing import List, Dict, Optional, Any, Union
|
||||||
from pydantic import BaseModel as BaseModel, Field
|
from pydantic import BaseModel as BaseModel, Field
|
||||||
from nptyping import NDArray, Shape, Float, Float32, Double, Float64, LongLong, Int64, Int, Int32, Int16, Short, Int8, UInt, UInt32, UInt16, UInt8, UInt64, Number, String, Unicode, Unicode, Unicode, String, Bool, Datetime64
|
from nptyping import Shape, Float, Float32, Double, Float64, LongLong, Int64, Int, Int32, Int16, Short, Int8, UInt, UInt32, UInt16, UInt8, UInt64, Number, String, Unicode, Unicode, Unicode, String, Bool, Datetime64
|
||||||
|
from nwb_linkml.types import NDArray
|
||||||
import sys
|
import sys
|
||||||
if sys.version_info >= (3, 8):
|
if sys.version_info >= (3, 8):
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
|
@ -26,6 +26,7 @@ class HDF5Element():
|
||||||
cls: h5py.Dataset | h5py.Group
|
cls: h5py.Dataset | h5py.Group
|
||||||
parent: Type[BaseModel]
|
parent: Type[BaseModel]
|
||||||
model: Optional[Any] = None
|
model: Optional[Any] = None
|
||||||
|
root_model: Optional[Type[BaseModel]] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def read(self) -> BaseModel | List[BaseModel]:
|
def read(self) -> BaseModel | List[BaseModel]:
|
||||||
|
@ -89,6 +90,13 @@ def take_outer_type(annotation):
|
||||||
if typing.get_origin(annotation) is list:
|
if typing.get_origin(annotation) is list:
|
||||||
return list
|
return list
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
|
def submodel_by_path(model: BaseModel, path:str) -> Type[BaseModel | dict | list]:
|
||||||
|
"""
|
||||||
|
Given a pydantic model and an absolute HDF5 path, get the type annotation
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class H5Dataset(HDF5Element):
|
class H5Dataset(HDF5Element):
|
||||||
cls: h5py.Dataset
|
cls: h5py.Dataset
|
||||||
|
@ -178,7 +186,7 @@ class HDF5IO():
|
||||||
data = {}
|
data = {}
|
||||||
for k, v in src.items():
|
for k, v in src.items():
|
||||||
if isinstance(v, h5py.Group):
|
if isinstance(v, h5py.Group):
|
||||||
data[k] = H5Group(cls=v, parent=parent).read()
|
data[k] = H5Group(cls=v, parent=parent, root_model=parent).read()
|
||||||
elif isinstance(v, h5py.Dataset):
|
elif isinstance(v, h5py.Dataset):
|
||||||
data[k] = H5Dataset(cls=v, parent=parent).read()
|
data[k] = H5Dataset(cls=v, parent=parent).read()
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
|
import numpy as np
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
flat_to_linkml = {
|
flat_to_linkml = {
|
||||||
"float" : "float",
|
"float" : "float",
|
||||||
|
@ -57,3 +58,16 @@ flat_to_npytyping = {
|
||||||
"isodatetime": "Datetime64",
|
"isodatetime": "Datetime64",
|
||||||
'AnyType': 'Any'
|
'AnyType': 'Any'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
np_to_python = {
|
||||||
|
Any: Any,
|
||||||
|
np.number: float,
|
||||||
|
np.object_: Any,
|
||||||
|
np.bool_: bool,
|
||||||
|
np.integer: int,
|
||||||
|
np.byte: bytes,
|
||||||
|
np.bytes_: bytes,
|
||||||
|
**{n:int for n in (np.int8, np.int16, np.int32, np.int64, np.short, np.uint8, np.uint16, np.uint32, np.uint64, np.uint)},
|
||||||
|
**{n:float for n in (np.float16, np.float32, np.floating, np.float32, np.float64, np.single, np.double, np.float_)},
|
||||||
|
**{n:str for n in (np.character, np.str_, np.string_, np.unicode_)}
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
from nwb_linkml.providers.schema import LinkMLProvider, SchemaProvider, PydanticProvider
|
1
nwb_linkml/src/nwb_linkml/types/__init__.py
Normal file
1
nwb_linkml/src/nwb_linkml/types/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from nwb_linkml.types.ndarray import NDArray
|
110
nwb_linkml/src/nwb_linkml/types/ndarray.py
Normal file
110
nwb_linkml/src/nwb_linkml/types/ndarray.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
"""
|
||||||
|
Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization
|
||||||
|
|
||||||
|
* Order to store data in (row first)
|
||||||
|
"""
|
||||||
|
import pdb
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Annotated,
|
||||||
|
Generic,
|
||||||
|
TypeVar
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
GetJsonSchemaHandler,
|
||||||
|
ValidationError,
|
||||||
|
GetCoreSchemaHandler
|
||||||
|
)
|
||||||
|
from pydantic.json_schema import JsonSchemaValue
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from nptyping import NDArray as _NDArray
|
||||||
|
from nptyping.ndarray import NDArrayMeta
|
||||||
|
from nptyping import Shape, Number
|
||||||
|
from nptyping.shape_expression import check_shape
|
||||||
|
|
||||||
|
from nwb_linkml.maps.dtype import np_to_python
|
||||||
|
|
||||||
|
class NDArray(_NDArray):
|
||||||
|
"""
|
||||||
|
Following the example here: https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls,
|
||||||
|
_source_type: _NDArray,
|
||||||
|
_handler: Callable[[Any], core_schema.CoreSchema],
|
||||||
|
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
|
||||||
|
shape, dtype = _source_type.__args__
|
||||||
|
# get pydantic core schema for the given specified type
|
||||||
|
array_type_handler = _handler.generate_schema(
|
||||||
|
np_to_python[dtype])
|
||||||
|
|
||||||
|
def validate_dtype(value: np.ndarray) -> np.ndarray:
|
||||||
|
assert value.dtype == dtype, f"Invalid dtype! expected {dtype}, got {value.dtype}"
|
||||||
|
return value
|
||||||
|
def validate_array(value: Any) -> np.ndarray:
|
||||||
|
assert cls.__instancecheck__(value), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
|
||||||
|
return value
|
||||||
|
|
||||||
|
# get the names of the shape constraints, if any
|
||||||
|
shape_parts = shape.__args__[0].split(',')
|
||||||
|
split_parts = [p.split(' ')[1] if len(p.split(' ')) == 2 else None for p in shape_parts]
|
||||||
|
|
||||||
|
|
||||||
|
# Construct a list of list schema
|
||||||
|
# go in reverse order - construct list schemas such that
|
||||||
|
# the final schema is the one that checks the first dimension
|
||||||
|
shape_labels = reversed(split_parts)
|
||||||
|
shape_args = reversed(shape.prepared_args)
|
||||||
|
list_schema = None
|
||||||
|
for arg, label in zip(shape_args, shape_labels):
|
||||||
|
# which handler to use? for the first we use the actual type
|
||||||
|
# handler, everywhere else we use the prior list handler
|
||||||
|
if list_schema is None:
|
||||||
|
inner_schema = array_type_handler
|
||||||
|
else:
|
||||||
|
inner_schema = list_schema
|
||||||
|
|
||||||
|
# make a label annotation, if we have one
|
||||||
|
if label is not None:
|
||||||
|
metadata = {'name': label}
|
||||||
|
else:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
# make the current level list schema, accounting for shape
|
||||||
|
if arg == '*':
|
||||||
|
list_schema = core_schema.list_schema(inner_schema,
|
||||||
|
metadata=metadata)
|
||||||
|
else:
|
||||||
|
arg = int(arg)
|
||||||
|
list_schema = core_schema.list_schema(
|
||||||
|
inner_schema,
|
||||||
|
min_length=arg,
|
||||||
|
max_length=arg,
|
||||||
|
metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
return core_schema.json_or_python_schema(
|
||||||
|
json_schema=list_schema,
|
||||||
|
python_schema=core_schema.chain_schema(
|
||||||
|
[
|
||||||
|
core_schema.is_instance_schema(np.ndarray),
|
||||||
|
core_schema.no_info_plain_validator_function(validate_dtype),
|
||||||
|
core_schema.no_info_plain_validator_function(validate_array)
|
||||||
|
]
|
||||||
|
),
|
||||||
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
|
lambda instance: instance.tolist(),
|
||||||
|
when_used='json'
|
||||||
|
)
|
||||||
|
)
|
3
nwb_linkml/src/nwb_linkml/types/ndarray.pyi
Normal file
3
nwb_linkml/src/nwb_linkml/types/ndarray.pyi
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
NDArray = np.ndarray
|
0
nwb_linkml/tests/test_types/__init__.py
Normal file
0
nwb_linkml/tests/test_types/__init__.py
Normal file
56
nwb_linkml/tests/test_types/ndarray.py
Normal file
56
nwb_linkml/tests/test_types/ndarray.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
import pdb
|
||||||
|
from typing import Union, Optional, Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ValidationError, Field
|
||||||
|
from nwb_linkml.types.ndarray import NDArray
|
||||||
|
from nptyping import Shape, Number
|
||||||
|
|
||||||
|
def test_ndarray_type():
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: NDArray[Shape["2 x, * y"], Number]
|
||||||
|
|
||||||
|
schema = Model.model_json_schema()
|
||||||
|
assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'}
|
||||||
|
assert schema['properties']['array']['maxItems'] == 2
|
||||||
|
assert schema['properties']['array']['minItems'] == 2
|
||||||
|
|
||||||
|
# models should instantiate correctly!
|
||||||
|
instance = Model(array=np.zeros((2,3)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.zeros((4,6)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.ones((2,3), dtype=bool))
|
||||||
|
|
||||||
|
|
||||||
|
def test_ndarray_union():
|
||||||
|
class Model(BaseModel):
|
||||||
|
array: Optional[Union[
|
||||||
|
NDArray[Shape["* x, * y"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b"], Number],
|
||||||
|
NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number]
|
||||||
|
]] = Field(None)
|
||||||
|
|
||||||
|
instance = Model()
|
||||||
|
instance = Model(array=np.random.random((5,10)))
|
||||||
|
instance = Model(array=np.random.random((5,10,3)))
|
||||||
|
instance = Model(array=np.random.random((5,10,3,4)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5,)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5,10,4)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5,10,3,6)))
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
instance = Model(array=np.random.random((5,10,4,6)))
|
||||||
|
|
Loading…
Reference in a new issue