From 3e2e6915cf14e951778422c09cc7befcab73e58c Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 14 Sep 2023 23:43:00 -0700 Subject: [PATCH] Add ability to make JSON schema with numpy arrays! --- nwb_linkml/pyproject.toml | 6 +- .../src/nwb_linkml/adapters/namespaces.py | 3 +- .../src/nwb_linkml/generators/pydantic.py | 3 +- nwb_linkml/src/nwb_linkml/io/hdf5.py | 10 +- nwb_linkml/src/nwb_linkml/maps/dtype.py | 16 ++- .../src/nwb_linkml/providers/__init__.py | 1 + nwb_linkml/src/nwb_linkml/types/__init__.py | 1 + nwb_linkml/src/nwb_linkml/types/ndarray.py | 110 ++++++++++++++++++ nwb_linkml/src/nwb_linkml/types/ndarray.pyi | 3 + nwb_linkml/tests/test_types/__init__.py | 0 nwb_linkml/tests/test_types/ndarray.py | 56 +++++++++ 11 files changed, 200 insertions(+), 9 deletions(-) create mode 100644 nwb_linkml/src/nwb_linkml/types/__init__.py create mode 100644 nwb_linkml/src/nwb_linkml/types/ndarray.py create mode 100644 nwb_linkml/src/nwb_linkml/types/ndarray.pyi create mode 100644 nwb_linkml/tests/test_types/__init__.py create mode 100644 nwb_linkml/tests/test_types/ndarray.py diff --git a/nwb_linkml/pyproject.toml b/nwb_linkml/pyproject.toml index 692fa89..77ac0db 100644 --- a/nwb_linkml/pyproject.toml +++ b/nwb_linkml/pyproject.toml @@ -26,7 +26,6 @@ pytest = { version="^7.4.0", optional=true} pytest-depends = {version="^1.0.1", optional=true} coverage = {version = "^6.1.1", optional = true} pytest-md = {version = "^0.2.0", optional = true} -pytest-emoji = {version="^0.2.0", optional = true} pytest-cov = {version = "^4.1.0", optional = true} coveralls = {version = "^3.3.1", optional = true} pytest-profiling = {version = "^1.7.0", optional = true} @@ -35,7 +34,7 @@ pydantic-settings = "^2.0.3" [tool.poetry.extras] tests = [ "pytest", "pytest-depends", "coverage", "pytest-md", - "pytest-emoji", "pytest-cov", "coveralls", "pytest-profiling" + "pytest-cov", "coveralls", "pytest-profiling" ] plot = ["dash", "dash-cytoscape"] @@ -69,8 +68,7 @@ build-backend = "poetry.core.masonry.api" addopts = [ "--cov=nwb_linkml", "--cov-append", - "--cov-config=.coveragerc", - "--emoji", + "--cov-config=.coveragerc" ] testpaths = [ "tests", diff --git a/nwb_linkml/src/nwb_linkml/adapters/namespaces.py b/nwb_linkml/src/nwb_linkml/adapters/namespaces.py index 9675a85..66615ce 100644 --- a/nwb_linkml/src/nwb_linkml/adapters/namespaces.py +++ b/nwb_linkml/src/nwb_linkml/adapters/namespaces.py @@ -47,8 +47,7 @@ class NamespacesAdapter(Adapter): with hdmf-common) """ from nwb_linkml.io import schema as schema_io - ns_adapter = schema_io.load_namespaces(path) - ns_adapter = schema_io.load_namespace_adapter(ns_adapter, path) + ns_adapter = schema_io.load_namespace_adapter(path) # try and find imported schema diff --git a/nwb_linkml/src/nwb_linkml/generators/pydantic.py b/nwb_linkml/src/nwb_linkml/generators/pydantic.py index 9c89656..d121f4a 100644 --- a/nwb_linkml/src/nwb_linkml/generators/pydantic.py +++ b/nwb_linkml/src/nwb_linkml/generators/pydantic.py @@ -64,7 +64,8 @@ from datetime import datetime, date from enum import Enum from typing import List, Dict, Optional, Any, Union from pydantic import BaseModel as BaseModel, Field -from nptyping import NDArray, Shape, Float, Float32, Double, Float64, LongLong, Int64, Int, Int32, Int16, Short, Int8, UInt, UInt32, UInt16, UInt8, UInt64, Number, String, Unicode, Unicode, Unicode, String, Bool, Datetime64 +from nptyping import Shape, Float, Float32, Double, Float64, LongLong, Int64, Int, Int32, Int16, Short, Int8, UInt, UInt32, UInt16, UInt8, UInt64, Number, String, Unicode, Unicode, Unicode, String, Bool, Datetime64 +from nwb_linkml.types import NDArray import sys if sys.version_info >= (3, 8): from typing import Literal diff --git a/nwb_linkml/src/nwb_linkml/io/hdf5.py b/nwb_linkml/src/nwb_linkml/io/hdf5.py index 7d17bb9..8f77f8e 100644 --- a/nwb_linkml/src/nwb_linkml/io/hdf5.py +++ b/nwb_linkml/src/nwb_linkml/io/hdf5.py @@ -26,6 +26,7 @@ class HDF5Element(): cls: h5py.Dataset | h5py.Group parent: Type[BaseModel] model: Optional[Any] = None + root_model: Optional[Type[BaseModel]] = None @abstractmethod def read(self) -> BaseModel | List[BaseModel]: @@ -89,6 +90,13 @@ def take_outer_type(annotation): if typing.get_origin(annotation) is list: return list return annotation + +def submodel_by_path(model: BaseModel, path:str) -> Type[BaseModel | dict | list]: + """ + Given a pydantic model and an absolute HDF5 path, get the type annotation + """ + + @dataclass class H5Dataset(HDF5Element): cls: h5py.Dataset @@ -178,7 +186,7 @@ class HDF5IO(): data = {} for k, v in src.items(): if isinstance(v, h5py.Group): - data[k] = H5Group(cls=v, parent=parent).read() + data[k] = H5Group(cls=v, parent=parent, root_model=parent).read() elif isinstance(v, h5py.Dataset): data[k] = H5Dataset(cls=v, parent=parent).read() diff --git a/nwb_linkml/src/nwb_linkml/maps/dtype.py b/nwb_linkml/src/nwb_linkml/maps/dtype.py index dd6cec6..d39dd60 100644 --- a/nwb_linkml/src/nwb_linkml/maps/dtype.py +++ b/nwb_linkml/src/nwb_linkml/maps/dtype.py @@ -1,4 +1,5 @@ - +import numpy as np +from typing import Any flat_to_linkml = { "float" : "float", @@ -56,4 +57,17 @@ flat_to_npytyping = { "bool": "Bool", "isodatetime": "Datetime64", 'AnyType': 'Any' +} + +np_to_python = { + Any: Any, + np.number: float, + np.object_: Any, + np.bool_: bool, + np.integer: int, + np.byte: bytes, + np.bytes_: bytes, + **{n:int for n in (np.int8, np.int16, np.int32, np.int64, np.short, np.uint8, np.uint16, np.uint32, np.uint64, np.uint)}, + **{n:float for n in (np.float16, np.float32, np.floating, np.float32, np.float64, np.single, np.double, np.float_)}, + **{n:str for n in (np.character, np.str_, np.string_, np.unicode_)} } \ No newline at end of file diff --git a/nwb_linkml/src/nwb_linkml/providers/__init__.py b/nwb_linkml/src/nwb_linkml/providers/__init__.py index e69de29..791fb43 100644 --- a/nwb_linkml/src/nwb_linkml/providers/__init__.py +++ b/nwb_linkml/src/nwb_linkml/providers/__init__.py @@ -0,0 +1 @@ +from nwb_linkml.providers.schema import LinkMLProvider, SchemaProvider, PydanticProvider \ No newline at end of file diff --git a/nwb_linkml/src/nwb_linkml/types/__init__.py b/nwb_linkml/src/nwb_linkml/types/__init__.py new file mode 100644 index 0000000..801a327 --- /dev/null +++ b/nwb_linkml/src/nwb_linkml/types/__init__.py @@ -0,0 +1 @@ +from nwb_linkml.types.ndarray import NDArray \ No newline at end of file diff --git a/nwb_linkml/src/nwb_linkml/types/ndarray.py b/nwb_linkml/src/nwb_linkml/types/ndarray.py new file mode 100644 index 0000000..63ae36d --- /dev/null +++ b/nwb_linkml/src/nwb_linkml/types/ndarray.py @@ -0,0 +1,110 @@ +""" +Extension of nptyping NDArray for pydantic that allows for JSON-Schema serialization + +* Order to store data in (row first) +""" +import pdb +from typing import ( + Any, + Callable, + Annotated, +Generic, +TypeVar +) + +from pydantic_core import core_schema +from pydantic import ( + BaseModel, + GetJsonSchemaHandler, + ValidationError, + GetCoreSchemaHandler +) +from pydantic.json_schema import JsonSchemaValue + +import numpy as np + +from nptyping import NDArray as _NDArray +from nptyping.ndarray import NDArrayMeta +from nptyping import Shape, Number +from nptyping.shape_expression import check_shape + +from nwb_linkml.maps.dtype import np_to_python + +class NDArray(_NDArray): + """ + Following the example here: https://docs.pydantic.dev/latest/usage/types/custom/#handling-third-party-types + """ + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: _NDArray, + _handler: Callable[[Any], core_schema.CoreSchema], + + ) -> core_schema.CoreSchema: + + shape, dtype = _source_type.__args__ + # get pydantic core schema for the given specified type + array_type_handler = _handler.generate_schema( + np_to_python[dtype]) + + def validate_dtype(value: np.ndarray) -> np.ndarray: + assert value.dtype == dtype, f"Invalid dtype! expected {dtype}, got {value.dtype}" + return value + def validate_array(value: Any) -> np.ndarray: + assert cls.__instancecheck__(value), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}' + return value + + # get the names of the shape constraints, if any + shape_parts = shape.__args__[0].split(',') + split_parts = [p.split(' ')[1] if len(p.split(' ')) == 2 else None for p in shape_parts] + + + # Construct a list of list schema + # go in reverse order - construct list schemas such that + # the final schema is the one that checks the first dimension + shape_labels = reversed(split_parts) + shape_args = reversed(shape.prepared_args) + list_schema = None + for arg, label in zip(shape_args, shape_labels): + # which handler to use? for the first we use the actual type + # handler, everywhere else we use the prior list handler + if list_schema is None: + inner_schema = array_type_handler + else: + inner_schema = list_schema + + # make a label annotation, if we have one + if label is not None: + metadata = {'name': label} + else: + metadata = None + + # make the current level list schema, accounting for shape + if arg == '*': + list_schema = core_schema.list_schema(inner_schema, + metadata=metadata) + else: + arg = int(arg) + list_schema = core_schema.list_schema( + inner_schema, + min_length=arg, + max_length=arg, + metadata=metadata + ) + + + return core_schema.json_or_python_schema( + json_schema=list_schema, + python_schema=core_schema.chain_schema( + [ + core_schema.is_instance_schema(np.ndarray), + core_schema.no_info_plain_validator_function(validate_dtype), + core_schema.no_info_plain_validator_function(validate_array) + ] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.tolist(), + when_used='json' + ) + ) \ No newline at end of file diff --git a/nwb_linkml/src/nwb_linkml/types/ndarray.pyi b/nwb_linkml/src/nwb_linkml/types/ndarray.pyi new file mode 100644 index 0000000..1aa5dd7 --- /dev/null +++ b/nwb_linkml/src/nwb_linkml/types/ndarray.pyi @@ -0,0 +1,3 @@ +import numpy as np + +NDArray = np.ndarray \ No newline at end of file diff --git a/nwb_linkml/tests/test_types/__init__.py b/nwb_linkml/tests/test_types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nwb_linkml/tests/test_types/ndarray.py b/nwb_linkml/tests/test_types/ndarray.py new file mode 100644 index 0000000..14f26b8 --- /dev/null +++ b/nwb_linkml/tests/test_types/ndarray.py @@ -0,0 +1,56 @@ +import pdb +from typing import Union, Optional, Any + +import pytest + +import numpy as np + +from pydantic import BaseModel, ValidationError, Field +from nwb_linkml.types.ndarray import NDArray +from nptyping import Shape, Number + +def test_ndarray_type(): + + class Model(BaseModel): + array: NDArray[Shape["2 x, * y"], Number] + + schema = Model.model_json_schema() + assert schema['properties']['array']['items'] == {'items': {'type': 'number'}, 'type': 'array'} + assert schema['properties']['array']['maxItems'] == 2 + assert schema['properties']['array']['minItems'] == 2 + + # models should instantiate correctly! + instance = Model(array=np.zeros((2,3))) + + with pytest.raises(ValidationError): + instance = Model(array=np.zeros((4,6))) + + with pytest.raises(ValidationError): + instance = Model(array=np.ones((2,3), dtype=bool)) + + +def test_ndarray_union(): + class Model(BaseModel): + array: Optional[Union[ + NDArray[Shape["* x, * y"], Number], + NDArray[Shape["* x, * y, 3 r_g_b"], Number], + NDArray[Shape["* x, * y, 3 r_g_b, 4 r_g_b_a"], Number] + ]] = Field(None) + + instance = Model() + instance = Model(array=np.random.random((5,10))) + instance = Model(array=np.random.random((5,10,3))) + instance = Model(array=np.random.random((5,10,3,4))) + + with pytest.raises(ValidationError): + instance = Model(array=np.random.random((5,))) + + with pytest.raises(ValidationError): + instance = Model(array=np.random.random((5,10,4))) + + with pytest.raises(ValidationError): + instance = Model(array=np.random.random((5,10,3,6))) + + with pytest.raises(ValidationError): + instance = Model(array=np.random.random((5,10,4,6))) +