mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
Merge pull request #20 from p2p-ld/dump_json
Roundtrip JSON serialization/deserialization
This commit is contained in:
commit
bd5b93773b
39 changed files with 2774 additions and 203 deletions
|
@ -1,6 +0,0 @@
|
|||
# monkeypatch
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: numpydantic.monkeypatch
|
||||
:members:
|
||||
```
|
7
docs/api/serialization.md
Normal file
7
docs/api/serialization.md
Normal file
|
@ -0,0 +1,7 @@
|
|||
# serialization
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: numpydantic.serialization
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
|
@ -2,6 +2,73 @@
|
|||
|
||||
## 1.*
|
||||
|
||||
### 1.6.*
|
||||
|
||||
#### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization
|
||||
|
||||
Roundtrip JSON serialization is here - with serialization to list of lists,
|
||||
as well as file references that don't require copying the whole array if
|
||||
used in data modeling, control over path relativization, and stamping of
|
||||
interface version for the extra provenance conscious.
|
||||
|
||||
Please see [serialization](./serialization.md) for narrative documentation :)
|
||||
|
||||
**Potentially Breaking Changes**
|
||||
- See [development](./development.md) for a statement about API stability
|
||||
- An additional {meth}`.Interface.deserialize` method has been added to
|
||||
{meth}`.Interface.validate` - downstream users are not intended to override the
|
||||
`validate method`, but if they have, then JSON deserialization will not work for them.
|
||||
- `Interface` subclasses now require a `name` attribute, a short string identifier for that interface,
|
||||
and a `json_model` that inherits from {class}`.interface.JsonDict`. Interfaces without
|
||||
these attributes will not be able to be instantiated.
|
||||
- {meth}`.Interface.to_json` is now an abstract method that all interfaces must define.
|
||||
|
||||
**Features**
|
||||
- Roundtrip JSON serialization - by default dump to a list of list arrays, but
|
||||
support the `round_trip` keyword in `model_dump_json` for provenance-preserving dumps
|
||||
- JSON Schema generation has been separated from `core_schema` generation in {class}`.NDArray`.
|
||||
Downstream interfaces can customize json schema generation without compromising ability to validate.
|
||||
- All proxy classes must have an `__eq__` dunder method to compare equality -
|
||||
in proxy classes, these compare equality of arguments, since the arrays that
|
||||
are referenced on disk should be equal by definition. Direct array comparison
|
||||
should use {func}`numpy.array_equal`
|
||||
- Interfaces previously couldn't be instantiated without explicit shape and dtype arguments,
|
||||
these have been given `Any` defaults.
|
||||
- New {mod}`numpydantic.serialization` module to contain serialization logic.
|
||||
|
||||
**New Classes**
|
||||
See the docstrings for descriptions of each class
|
||||
- `MarkMismatchError` for when an array serialized with `mark_interface` doesn't match
|
||||
the interface that's deserializing it
|
||||
- {class}`.interface.InterfaceMark`
|
||||
- {class}`.interface.MarkedJson`
|
||||
- {class}`.interface.JsonDict`
|
||||
- {class}`.dask.DaskJsonDict`
|
||||
- {class}`.hdf5.H5JsonDict`
|
||||
- {class}`.numpy.NumpyJsonDict`
|
||||
- {class}`.video.VideoJsonDict`
|
||||
- {class}`.zarr.ZarrJsonDict`
|
||||
|
||||
**Bugfix**
|
||||
- [`#17`](https://github.com/p2p-ld/numpydantic/issues/17) - Arrays are re-validated as lists, rather than arrays
|
||||
- Some proxy classes would fail to be serialized becauase they lacked an `__array__` method.
|
||||
`__array__` methods have been added, and tests for coercing to an array to prevent regression.
|
||||
- Some proxy classes lacked a `__name__` attribute, which caused failures to serialize
|
||||
when the `__getattr__` methods attempted to pass it through. These have been added where needed.
|
||||
|
||||
**Docs**
|
||||
- Add statement about versioning and API stability to [development](./development.md)
|
||||
- Add docs for serialization!
|
||||
- Remove stranded docs from hooks and monkeypatch
|
||||
- Added `myst_nb` to docs dependencies for direct rendering of code and output
|
||||
|
||||
**Tests**
|
||||
- Marks have been added for running subsets of the tests for a given interface,
|
||||
package feature, etc.
|
||||
- Tests for all the above functionality
|
||||
|
||||
|
||||
|
||||
### 1.5.*
|
||||
|
||||
#### 1.5.3 - 24-09-03 - Bugfix, type checking for empty HDF5 datasets
|
||||
|
|
|
@ -25,7 +25,7 @@ extensions = [
|
|||
"sphinx.ext.doctest",
|
||||
"sphinx_design",
|
||||
"sphinxcontrib.mermaid",
|
||||
"myst_parser",
|
||||
"myst_nb",
|
||||
"sphinx.ext.todo",
|
||||
]
|
||||
|
||||
|
@ -77,3 +77,8 @@ napoleon_attr_annotations = True
|
|||
# todo
|
||||
todo_include_todos = True
|
||||
todo_link_only = True
|
||||
|
||||
# myst
|
||||
# myst-nb
|
||||
nb_render_markdown_format = "myst"
|
||||
nb_execution_show_tb = True
|
||||
|
|
BIN
docs/data/test.avi
Normal file
BIN
docs/data/test.avi
Normal file
Binary file not shown.
BIN
docs/data/test.h5
Normal file
BIN
docs/data/test.h5
Normal file
Binary file not shown.
22
docs/data/test.zarr/.zarray
Normal file
22
docs/data/test.zarr/.zarray
Normal file
|
@ -0,0 +1,22 @@
|
|||
{
|
||||
"chunks": [
|
||||
2,
|
||||
2
|
||||
],
|
||||
"compressor": {
|
||||
"blocksize": 0,
|
||||
"clevel": 5,
|
||||
"cname": "lz4",
|
||||
"id": "blosc",
|
||||
"shuffle": 1
|
||||
},
|
||||
"dtype": "<i8",
|
||||
"fill_value": 0,
|
||||
"filters": null,
|
||||
"order": "C",
|
||||
"shape": [
|
||||
2,
|
||||
2
|
||||
],
|
||||
"zarr_format": 2
|
||||
}
|
BIN
docs/data/test.zarr/0.0
Normal file
BIN
docs/data/test.zarr/0.0
Normal file
Binary file not shown.
84
docs/development.md
Normal file
84
docs/development.md
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Development
|
||||
|
||||
## Versioning
|
||||
|
||||
This package uses a colloquial form of [semantic versioning 2](https://semver.org/).
|
||||
|
||||
Specifically:
|
||||
|
||||
- Major version `2.*.*` is reserved for the transition from nptyping to using
|
||||
`TypeVarTuple`, `Generic`, and `Protocol`. Until `2.*.*`...
|
||||
- breaking changes will be indicated with an advance in `MINOR`
|
||||
version, taking the place of `MAJOR` in semver
|
||||
- backwards-compatible bugfixes **and** additions in functionality
|
||||
will be indicated by a `PATCH` release, taking the place of `MINOR` and
|
||||
`PATCH` in semver.
|
||||
- After `2.*.*`, semver as usual will resume
|
||||
|
||||
You are encouraged to set an upper bound on your version dependencies until
|
||||
we pass `2.*.*`, as the major function of numpydantic is stable,
|
||||
but there is still a decent amount of jostling things around to be expected.
|
||||
|
||||
|
||||
### API Stability
|
||||
|
||||
- All breaking changes to the **public API** will be signaled by a major
|
||||
version's worth of deprecation warnings
|
||||
- All breaking changes to the **development API** will be signaled by a
|
||||
minor version's worth of deprecation warnings.
|
||||
- Changes to the remainder of the package, whether marked as private with a
|
||||
leading underscore or not, including the import structure of the package,
|
||||
are not considered part of the API and should not be relied on as stable
|
||||
until explicitly marked otherwise.
|
||||
|
||||
#### Public API
|
||||
|
||||
**Only the {class}`.NDArray` and {class}`.Shape` classes should be considered
|
||||
part of the stable public API.**
|
||||
|
||||
All associated functionality for validation should also be considered
|
||||
a stable part of the `NDArray` and `Shape` classes - functionality
|
||||
will only be added here, and the departure for the string-form of the
|
||||
shape specifications (and its removal) will take place in `v3.*.*`
|
||||
|
||||
End-users of numpydantic should pin an upper bound for the `MAJOR` version
|
||||
until after `v2.*.*`, after which time it is up to your discretion -
|
||||
no breaking changes are planned, but they would be signaled by a major version change.
|
||||
|
||||
#### Development API
|
||||
|
||||
**Only the {class}`.Interface` class and its subclasses,
|
||||
along with the Public API,
|
||||
should be considered part of the stable development API.**
|
||||
|
||||
The `Interface` class is the primary point of external development expected
|
||||
for numpydantic. It is still somewhat in flux, but it is prioritized for stability
|
||||
and deprecation warnings above the rest of the package.
|
||||
|
||||
Dependent packages that define their own `Interface`s should pin an upper
|
||||
bound for the `PATCH` version until `2.*.*`, and afterwards likely pin a `MINOR` version.
|
||||
Tests are designed such that it should be easy to test major features against
|
||||
each interface, and that work is also ongoing. Once the test suite reaches
|
||||
maturity, it should be possible for any downstream interfaces to simply use those to
|
||||
ensure they are compatible with the latest version.
|
||||
|
||||
## Release Schedule
|
||||
|
||||
There is no release schedule. Versions are released according to need and available labor.
|
||||
|
||||
## Contributing
|
||||
|
||||
### Dev environment
|
||||
|
||||
```{todo}
|
||||
Document dev environment
|
||||
|
||||
Really it's very simple, you just clone a fork and install
|
||||
the `dev` environment like `pip install '.[dev]'`
|
||||
```
|
||||
|
||||
### Pull Requests
|
||||
|
||||
```{todo}
|
||||
Document pull requests if we ever receive one
|
||||
```
|
|
@ -1,11 +0,0 @@
|
|||
# Hooks
|
||||
|
||||
What hooks do we want to expose to downstream users so they can use this without needing
|
||||
to override everything?
|
||||
|
||||
```{todo}
|
||||
**NWB Compatibility**
|
||||
|
||||
**Precision:** NWB allows for a sort of hierarchy of type specification -
|
||||
a less precise type also allows the data to be specified in a more precise type
|
||||
```
|
|
@ -86,8 +86,8 @@ isinstance(np.zeros((1,2,3), dtype=float), array_type)
|
|||
and a simple extension system to make it work with whatever else you want! Provides
|
||||
a uniform and transparent interface so you can both use common indexing operations
|
||||
and also access any special features of a given array library.
|
||||
- **Serialization** - Dump an array as a JSON-compatible array-of-arrays with enough metadata to be able to
|
||||
recreate the model in the native format
|
||||
- [**Serialization**](./serialization.md) - Dump an array as a JSON-compatible array-of-arrays with enough metadata to be able to
|
||||
recreate the model in the native format. Full roundtripping is supported :)
|
||||
- **Schema Generation** - Correct JSON Schema for arrays, complete with shape and dtype constraints, to
|
||||
make your models interoperable
|
||||
|
||||
|
@ -473,9 +473,8 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
|
|||
|
||||
design
|
||||
syntax
|
||||
serialization
|
||||
interfaces
|
||||
todo
|
||||
changelog
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
|
@ -489,13 +488,23 @@ api/dtype
|
|||
api/ndarray
|
||||
api/maps
|
||||
api/meta
|
||||
api/monkeypatch
|
||||
api/schema
|
||||
api/serialization
|
||||
api/shape
|
||||
api/types
|
||||
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 2
|
||||
:caption: Meta
|
||||
:hidden: true
|
||||
|
||||
changelog
|
||||
development
|
||||
todo
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [`jaxtyping`](https://docs.kidger.site/jaxtyping/)
|
||||
|
|
|
@ -46,6 +46,11 @@ for interfaces to implement custom behavior that matches the array format.
|
|||
|
||||
{meth}`.Interface.validate` calls the following methods, in order:
|
||||
|
||||
A method to deserialize the array dumped with a {func}`~pydantic.BaseModel.model_dump_json`
|
||||
with `round_trip = True` (see [serialization](./serialization.md))
|
||||
|
||||
- {meth}`.Interface.deserialize`
|
||||
|
||||
An initial hook for modifying the input data before validation, eg.
|
||||
if it needs to be coerced or wrapped in some proxy class. This method
|
||||
should accept all and only the types specified in that interface's
|
||||
|
|
332
docs/serialization.md
Normal file
332
docs/serialization.md
Normal file
|
@ -0,0 +1,332 @@
|
|||
---
|
||||
file_format: mystnb
|
||||
mystnb:
|
||||
output_stderr: remove
|
||||
render_text_lexer: python
|
||||
render_markdown_format: myst
|
||||
myst:
|
||||
enable_extensions: ["colon_fence"]
|
||||
---
|
||||
|
||||
# Serialization
|
||||
|
||||
## Python
|
||||
|
||||
In most cases, dumping to python should work as expected.
|
||||
|
||||
When a given array framework doesn't provide a tidy means of interacting
|
||||
with it from python, we substitute a proxy class like {class}`.hdf5.H5Proxy`,
|
||||
but aside from that numpydantic {class}`.NDArray` annotations
|
||||
should be passthrough when using {func}`~pydantic.BaseModel.model_dump` .
|
||||
|
||||
## JSON
|
||||
|
||||
JSON is the ~ ♥ fun one ♥ ~
|
||||
|
||||
There isn't necessarily a single optimal way to represent all possible
|
||||
arrays in JSON. The standard way that n-dimensional arrays are rendered
|
||||
in json is as a list-of-lists (or array of arrays, in JSON parlance),
|
||||
but that's almost never what is desirable, especially for large arrays.
|
||||
|
||||
### Normal Style[^normalstyle]
|
||||
|
||||
Lists-of-lists are the standard, however, so it is the default behavior
|
||||
for all interfaces, and all interfaces must support it.
|
||||
|
||||
```{code-cell}
|
||||
---
|
||||
tags: [hide-cell]
|
||||
---
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
from numpydantic import NDArray, Shape
|
||||
from numpydantic.interface.dask import DaskJsonDict
|
||||
from numpydantic.interface.numpy import NumpyJsonDict
|
||||
import numpy as np
|
||||
import dask.array as da
|
||||
import zarr
|
||||
import json
|
||||
from rich import print
|
||||
from rich.console import Console
|
||||
|
||||
def print_json(string:str):
|
||||
data = json.loads(string)
|
||||
console = Console(width=74)
|
||||
console.print(data)
|
||||
```
|
||||
|
||||
For our humble model:
|
||||
|
||||
```{code-cell}
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray
|
||||
```
|
||||
|
||||
We should get the same thing for each interface:
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=[[1,2],[3,4]])
|
||||
print(model.model_dump_json())
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=da.array([[1,2],[3,4]], dtype=int))
|
||||
print(model.model_dump_json())
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=zarr.array([[1,2],[3,4]], dtype=int))
|
||||
print(model.model_dump_json())
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array="data/test.avi")
|
||||
print(model.model_dump_json())
|
||||
```
|
||||
|
||||
(ok maybe not that last one, since the video reader still incorrectly
|
||||
reads grayscale videos as BGR values for now, but you get the idea)
|
||||
|
||||
Since by default arrays are dumped into unadorned JSON arrays,
|
||||
when they are re-validated, they will always be handled by the
|
||||
{class}`.NumpyInterface`
|
||||
|
||||
```{code-cell}
|
||||
dask_array = da.array([[1,2],[3,4]], dtype=int)
|
||||
model = MyModel(array=dask_array)
|
||||
type(model.array)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
model_json = model.model_dump_json()
|
||||
deserialized_model = MyModel.model_validate_json(model_json)
|
||||
type(deserialized_model.array)
|
||||
```
|
||||
|
||||
All information about `dtype` will be lost, and numbers will be either parsed
|
||||
as `int` ({class}`numpy.int64`) or `float` ({class}`numpy.float64`)
|
||||
|
||||
## Roundtripping
|
||||
|
||||
To roundtrip make arrays round-trippable, use the `round_trip` argument
|
||||
to {func}`~pydantic.BaseModel.model_dump_json`.
|
||||
|
||||
All the following should return an equivalent array from the same
|
||||
file/etc. as the source array when using
|
||||
`{func}`~pydantic.BaseModel.model_validate_json`` .
|
||||
|
||||
```{code-cell}
|
||||
print_json(model.model_dump_json(round_trip=True))
|
||||
```
|
||||
|
||||
Each interface must implement a dataclass that describes a
|
||||
json-able roundtrip form (see {class}`.interface.JsonDict`).
|
||||
|
||||
That dataclass then has a {meth}`JsonDict.is_valid` method that checks
|
||||
whether an incoming dict matches its schema
|
||||
|
||||
```{code-cell}
|
||||
roundtrip_json = json.loads(model.model_dump_json(round_trip=True))['array']
|
||||
DaskJsonDict.is_valid(roundtrip_json)
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
NumpyJsonDict.is_valid(roundtrip_json)
|
||||
```
|
||||
|
||||
#### Controlling paths
|
||||
|
||||
When possible, the full content of the array is omitted in favor
|
||||
of the path to the file that provided it.
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array="data/test.avi")
|
||||
print_json(model.model_dump_json(round_trip=True))
|
||||
```
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=("data/test.h5", "/data"))
|
||||
print_json(model.model_dump_json(round_trip=True))
|
||||
```
|
||||
|
||||
You may notice the relative, rather than absolute paths.
|
||||
|
||||
|
||||
We expect that when people are dumping data to json in this roundtripped
|
||||
form that they are either working locally
|
||||
(e.g. transmitting an array specification across a socket in multiprocessing
|
||||
or in a computing cluster),
|
||||
or exporting to some directory structure of data,
|
||||
where they are making an index file that refers to datasets in a directory
|
||||
as part of a data standard or vernacular format.
|
||||
|
||||
By default, numpydantic uses the current working directory as the root to find
|
||||
paths relative to, but this can be controlled by the [`relative_to`](#relative_to)
|
||||
context parameter:
|
||||
|
||||
For example if you're working on data in many subdirectories,
|
||||
you might want to serialize relative to each of them:
|
||||
|
||||
```{code-cell}
|
||||
print_json(
|
||||
model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"relative_to": Path('./data')}
|
||||
))
|
||||
```
|
||||
|
||||
Or in the other direction:
|
||||
|
||||
```{code-cell}
|
||||
print_json(
|
||||
model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"relative_to": Path('../')}
|
||||
))
|
||||
```
|
||||
|
||||
Or you might be working in some completely different place,
|
||||
numpydantic will try and find the way from here to there as long as it exists,
|
||||
even if it means traversing to the root of the readthedocs filesystem
|
||||
|
||||
```{code-cell}
|
||||
print_json(
|
||||
model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"relative_to": Path('/a/long/distance/directory')}
|
||||
))
|
||||
```
|
||||
|
||||
You can force absolute paths with the `absolute_paths` context parameter
|
||||
|
||||
```{code-cell}
|
||||
print_json(
|
||||
model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"absolute_paths": True}
|
||||
))
|
||||
```
|
||||
|
||||
#### Durable Interface Metadata
|
||||
|
||||
Numpydantic tries to be [stable](./development.md#api-stability),
|
||||
but we're not perfect. To preserve the full information about the
|
||||
interface that's needed to load the data referred to by the value,
|
||||
use the `mark_interface` contest parameter:
|
||||
|
||||
```{code-cell}
|
||||
print_json(
|
||||
model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"mark_interface": True}
|
||||
))
|
||||
```
|
||||
|
||||
When an array marked with the interface is deserialized,
|
||||
it short-circuits the {meth}`.Interface.match` method,
|
||||
attempting to directly return the indicated interface as long as the
|
||||
array dumped in `value` still satisfies that interface's {meth}`.Interface.check`
|
||||
method. Arrays dumped *without* `round_trip=True` might *not* validate with
|
||||
the originating model, even when marked -- eg. an array dumped without `round_trip`
|
||||
will be revalidated as a numpy array for the same reasons it is everywhere else,
|
||||
since all connection to the source file is lost.
|
||||
|
||||
```{todo}
|
||||
Currently, the version of the package the interface is from (usually `numpydantic`)
|
||||
will be stored, but there is no means of resolving it on the fly.
|
||||
If there is a mismatch between the marked interface description and the interface
|
||||
that was matched on revalidation, a warning is emitted, but validation
|
||||
attempts to proceed as normal.
|
||||
|
||||
This feature is for extra-verbose provenance, rather than airtight serialization
|
||||
and deserialization, but PRs welcome if you would like to make it be that way.
|
||||
```
|
||||
|
||||
```{todo}
|
||||
We will also add a separate `mark_version` parameter for marking
|
||||
the specific version of the relevant interface package, like `zarr`, or `numpy`,
|
||||
patience.
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Context parameters
|
||||
|
||||
A reference listing of all the things that can be passed to
|
||||
{func}`~pydantic.BaseModel.model_dump_json`
|
||||
|
||||
|
||||
### `mark_interface`
|
||||
|
||||
Nest an additional layer of metadata for unambigous serialization that
|
||||
can be absolutely resolved across numpydantic versions
|
||||
(for now for downstream metadata purposes only,
|
||||
automatically resolving to a numpydantic version is not yet possible.)
|
||||
|
||||
Supported interfaces:
|
||||
|
||||
- (all)
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=[[1,2],[3,4]])
|
||||
data = model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"mark_interface": True}
|
||||
)
|
||||
print_json(data)
|
||||
```
|
||||
|
||||
### `absolute_paths`
|
||||
|
||||
Make all paths (that exist) absolute.
|
||||
|
||||
Supported interfaces:
|
||||
|
||||
- (all)
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=("data/test.h5", "/data"))
|
||||
data = model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"absolute_paths": True}
|
||||
)
|
||||
print_json(data)
|
||||
```
|
||||
|
||||
### `relative_to`
|
||||
|
||||
Make all paths (that exist) relative to the given path
|
||||
|
||||
Supported interfaces:
|
||||
|
||||
- (all)
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=("data/test.h5", "/data"))
|
||||
data = model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"relative_to": Path('../')}
|
||||
)
|
||||
print_json(data)
|
||||
```
|
||||
|
||||
### `dump_array`
|
||||
|
||||
Dump the raw array contents when serializing to json inside an `array` field
|
||||
|
||||
Supported interfaces:
|
||||
- {class}`.ZarrInterface`
|
||||
|
||||
```{code-cell}
|
||||
model = MyModel(array=("data/test.zarr",))
|
||||
data = model.model_dump_json(
|
||||
round_trip=True,
|
||||
context={"dump_array": True}
|
||||
)
|
||||
print_json(data)
|
||||
```
|
||||
|
||||
|
||||
|
||||
[^normalstyle]: o ya we're posting JSON [normal style](https://normal.style)
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "numpydantic"
|
||||
version = "1.5.3"
|
||||
version = "1.6.0"
|
||||
description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
|
||||
authors = [
|
||||
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
||||
|
@ -73,12 +73,15 @@ tests = [
|
|||
"coveralls<4.0.0,>=3.3.1",
|
||||
]
|
||||
docs = [
|
||||
"numpydantic[arrays]",
|
||||
"sphinx<8.0.0,>=7.2.6",
|
||||
"furo>=2024.1.29",
|
||||
"myst-parser<3.0.0,>=2.0.0",
|
||||
"autodoc-pydantic<3.0.0,>=2.0.1",
|
||||
"sphinx-design<1.0.0,>=0.5.0",
|
||||
"sphinxcontrib-mermaid>=0.9.2",
|
||||
"myst-nb>=1.1.1",
|
||||
"rich>=13.8.1",
|
||||
]
|
||||
dev = [
|
||||
"numpydantic[tests,docs]",
|
||||
|
@ -109,6 +112,18 @@ filterwarnings = [
|
|||
# nptyping's alias warnings
|
||||
'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*'
|
||||
]
|
||||
markers = [
|
||||
"dtype: mark test related to dtype validation",
|
||||
"shape: mark test related to shape validation",
|
||||
"json_schema: mark test related to json schema generation",
|
||||
"serialization: mark test related to serialization",
|
||||
"proxy: test for proxy class in any interface",
|
||||
"dask: dask interface",
|
||||
"hdf5: hdf5 interface",
|
||||
"numpy: numpy interface",
|
||||
"video: video interface",
|
||||
"zarr: zarr interface",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
|
|
|
@ -25,3 +25,7 @@ class NoMatchError(MatchError):
|
|||
|
||||
class TooManyMatchesError(MatchError):
|
||||
"""Too many matches found by :class:`.Interface.match`"""
|
||||
|
||||
|
||||
class MarkMismatchError(MatchError):
|
||||
"""A serialized :class:`.InterfaceMark` doesn't match the receiving interface"""
|
||||
|
|
|
@ -4,15 +4,23 @@ Interfaces between nptyping types and array backends
|
|||
|
||||
from numpydantic.interface.dask import DaskInterface
|
||||
from numpydantic.interface.hdf5 import H5Interface
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import (
|
||||
Interface,
|
||||
InterfaceMark,
|
||||
JsonDict,
|
||||
MarkedJson,
|
||||
)
|
||||
from numpydantic.interface.numpy import NumpyInterface
|
||||
from numpydantic.interface.video import VideoInterface
|
||||
from numpydantic.interface.zarr import ZarrInterface
|
||||
|
||||
__all__ = [
|
||||
"Interface",
|
||||
"DaskInterface",
|
||||
"H5Interface",
|
||||
"Interface",
|
||||
"InterfaceMark",
|
||||
"JsonDict",
|
||||
"MarkedJson",
|
||||
"NumpyInterface",
|
||||
"VideoInterface",
|
||||
"ZarrInterface",
|
||||
|
|
|
@ -2,34 +2,73 @@
|
|||
Interface for Dask arrays
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Iterable, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
||||
try:
|
||||
from dask.array import from_array
|
||||
from dask.array.core import Array as DaskArray
|
||||
except ImportError: # pragma: no cover
|
||||
DaskArray = None
|
||||
|
||||
|
||||
def _as_tuple(a_list: Any) -> tuple:
|
||||
"""Make a list of list into a tuple of tuples"""
|
||||
return tuple(
|
||||
[_as_tuple(item) if isinstance(item, list) else item for item in a_list]
|
||||
)
|
||||
|
||||
|
||||
class DaskJsonDict(JsonDict):
|
||||
"""
|
||||
Round-trip json serialized form of a dask array
|
||||
"""
|
||||
|
||||
type: Literal["dask"]
|
||||
name: str
|
||||
chunks: Iterable[tuple[int, ...]]
|
||||
dtype: str
|
||||
array: list
|
||||
|
||||
def to_array_input(self) -> DaskArray:
|
||||
"""Construct a dask array"""
|
||||
np_array = np.array(self.array, dtype=self.dtype)
|
||||
array = from_array(
|
||||
np_array,
|
||||
name=self.name,
|
||||
chunks=_as_tuple(self.chunks),
|
||||
)
|
||||
return array
|
||||
|
||||
|
||||
class DaskInterface(Interface):
|
||||
"""
|
||||
Interface for Dask :class:`~dask.array.core.Array`
|
||||
"""
|
||||
|
||||
input_types = (DaskArray,)
|
||||
name = "dask"
|
||||
input_types = (DaskArray, dict)
|
||||
return_type = DaskArray
|
||||
json_model = DaskJsonDict
|
||||
|
||||
@classmethod
|
||||
def check(cls, array: Any) -> bool:
|
||||
"""
|
||||
check if array is a dask array
|
||||
"""
|
||||
return DaskArray is not None and isinstance(array, DaskArray)
|
||||
if DaskArray is None: # pragma: no cover - no tests for interface deps atm
|
||||
return False
|
||||
elif isinstance(array, DaskArray):
|
||||
return True
|
||||
elif isinstance(array, dict):
|
||||
return DaskJsonDict.is_valid(array)
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||
|
@ -43,7 +82,7 @@ class DaskInterface(Interface):
|
|||
@classmethod
|
||||
def to_json(
|
||||
cls, array: DaskArray, info: Optional[SerializationInfo] = None
|
||||
) -> list:
|
||||
) -> Union[List, DaskJsonDict]:
|
||||
"""
|
||||
Convert an array to a JSON serializable array by first converting to a numpy
|
||||
array and then to a list.
|
||||
|
@ -56,4 +95,14 @@ class DaskInterface(Interface):
|
|||
method of serialization here using the python object itself rather than
|
||||
its JSON representation.
|
||||
"""
|
||||
return np.array(array).tolist()
|
||||
np_array = np.array(array)
|
||||
as_json = np_array.tolist()
|
||||
if info.round_trip:
|
||||
as_json = DaskJsonDict(
|
||||
type=cls.name,
|
||||
array=as_json,
|
||||
name=array.name,
|
||||
chunks=array.chunks,
|
||||
dtype=str(np_array.dtype),
|
||||
)
|
||||
return as_json
|
||||
|
|
|
@ -47,7 +47,7 @@ from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Un
|
|||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType, NDArrayType
|
||||
|
||||
try:
|
||||
|
@ -76,6 +76,20 @@ class H5ArrayPath(NamedTuple):
|
|||
"""Refer to a specific field within a compound dtype"""
|
||||
|
||||
|
||||
class H5JsonDict(JsonDict):
|
||||
"""Round-trip Json-able version of an HDF5 dataset"""
|
||||
|
||||
file: str
|
||||
path: str
|
||||
field: Optional[str] = None
|
||||
|
||||
def to_array_input(self) -> H5ArrayPath:
|
||||
"""Construct an :class:`.H5ArrayPath`"""
|
||||
return H5ArrayPath(
|
||||
**{k: v for k, v in self.model_dump().items() if k in H5ArrayPath._fields}
|
||||
)
|
||||
|
||||
|
||||
class H5Proxy:
|
||||
"""
|
||||
Proxy class to mimic numpy-like array behavior with an HDF5 array
|
||||
|
@ -106,10 +120,11 @@ class H5Proxy:
|
|||
annotation_dtype: Optional[DtypeType] = None,
|
||||
):
|
||||
self._h5f = None
|
||||
self.file = Path(file)
|
||||
self.file = Path(file).resolve()
|
||||
self.path = path
|
||||
self.field = field
|
||||
self._annotation_dtype = annotation_dtype
|
||||
self._h5arraypath = H5ArrayPath(self.file, self.path, self.field)
|
||||
|
||||
def array_exists(self) -> bool:
|
||||
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
||||
|
@ -134,10 +149,20 @@ class H5Proxy:
|
|||
else:
|
||||
return obj.dtype[self.field]
|
||||
|
||||
def __getattr__(self, item: str):
|
||||
def __array__(self) -> np.ndarray:
|
||||
"""To a numpy array"""
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
return getattr(obj, item)
|
||||
return obj[:]
|
||||
|
||||
def __getattr__(self, item: str):
|
||||
if item == "__name__":
|
||||
# special case for H5Proxies that don't refer to a real file during testing
|
||||
return "H5Proxy"
|
||||
with h5py.File(self.file, "r") as h5f:
|
||||
obj = h5f.get(self.path)
|
||||
val = getattr(obj, item)
|
||||
return val
|
||||
|
||||
def __getitem__(
|
||||
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||
|
@ -205,6 +230,15 @@ class H5Proxy:
|
|||
"""self.shape[0]"""
|
||||
return self.shape[0]
|
||||
|
||||
def __eq__(self, other: "H5Proxy") -> bool:
|
||||
"""
|
||||
Check that we are referring to the same hdf5 array
|
||||
"""
|
||||
if isinstance(other, H5Proxy):
|
||||
return self._h5arraypath == other._h5arraypath
|
||||
else:
|
||||
raise ValueError("Can only compare equality of two H5Proxies")
|
||||
|
||||
def open(self, mode: str = "r") -> "h5py.Dataset":
|
||||
"""
|
||||
Return the opened :class:`h5py.Dataset` object
|
||||
|
@ -244,8 +278,10 @@ class H5Interface(Interface):
|
|||
passthrough numpy-like interface to the dataset.
|
||||
"""
|
||||
|
||||
name = "hdf5"
|
||||
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
|
||||
return_type = H5Proxy
|
||||
json_model = H5JsonDict
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
|
@ -261,6 +297,13 @@ class H5Interface(Interface):
|
|||
if isinstance(array, (H5ArrayPath, H5Proxy)):
|
||||
return True
|
||||
|
||||
if isinstance(array, dict):
|
||||
if array.get("type", False) == cls.name:
|
||||
return True
|
||||
# continue checking if dict contains an hdf5 file
|
||||
file = array.get("file", "")
|
||||
array = (file, "")
|
||||
|
||||
if isinstance(array, (tuple, list)) and len(array) in (2, 3):
|
||||
# check that the first arg is an hdf5 file
|
||||
try:
|
||||
|
@ -342,21 +385,27 @@ class H5Interface(Interface):
|
|||
@classmethod
|
||||
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
|
||||
"""
|
||||
Dump to a dictionary containing
|
||||
Render HDF5 array as JSON
|
||||
|
||||
If ``round_trip == True``, we dump just the proxy info, a dictionary like:
|
||||
|
||||
* ``file``: :attr:`.file`
|
||||
* ``path``: :attr:`.path`
|
||||
* ``attrs``: Any HDF5 attributes on the dataset
|
||||
* ``array``: The array as a list of lists
|
||||
|
||||
Otherwise, we dump the array as a list of lists
|
||||
"""
|
||||
if info.round_trip:
|
||||
as_json = {
|
||||
"type": cls.name,
|
||||
}
|
||||
as_json.update(array._h5arraypath._asdict())
|
||||
else:
|
||||
try:
|
||||
dset = array.open()
|
||||
meta = {
|
||||
"file": array.file,
|
||||
"path": array.path,
|
||||
"attrs": dict(dset.attrs),
|
||||
"array": dset[:].tolist(),
|
||||
}
|
||||
return meta
|
||||
as_json = dset[:].tolist()
|
||||
finally:
|
||||
array.close()
|
||||
|
||||
return as_json
|
||||
|
|
|
@ -2,15 +2,20 @@
|
|||
Base Interface metaclass
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from operator import attrgetter
|
||||
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
from pydantic import BaseModel, SerializationInfo, ValidationError
|
||||
|
||||
from numpydantic.exceptions import (
|
||||
DtypeError,
|
||||
MarkMismatchError,
|
||||
NoMatchError,
|
||||
ShapeError,
|
||||
TooManyMatchesError,
|
||||
|
@ -19,6 +24,130 @@ from numpydantic.shape import check_shape
|
|||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
|
||||
T = TypeVar("T", bound=NDArrayType)
|
||||
U = TypeVar("U", bound="JsonDict")
|
||||
V = TypeVar("V") # input type
|
||||
W = TypeVar("W") # Any type in handle_input
|
||||
|
||||
|
||||
class InterfaceMark(BaseModel):
|
||||
"""JSON-able mark to be able to round-trip json dumps"""
|
||||
|
||||
module: str
|
||||
cls: str
|
||||
name: str
|
||||
version: str
|
||||
|
||||
def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool:
|
||||
"""
|
||||
Check that a given interface matches the mark.
|
||||
|
||||
Args:
|
||||
cls (Type): Interface type to check
|
||||
raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
|
||||
is incorrect
|
||||
|
||||
Returns:
|
||||
bool
|
||||
|
||||
Raises:
|
||||
:class:`.MarkMismatchError` if requested by ``raise_on_error``
|
||||
for an invalid match
|
||||
"""
|
||||
mark = cls.mark_interface()
|
||||
valid = self == mark
|
||||
if not valid and raise_on_error:
|
||||
raise MarkMismatchError(
|
||||
"Mismatch between serialized mark and current interface, "
|
||||
f"Serialized: {self}; current: {cls}"
|
||||
)
|
||||
return valid
|
||||
|
||||
def match_by_name(self) -> Optional[Type["Interface"]]:
|
||||
"""
|
||||
Try to find a matching interface by its name, returning it if found,
|
||||
or None if not found.
|
||||
"""
|
||||
for i in Interface.interfaces(sort=False):
|
||||
if i.name == self.name:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
class JsonDict(BaseModel):
|
||||
"""
|
||||
Representation of array when dumped with round_trip == True.
|
||||
"""
|
||||
|
||||
type: str
|
||||
|
||||
@abstractmethod
|
||||
def to_array_input(self) -> V:
|
||||
"""
|
||||
Convert this roundtrip specifier to the relevant input class
|
||||
(one of the ``input_types`` of an interface).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
|
||||
"""
|
||||
Check whether a given dictionary matches this JsonDict specification
|
||||
|
||||
Args:
|
||||
val (dict): The dictionary to check for validity
|
||||
raise_on_error (bool): If ``True``, raise the validation error
|
||||
rather than returning a bool. (default: ``False``)
|
||||
|
||||
Returns:
|
||||
bool - true if valid, false if not
|
||||
"""
|
||||
try:
|
||||
_ = cls.model_validate(val)
|
||||
return True
|
||||
except ValidationError as e:
|
||||
if raise_on_error:
|
||||
raise e
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
|
||||
"""
|
||||
Handle input that is the json serialized roundtrip version
|
||||
(from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
|
||||
converting it to the input format with :meth:`.JsonDict.to_array_input`
|
||||
or passing it through if not applicable
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
value = cls(**value).to_array_input()
|
||||
elif isinstance(value, cls):
|
||||
value = value.to_array_input()
|
||||
return value
|
||||
|
||||
|
||||
class MarkedJson(BaseModel):
|
||||
"""
|
||||
Model of JSON dumped with an additional interface mark
|
||||
with ``model_dump_json({'mark_interface': True})``
|
||||
"""
|
||||
|
||||
interface: InterfaceMark
|
||||
value: Union[list, dict]
|
||||
"""
|
||||
Inner value of the array, we don't validate for JsonDict here,
|
||||
that should be downstream from us for performance reasons
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]:
|
||||
"""
|
||||
Try to cast to MarkedJson if applicable, otherwise return input
|
||||
"""
|
||||
if isinstance(value, dict) and "interface" in value and "value" in value:
|
||||
try:
|
||||
value = MarkedJson(**value)
|
||||
except ValidationError:
|
||||
# fine, just not a MarkedJson dict even if it looks like one
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
class Interface(ABC, Generic[T]):
|
||||
|
@ -30,7 +159,7 @@ class Interface(ABC, Generic[T]):
|
|||
return_type: Type[T]
|
||||
priority: int = 0
|
||||
|
||||
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
|
||||
def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
|
||||
self.shape = shape
|
||||
self.dtype = dtype
|
||||
|
||||
|
@ -40,6 +169,7 @@ class Interface(ABC, Generic[T]):
|
|||
|
||||
Calls the methods, in order:
|
||||
|
||||
* array = :meth:`.deserialize` (array)
|
||||
* array = :meth:`.before_validation` (array)
|
||||
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
||||
override if eg. the dtype is not contained in ``array.dtype``
|
||||
|
@ -74,6 +204,8 @@ class Interface(ABC, Generic[T]):
|
|||
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
||||
of :class:`.InterfaceError` )
|
||||
"""
|
||||
array = self.deserialize(array)
|
||||
|
||||
array = self.before_validation(array)
|
||||
|
||||
dtype = self.get_dtype(array)
|
||||
|
@ -86,8 +218,32 @@ class Interface(ABC, Generic[T]):
|
|||
self.raise_for_shape(shape_valid, shape)
|
||||
|
||||
array = self.after_validation(array)
|
||||
|
||||
return array
|
||||
|
||||
def deserialize(self, array: Any) -> Union[V, Any]:
|
||||
"""
|
||||
If given a JSON serialized version of the array,
|
||||
deserialize it first.
|
||||
|
||||
If a roundtrip-serialized :class:`.JsonDict`,
|
||||
pass to :meth:`.JsonDict.handle_input`.
|
||||
|
||||
If a roundtrip-serialized :class:`.MarkedJson`,
|
||||
unpack mark, check for validity, warn if not,
|
||||
and try to continue with validation
|
||||
"""
|
||||
if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
|
||||
try:
|
||||
marked_array.interface.is_valid(self.__class__, raise_on_error=True)
|
||||
except MarkMismatchError as e:
|
||||
warnings.warn(
|
||||
str(e) + "\nAttempting to continue validation...", stacklevel=2
|
||||
)
|
||||
array = marked_array.value
|
||||
|
||||
return self.json_model.handle_input(array)
|
||||
|
||||
def before_validation(self, array: Any) -> NDArrayType:
|
||||
"""
|
||||
Optional step pre-validation that coerces the input into a type that can be
|
||||
|
@ -117,8 +273,6 @@ class Interface(ABC, Generic[T]):
|
|||
"""
|
||||
Validate the dtype of the given array, returning
|
||||
``True`` if valid, ``False`` if not.
|
||||
|
||||
|
||||
"""
|
||||
if self.dtype is Any:
|
||||
return True
|
||||
|
@ -211,17 +365,48 @@ class Interface(ABC, Generic[T]):
|
|||
installed, etc.)
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Short name for this interface
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def json_model(self) -> JsonDict:
|
||||
"""
|
||||
The :class:`.JsonDict` model used for roundtripping
|
||||
JSON serialization
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def to_json(
|
||||
cls, array: Type[T], info: Optional[SerializationInfo] = None
|
||||
) -> Union[list, dict]:
|
||||
@abstractmethod
|
||||
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
|
||||
"""
|
||||
Convert an array of :attr:`.return_type` to a JSON-compatible format using
|
||||
base python types
|
||||
"""
|
||||
if not isinstance(array, np.ndarray): # pragma: no cover
|
||||
array = np.array(array)
|
||||
return array.tolist()
|
||||
|
||||
@classmethod
|
||||
def mark_json(cls, array: Union[list, dict]) -> dict:
|
||||
"""
|
||||
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
|
||||
add additional annotations that would allow the serialized array to be
|
||||
roundtripped.
|
||||
|
||||
Default is just to add an :class:`.InterfaceMark`
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from pprint import pprint
|
||||
>>> pprint(Interface.mark_json([1.0, 2.0]))
|
||||
{'interface': {'cls': 'Interface',
|
||||
'module': 'numpydantic.interface.interface',
|
||||
'version': '1.2.2'},
|
||||
'value': [1.0, 2.0]}
|
||||
"""
|
||||
return {"interface": cls.mark_interface(), "value": array}
|
||||
|
||||
@classmethod
|
||||
def interfaces(
|
||||
|
@ -274,6 +459,28 @@ class Interface(ABC, Generic[T]):
|
|||
|
||||
return tuple(in_types)
|
||||
|
||||
@classmethod
|
||||
def match_mark(cls, array: Any) -> Optional[Type["Interface"]]:
|
||||
"""
|
||||
Match a marked JSON dump of this array to the interface that it indicates.
|
||||
|
||||
First find an interface that matches by name, and then run its
|
||||
``check`` method, because arrays can be dumped with a mark
|
||||
but without ``round_trip == True`` (and thus can't necessarily
|
||||
use the same interface that they were dumped with)
|
||||
|
||||
Returns:
|
||||
Interface if match found, None otherwise
|
||||
"""
|
||||
mark = MarkedJson.try_cast(array)
|
||||
if not isinstance(mark, MarkedJson):
|
||||
return None
|
||||
|
||||
interface = mark.interface.match_by_name()
|
||||
if interface is not None and interface.check(mark.value):
|
||||
return interface
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
||||
"""
|
||||
|
@ -291,11 +498,18 @@ class Interface(ABC, Generic[T]):
|
|||
check each interface (as ordered by its ``priority`` , decreasing),
|
||||
and return on the first match.
|
||||
"""
|
||||
# Shortcircuit match if this is a marked json dump
|
||||
array = MarkedJson.try_cast(array)
|
||||
if (match := cls.match_mark(array)) is not None:
|
||||
return match
|
||||
elif isinstance(array, MarkedJson):
|
||||
array = array.value
|
||||
|
||||
# first try and find a non-numpy interface, since the numpy interface
|
||||
# will try and load the array into memory in its check method
|
||||
interfaces = cls.interfaces()
|
||||
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
|
||||
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
|
||||
non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
|
||||
np_interface = [i for i in interfaces if i.name == "numpy"][0]
|
||||
|
||||
if fast:
|
||||
matches = []
|
||||
|
@ -335,3 +549,29 @@ class Interface(ABC, Generic[T]):
|
|||
raise NoMatchError(f"No matching interfaces found for output {array}")
|
||||
else:
|
||||
return matches[0]
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def mark_interface(cls) -> InterfaceMark:
|
||||
"""
|
||||
Create an interface mark indicating this interface for validation after
|
||||
JSON serialization with ``round_trip==True``
|
||||
"""
|
||||
interface_module = inspect.getmodule(cls)
|
||||
interface_module = (
|
||||
None if interface_module is None else interface_module.__name__
|
||||
)
|
||||
try:
|
||||
v = (
|
||||
None
|
||||
if interface_module is None
|
||||
else version(interface_module.split(".")[0])
|
||||
)
|
||||
except (
|
||||
PackageNotFoundError
|
||||
): # pragma: no cover - no tests for missing interface deps
|
||||
v = None
|
||||
|
||||
return InterfaceMark(
|
||||
module=interface_module, cls=cls.__name__, name=cls.name, version=v
|
||||
)
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
Interface to numpy arrays
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
@ -18,13 +20,31 @@ except ImportError: # pragma: no cover
|
|||
np = None
|
||||
|
||||
|
||||
class NumpyJsonDict(JsonDict):
|
||||
"""
|
||||
JSON-able roundtrip representation of numpy array
|
||||
"""
|
||||
|
||||
type: Literal["numpy"]
|
||||
dtype: str
|
||||
array: list
|
||||
|
||||
def to_array_input(self) -> ndarray:
|
||||
"""
|
||||
Construct a numpy array
|
||||
"""
|
||||
return np.array(self.array, dtype=self.dtype)
|
||||
|
||||
|
||||
class NumpyInterface(Interface):
|
||||
"""
|
||||
Numpy :class:`~numpy.ndarray` s!
|
||||
"""
|
||||
|
||||
name = "numpy"
|
||||
input_types = (ndarray, list)
|
||||
return_type = ndarray
|
||||
json_model = NumpyJsonDict
|
||||
priority = -999
|
||||
"""
|
||||
The numpy interface is usually the interface of last resort.
|
||||
|
@ -41,6 +61,8 @@ class NumpyInterface(Interface):
|
|||
"""
|
||||
if isinstance(array, ndarray):
|
||||
return True
|
||||
elif isinstance(array, dict):
|
||||
return NumpyJsonDict.is_valid(array)
|
||||
else:
|
||||
try:
|
||||
_ = np.array(array)
|
||||
|
@ -61,3 +83,22 @@ class NumpyInterface(Interface):
|
|||
def enabled(cls) -> bool:
|
||||
"""Check that numpy is present in the environment"""
|
||||
return ENABLED
|
||||
|
||||
@classmethod
|
||||
def to_json(
|
||||
cls, array: ndarray, info: SerializationInfo = None
|
||||
) -> Union[list, JsonDict]:
|
||||
"""
|
||||
Convert an array of :attr:`.return_type` to a JSON-compatible format using
|
||||
base python types
|
||||
"""
|
||||
if not isinstance(array, np.ndarray): # pragma: no cover
|
||||
array = np.array(array)
|
||||
|
||||
json_array = array.tolist()
|
||||
|
||||
if info.round_trip:
|
||||
json_array = NumpyJsonDict(
|
||||
type=cls.name, dtype=str(array.dtype), array=json_array
|
||||
)
|
||||
return json_array
|
||||
|
|
|
@ -3,10 +3,12 @@ Interface to support treating videos like arrays using OpenCV
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic_core.core_schema import SerializationInfo
|
||||
|
||||
from numpydantic.interface import JsonDict
|
||||
from numpydantic.interface.interface import Interface
|
||||
|
||||
try:
|
||||
|
@ -19,6 +21,19 @@ except ImportError: # pragma: no cover
|
|||
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
|
||||
|
||||
class VideoJsonDict(JsonDict):
|
||||
"""Json-able roundtrip representation of a video file"""
|
||||
|
||||
type: Literal["video"]
|
||||
file: str
|
||||
|
||||
def to_array_input(self) -> "VideoProxy":
|
||||
"""
|
||||
Construct a :class:`.VideoProxy`
|
||||
"""
|
||||
return VideoProxy(path=Path(self.file))
|
||||
|
||||
|
||||
class VideoProxy:
|
||||
"""
|
||||
Passthrough proxy class to interact with videos as arrays
|
||||
|
@ -33,7 +48,7 @@ class VideoProxy:
|
|||
)
|
||||
|
||||
if path is not None:
|
||||
path = Path(path)
|
||||
path = Path(path).resolve()
|
||||
self.path = path
|
||||
|
||||
self._video = video # type: Optional[VideoCapture]
|
||||
|
@ -52,6 +67,9 @@ class VideoProxy:
|
|||
"and it cant be reopened since source path cant be gotten "
|
||||
"from VideoCapture objects"
|
||||
)
|
||||
if not self.path.exists():
|
||||
raise FileNotFoundError(f"Video file {self.path} does not exist!")
|
||||
|
||||
self._video = VideoCapture(str(self.path))
|
||||
return self._video
|
||||
|
||||
|
@ -137,6 +155,10 @@ class VideoProxy:
|
|||
slice_ = slice(0, slice_.stop, slice_.step)
|
||||
return slice_
|
||||
|
||||
def __array__(self) -> np.ndarray:
|
||||
"""Whole video as a numpy array"""
|
||||
return self[:]
|
||||
|
||||
def __getitem__(self, item: Union[int, slice, tuple]) -> np.ndarray:
|
||||
if isinstance(item, int):
|
||||
# want a single frame
|
||||
|
@ -178,8 +200,16 @@ class VideoProxy:
|
|||
raise NotImplementedError("Setting pixel values on videos is not supported!")
|
||||
|
||||
def __getattr__(self, item: str):
|
||||
if item == "__name__":
|
||||
return "VideoProxy"
|
||||
return getattr(self.video, item)
|
||||
|
||||
def __eq__(self, other: "VideoProxy") -> bool:
|
||||
"""Check if this is a proxy to the same video file"""
|
||||
if not isinstance(other, VideoProxy):
|
||||
raise TypeError("Can only compare equality of two VideoProxies")
|
||||
return self.path == other.path
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Number of frames in the video"""
|
||||
return self.shape[0]
|
||||
|
@ -190,8 +220,10 @@ class VideoInterface(Interface):
|
|||
OpenCV interface to treat videos as arrays.
|
||||
"""
|
||||
|
||||
name = "video"
|
||||
input_types = (str, Path, VideoCapture, VideoProxy)
|
||||
return_type = VideoProxy
|
||||
json_model = VideoJsonDict
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
|
@ -209,6 +241,9 @@ class VideoInterface(Interface):
|
|||
):
|
||||
return True
|
||||
|
||||
if isinstance(array, dict):
|
||||
array = array.get("file", "")
|
||||
|
||||
if isinstance(array, str):
|
||||
try:
|
||||
array = Path(array)
|
||||
|
@ -227,3 +262,13 @@ class VideoInterface(Interface):
|
|||
else:
|
||||
proxy = VideoProxy(path=array)
|
||||
return proxy
|
||||
|
||||
@classmethod
|
||||
def to_json(
|
||||
cls, array: VideoProxy, info: SerializationInfo
|
||||
) -> Union[list, VideoJsonDict]:
|
||||
"""Return a json-representation of a video"""
|
||||
if info.round_trip:
|
||||
return VideoJsonDict(type=cls.name, file=str(array.path))
|
||||
else:
|
||||
return np.array(array).tolist()
|
||||
|
|
|
@ -5,12 +5,12 @@ Interface to zarr arrays
|
|||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
from typing import Any, Literal, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import SerializationInfo
|
||||
|
||||
from numpydantic.interface.interface import Interface
|
||||
from numpydantic.interface.interface import Interface, JsonDict
|
||||
from numpydantic.types import DtypeType
|
||||
|
||||
try:
|
||||
|
@ -56,13 +56,36 @@ class ZarrArrayPath:
|
|||
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
|
||||
|
||||
|
||||
class ZarrJsonDict(JsonDict):
|
||||
"""Round-trip Json-able version of a Zarr Array"""
|
||||
|
||||
info: dict[str, str]
|
||||
type: Literal["zarr"]
|
||||
file: Optional[str] = None
|
||||
path: Optional[str] = None
|
||||
array: Optional[list] = None
|
||||
|
||||
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
||||
"""
|
||||
Construct a ZarrArrayPath if file and path are present,
|
||||
otherwise a ZarrArray
|
||||
"""
|
||||
if self.file:
|
||||
array = ZarrArrayPath(file=self.file, path=self.path)
|
||||
else:
|
||||
array = zarr.array(self.array)
|
||||
return array
|
||||
|
||||
|
||||
class ZarrInterface(Interface):
|
||||
"""
|
||||
Interface to in-memory or on-disk zarr arrays
|
||||
"""
|
||||
|
||||
name = "zarr"
|
||||
input_types = (Path, ZarrArray, ZarrArrayPath)
|
||||
return_type = ZarrArray
|
||||
json_model = ZarrJsonDict
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
|
@ -71,7 +94,7 @@ class ZarrInterface(Interface):
|
|||
|
||||
@staticmethod
|
||||
def _get_array(
|
||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
|
||||
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
|
||||
) -> ZarrArray:
|
||||
if isinstance(array, ZarrArray):
|
||||
return array
|
||||
|
@ -92,6 +115,12 @@ class ZarrInterface(Interface):
|
|||
if isinstance(array, ZarrArray):
|
||||
return True
|
||||
|
||||
if isinstance(array, dict):
|
||||
if array.get("type", False) == cls.name:
|
||||
return True
|
||||
# continue checking if dict contains a zarr file
|
||||
array = array.get("file", "")
|
||||
|
||||
# See if can be coerced to ZarrArrayPath
|
||||
if isinstance(array, (Path, str)):
|
||||
array = ZarrArrayPath(file=array)
|
||||
|
@ -135,26 +164,48 @@ class ZarrInterface(Interface):
|
|||
cls,
|
||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence],
|
||||
info: Optional[SerializationInfo] = None,
|
||||
) -> dict:
|
||||
) -> Union[list, ZarrJsonDict]:
|
||||
"""
|
||||
Dump just the metadata for an array from :meth:`zarr.core.Array.info_items`
|
||||
plus the :meth:`zarr.core.Array.hexdigest`.
|
||||
Dump a Zarr Array to JSON
|
||||
|
||||
The full array can be returned by passing ``'zarr_dump_array': True`` to the
|
||||
serialization ``context`` ::
|
||||
If ``info.round_trip == False``, dump the array as a list of lists.
|
||||
This may be a memory-intensive operation.
|
||||
|
||||
Otherwise, dump the metadata for an array from
|
||||
:meth:`zarr.core.Array.info_items`
|
||||
plus the :meth:`zarr.core.Array.hexdigest` as a :class:`.ZarrJsonDict`
|
||||
|
||||
If either the ``dump_array`` value in the context dictionary is ``True``
|
||||
or the zarr array is an in-memory array, dump the array as well
|
||||
(since without a persistent array it would be impossible to roundtrip and
|
||||
dumping to JSON would be meaningless)
|
||||
|
||||
Passing ```dump_array': True`` to the serialization ``context``
|
||||
looks like this::
|
||||
|
||||
model.model_dump_json(context={'zarr_dump_array': True})
|
||||
"""
|
||||
array = cls._get_array(array)
|
||||
|
||||
if info.round_trip:
|
||||
dump_array = False
|
||||
if info is not None and info.context is not None:
|
||||
dump_array = info.context.get("zarr_dump_array", False)
|
||||
dump_array = info.context.get("dump_array", False)
|
||||
is_file = False
|
||||
|
||||
array = cls._get_array(array)
|
||||
info = array.info_items()
|
||||
info_dict = {i[0]: i[1] for i in info}
|
||||
info_dict["hexdigest"] = array.hexdigest()
|
||||
as_json = {"type": cls.name}
|
||||
if hasattr(array.store, "dir_path"):
|
||||
is_file = True
|
||||
as_json["file"] = array.store.dir_path()
|
||||
as_json["path"] = array.name
|
||||
as_json["info"] = {i[0]: i[1] for i in array.info_items()}
|
||||
as_json["info"]["hexdigest"] = array.hexdigest()
|
||||
|
||||
if dump_array:
|
||||
info_dict["array"] = array[:].tolist()
|
||||
if dump_array or not is_file:
|
||||
as_json["array"] = array[:].tolist()
|
||||
|
||||
return info_dict
|
||||
as_json = ZarrJsonDict(**as_json)
|
||||
else:
|
||||
as_json = array[:].tolist()
|
||||
|
||||
return as_json
|
||||
|
|
|
@ -24,11 +24,10 @@ from numpydantic.exceptions import InterfaceError
|
|||
from numpydantic.interface import Interface
|
||||
from numpydantic.maps import python_to_nptyping
|
||||
from numpydantic.schema import (
|
||||
_handler_type,
|
||||
_jsonize_array,
|
||||
get_validate_interface,
|
||||
make_json_schema,
|
||||
)
|
||||
from numpydantic.serialization import jsonize_array
|
||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
||||
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||
|
@ -41,6 +40,9 @@ from numpydantic.vendor.nptyping.typing_ import (
|
|||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from nptyping.base_meta_classes import SubscriptableMeta
|
||||
from pydantic._internal._schema_generation_shared import (
|
||||
CallbackGetCoreSchemaHandler,
|
||||
)
|
||||
|
||||
from numpydantic import Shape
|
||||
|
||||
|
@ -164,33 +166,34 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
|||
def __get_pydantic_core_schema__(
|
||||
cls,
|
||||
_source_type: "NDArray",
|
||||
_handler: _handler_type,
|
||||
_handler: "CallbackGetCoreSchemaHandler",
|
||||
) -> core_schema.CoreSchema:
|
||||
shape, dtype = _source_type.__args__
|
||||
shape: ShapeType
|
||||
dtype: DtypeType
|
||||
|
||||
# get pydantic core schema as a list of lists for JSON schema
|
||||
list_schema = make_json_schema(shape, dtype, _handler)
|
||||
# make core schema for json schema, store it and any model definitions
|
||||
# note that there is a big of fragility in this function,
|
||||
# as we need to access a private method of _handler to
|
||||
# flatten out the json schema. See help(make_json_schema)
|
||||
json_schema = make_json_schema(shape, dtype, _handler)
|
||||
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=list_schema,
|
||||
python_schema=core_schema.with_info_plain_validator_function(
|
||||
get_validate_interface(shape, dtype)
|
||||
),
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
get_validate_interface(shape, dtype),
|
||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||
_jsonize_array, when_used="json", info_arg=True
|
||||
jsonize_array, when_used="json", info_arg=True
|
||||
),
|
||||
metadata=json_schema,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> core_schema.JsonSchema:
|
||||
json_schema = handler(schema)
|
||||
shape, dtype = cls.__args__
|
||||
json_schema = handler(schema["metadata"])
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
|
||||
dtype = cls.__args__[1]
|
||||
if not isinstance(dtype, tuple) and dtype.__module__ not in (
|
||||
"builtins",
|
||||
"typing",
|
||||
|
|
|
@ -5,10 +5,10 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, SerializationInfo
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
||||
|
||||
|
@ -19,13 +19,16 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
|||
from numpydantic.vendor.nptyping.structure import StructureMeta
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from pydantic._internal._schema_generation_shared import (
|
||||
CallbackGetCoreSchemaHandler,
|
||||
)
|
||||
|
||||
from numpydantic import Shape
|
||||
|
||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
||||
_UNSUPPORTED_TYPES = (complex,)
|
||||
|
||||
|
||||
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
def _numeric_dtype(
|
||||
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||
) -> CoreSchema:
|
||||
"""Make a numeric dtype that respects min/max values from extended numpy types"""
|
||||
if dtype in (np.number,):
|
||||
dtype = float
|
||||
|
@ -36,14 +39,15 @@ def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
|||
elif issubclass(dtype, np.integer):
|
||||
info = np.iinfo(dtype)
|
||||
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
|
||||
|
||||
else:
|
||||
schema = _handler.generate_schema(dtype)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||
def _lol_dtype(
|
||||
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||
) -> CoreSchema:
|
||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||
if isinstance(dtype, StructureMeta): # pragma: no cover
|
||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||
|
@ -79,11 +83,12 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
|||
# does this need a warning?
|
||||
python_type = Any
|
||||
|
||||
if python_type in _UNSUPPORTED_TYPES:
|
||||
array_type = core_schema.any_schema()
|
||||
# TODO: warn and log here
|
||||
elif python_type in (float, int):
|
||||
if python_type in (float, int):
|
||||
array_type = _numeric_dtype(dtype, _handler)
|
||||
elif python_type is bool:
|
||||
array_type = core_schema.bool_schema()
|
||||
elif python_type is Any:
|
||||
array_type = core_schema.any_schema()
|
||||
else:
|
||||
array_type = _handler.generate_schema(python_type)
|
||||
|
||||
|
@ -208,14 +213,24 @@ def _unbounded_shape(
|
|||
|
||||
|
||||
def make_json_schema(
|
||||
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
|
||||
shape: ShapeType, dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||
) -> ListSchema:
|
||||
"""
|
||||
Make a list of list JSON schema from a shape and a dtype.
|
||||
Make a list of list pydantic core schema for an array from a shape and a dtype.
|
||||
Used to generate JSON schema in the containing model, but not for validation,
|
||||
which is handled by interfaces.
|
||||
|
||||
First resolves the dtype into a pydantic ``CoreSchema`` ,
|
||||
and then uses that with :func:`.list_of_lists_schema` .
|
||||
|
||||
.. admonition:: Potentially Fragile
|
||||
|
||||
Uses a private method from the handler to flatten out nested definitions
|
||||
(e.g. when dtype is a pydantic model)
|
||||
so that they are present in the generated schema directly rather than
|
||||
as references. Otherwise, at the time __get_pydantic_json_schema__ is called,
|
||||
the definition references are lost.
|
||||
|
||||
Args:
|
||||
shape ( ShapeType ): Specification of a shape, as a tuple or
|
||||
an nptyping ``Shape``
|
||||
|
@ -234,6 +249,8 @@ def make_json_schema(
|
|||
else:
|
||||
list_schema = list_of_lists_schema(shape, dtype_schema)
|
||||
|
||||
list_schema = _handler._generate_schema.clean_schema(list_schema)
|
||||
|
||||
return list_schema
|
||||
|
||||
|
||||
|
@ -252,9 +269,3 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
|||
return value
|
||||
|
||||
return validate_interface
|
||||
|
||||
|
||||
def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||
"""Use an interface class to render an array as JSON"""
|
||||
interface_cls = Interface.match_output(value)
|
||||
return interface_cls.to_json(value, info)
|
||||
|
|
128
src/numpydantic/serialization.py
Normal file
128
src/numpydantic/serialization.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
"""
|
||||
Serialization helpers for :func:`pydantic.BaseModel.model_dump`
|
||||
and :func:`pydantic.BaseModel.model_dump_json` .
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
|
||||
from pydantic_core.core_schema import SerializationInfo
|
||||
|
||||
from numpydantic.interface import Interface, JsonDict
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||
"""Use an interface class to render an array as JSON"""
|
||||
interface_cls = Interface.match_output(value)
|
||||
array = interface_cls.to_json(value, info)
|
||||
if isinstance(array, JsonDict):
|
||||
array = array.model_dump(exclude_none=True)
|
||||
|
||||
if info.context:
|
||||
if info.context.get("mark_interface", False):
|
||||
array = interface_cls.mark_json(array)
|
||||
|
||||
if info.context.get("absolute_paths", False):
|
||||
array = _absolutize_paths(array)
|
||||
else:
|
||||
relative_to = info.context.get("relative_to", ".")
|
||||
array = _relativize_paths(array, relative_to)
|
||||
else:
|
||||
# relativize paths by default
|
||||
array = _relativize_paths(array, ".")
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
||||
"""
|
||||
Make paths relative to either the current directory or the provided
|
||||
``relative_to`` directory, if provided in the context
|
||||
"""
|
||||
relative_to = Path(relative_to).resolve()
|
||||
# pdb.set_trace()
|
||||
|
||||
def _r_path(v: Any) -> Any:
|
||||
try:
|
||||
path = Path(v)
|
||||
if not path.exists():
|
||||
return v
|
||||
return str(relative_path(path, relative_to))
|
||||
except (TypeError, ValueError):
|
||||
return v
|
||||
|
||||
return _walk_and_apply(value, _r_path)
|
||||
|
||||
|
||||
def _absolutize_paths(value: dict) -> dict:
|
||||
def _a_path(v: Any) -> Any:
|
||||
try:
|
||||
path = Path(v)
|
||||
if not path.exists():
|
||||
return v
|
||||
return str(path.resolve())
|
||||
except (TypeError, ValueError):
|
||||
return v
|
||||
|
||||
return _walk_and_apply(value, _a_path)
|
||||
|
||||
|
||||
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T:
|
||||
"""
|
||||
Walk an object, applying a function
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
if isinstance(v, dict):
|
||||
_walk_and_apply(v, f)
|
||||
elif isinstance(v, list):
|
||||
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v]
|
||||
else:
|
||||
value[k] = f(v)
|
||||
elif isinstance(value, list):
|
||||
value = [_walk_and_apply(v, f) for v in value]
|
||||
else:
|
||||
value = f(value)
|
||||
return value
|
||||
|
||||
|
||||
def relative_path(self: Path, other: Path, walk_up: bool = True) -> Path:
|
||||
"""
|
||||
"Backport" of :meth:`pathlib.Path.relative_to` with ``walk_up=True``
|
||||
that's not available pre 3.12.
|
||||
|
||||
Return the relative path to another path identified by the passed
|
||||
arguments. If the operation is not possible (because this is not
|
||||
related to the other path), raise ValueError.
|
||||
|
||||
The *walk_up* parameter controls whether `..` may be used to resolve
|
||||
the path.
|
||||
|
||||
References:
|
||||
https://github.com/python/cpython/blob/8a2baedc4bcb606da937e4e066b4b3a18961cace/Lib/pathlib/_abc.py#L244-L270
|
||||
"""
|
||||
# pdb.set_trace()
|
||||
if not isinstance(other, Path): # pragma: no cover - ripped from cpython
|
||||
other = Path(other)
|
||||
self_parts = self.parts
|
||||
other_parts = other.parts
|
||||
anchor0, parts0 = self_parts[0], list(reversed(self_parts[1:]))
|
||||
anchor1, parts1 = other_parts[0], list(reversed(other_parts[1:]))
|
||||
if anchor0 != anchor1:
|
||||
raise ValueError(f"{self!r} and {other!r} have different anchors")
|
||||
while parts0 and parts1 and parts0[-1] == parts1[-1]:
|
||||
parts0.pop()
|
||||
parts1.pop()
|
||||
for part in parts1: # pragma: no cover - not testing, ripped off from cpython
|
||||
if not part or part == ".":
|
||||
pass
|
||||
elif not walk_up:
|
||||
raise ValueError(f"{self!r} is not in the subpath of {other!r}")
|
||||
elif part == "..":
|
||||
raise ValueError(f"'..' segment in {other!r} cannot be walked")
|
||||
else:
|
||||
parts0.append("..")
|
||||
return Path(*reversed(parts0))
|
|
@ -1,4 +1,3 @@
|
|||
import pdb
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from typing import Tuple, Callable
|
||||
from typing import Callable, Tuple, Type
|
||||
import numpy as np
|
||||
import dask.array as da
|
||||
import zarr
|
||||
|
@ -12,27 +12,47 @@ from numpydantic import interface, NDArray
|
|||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=[
|
||||
pytest.param(
|
||||
([[1, 2], [3, 4]], interface.NumpyInterface),
|
||||
marks=pytest.mark.numpy,
|
||||
id="numpy-list",
|
||||
),
|
||||
pytest.param(
|
||||
(np.zeros((3, 4)), interface.NumpyInterface),
|
||||
marks=pytest.mark.numpy,
|
||||
id="numpy",
|
||||
),
|
||||
pytest.param(
|
||||
("hdf5_array", interface.H5Interface),
|
||||
marks=pytest.mark.hdf5,
|
||||
id="h5-array-path",
|
||||
),
|
||||
pytest.param(
|
||||
(da.random.random((10, 10)), interface.DaskInterface),
|
||||
marks=pytest.mark.dask,
|
||||
id="dask",
|
||||
),
|
||||
pytest.param(
|
||||
(zarr.ones((10, 10)), interface.ZarrInterface),
|
||||
marks=pytest.mark.zarr,
|
||||
id="zarr-memory",
|
||||
),
|
||||
pytest.param(
|
||||
("zarr_nested_array", interface.ZarrInterface),
|
||||
marks=pytest.mark.zarr,
|
||||
id="zarr-nested",
|
||||
),
|
||||
pytest.param(
|
||||
("zarr_array", interface.ZarrInterface),
|
||||
("avi_video", interface.VideoInterface),
|
||||
],
|
||||
ids=[
|
||||
"numpy_list",
|
||||
"numpy",
|
||||
"H5ArrayPath",
|
||||
"dask",
|
||||
"zarr_memory",
|
||||
"zarr_nested",
|
||||
"zarr_array",
|
||||
"video",
|
||||
marks=pytest.mark.zarr,
|
||||
id="zarr-array",
|
||||
),
|
||||
pytest.param(
|
||||
("avi_video", interface.VideoInterface), marks=pytest.mark.video, id="video"
|
||||
),
|
||||
],
|
||||
)
|
||||
def interface_type(request) -> Tuple[NDArray, interface.Interface]:
|
||||
def interface_type(request) -> Tuple[NDArray, Type[interface.Interface]]:
|
||||
"""
|
||||
Test cases for each interface's ``check`` method - each input should match the
|
||||
provided interface and that interface only
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import pdb
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
|
@ -11,6 +9,8 @@ from numpydantic.exceptions import DtypeError, ShapeError
|
|||
|
||||
from tests.conftest import ValidationCase
|
||||
|
||||
pytestmark = pytest.mark.dask
|
||||
|
||||
|
||||
def dask_array(case: ValidationCase) -> da.Array:
|
||||
if issubclass(case.dtype, BaseModel):
|
||||
|
@ -42,14 +42,17 @@ def test_dask_check(interface_type):
|
|||
assert not DaskInterface.check(interface_type[0])
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
def test_dask_shape(shape_cases):
|
||||
_test_dask_case(shape_cases)
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
def test_dask_dtype(dtype_cases):
|
||||
_test_dask_case(dtype_cases)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_dask_to_json(array_model):
|
||||
array_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||
array = da.array(array_list)
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
"""
|
||||
Tests for dunder methods on all interfaces
|
||||
"""
|
||||
|
||||
|
||||
def test_dunder_len(all_interfaces):
|
||||
"""
|
||||
Each interface or proxy type should support __len__
|
||||
"""
|
||||
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
|
@ -14,6 +14,8 @@ from numpydantic.exceptions import DtypeError, ShapeError
|
|||
|
||||
from tests.conftest import ValidationCase
|
||||
|
||||
pytestmark = pytest.mark.hdf5
|
||||
|
||||
|
||||
def hdf5_array_case(
|
||||
case: ValidationCase, array_func, compound: bool = False
|
||||
|
@ -72,11 +74,13 @@ def test_hdf5_check_not_hdf5(tmp_path):
|
|||
assert not H5Interface.check(spec)
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_hdf5_shape(shape_cases, hdf5_array, compound):
|
||||
_test_hdf5_case(shape_cases, hdf5_array, compound)
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
|
||||
_test_hdf5_case(dtype_cases, hdf5_array, compound)
|
||||
|
@ -90,6 +94,7 @@ def test_hdf5_dataset_not_exists(hdf5_array, model_blank):
|
|||
assert "no array found" in e
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
def test_assignment(hdf5_array, model_blank):
|
||||
array = hdf5_array()
|
||||
|
||||
|
@ -101,7 +106,9 @@ def test_assignment(hdf5_array, model_blank):
|
|||
assert (model.array[1:3, 2:4] == 10).all()
|
||||
|
||||
|
||||
def test_to_json(hdf5_array, array_model):
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("round_trip", (True, False))
|
||||
def test_to_json(hdf5_array, array_model, round_trip):
|
||||
"""
|
||||
Test serialization of HDF5 arrays to JSON
|
||||
Args:
|
||||
|
@ -115,15 +122,19 @@ def test_to_json(hdf5_array, array_model):
|
|||
|
||||
instance = model(array=array) # type: BaseModel
|
||||
|
||||
json_str = instance.model_dump_json()
|
||||
json_dict = json.loads(json_str)["array"]
|
||||
|
||||
assert json_dict["file"] == str(array.file)
|
||||
assert json_dict["path"] == str(array.path)
|
||||
assert json_dict["attrs"] == {}
|
||||
assert json_dict["array"] == instance.array[:].tolist()
|
||||
json_str = instance.model_dump_json(
|
||||
round_trip=round_trip, context={"absolute_paths": True}
|
||||
)
|
||||
json_dumped = json.loads(json_str)["array"]
|
||||
if round_trip:
|
||||
assert json_dumped["file"] == str(array.file)
|
||||
assert json_dumped["path"] == str(array.path)
|
||||
else:
|
||||
assert json_dumped == instance.array[:].tolist()
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.proxy
|
||||
def test_compound_dtype(tmp_path):
|
||||
"""
|
||||
hdf5 proxy indexes compound dtypes as single fields when field is given
|
||||
|
@ -158,6 +169,8 @@ def test_compound_dtype(tmp_path):
|
|||
assert all(instance.array[1] == 2)
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.proxy
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_strings(hdf5_array, compound):
|
||||
"""
|
||||
|
@ -177,6 +190,8 @@ def test_strings(hdf5_array, compound):
|
|||
assert all(instance.array[1] == "sup")
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.proxy
|
||||
@pytest.mark.parametrize("compound", [True, False])
|
||||
def test_datetime(hdf5_array, compound):
|
||||
"""
|
||||
|
@ -218,3 +233,29 @@ def test_empty_dataset(dtype, tmp_path):
|
|||
array: NDArray[Any, dtype]
|
||||
|
||||
_ = MyModel(array=(array_path, "/data"))
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
@pytest.mark.parametrize(
|
||||
"comparison,valid",
|
||||
[
|
||||
(H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True),
|
||||
(H5Proxy(file="test_file.h5", path="/subpath"), False),
|
||||
(H5Proxy(file="different_file.h5", path="/subpath"), False),
|
||||
(("different_file.h5", "/subpath", "sup"), ValueError),
|
||||
("not even a proxy-like thing", ValueError),
|
||||
],
|
||||
)
|
||||
def test_proxy_eq(comparison, valid):
|
||||
"""
|
||||
test the __eq__ method of H5ArrayProxy matches proxies to the same
|
||||
dataset (and path), or raises a ValueError
|
||||
"""
|
||||
proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup")
|
||||
if valid is True:
|
||||
assert proxy_a == comparison
|
||||
elif valid is False:
|
||||
assert proxy_a != comparison
|
||||
else:
|
||||
with pytest.raises(valid):
|
||||
assert proxy_a == comparison
|
||||
|
|
|
@ -4,11 +4,32 @@ for tests that should apply to all interfaces, use ``test_interfaces.py``
|
|||
"""
|
||||
|
||||
import gc
|
||||
from typing import Literal
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from numpydantic.interface import Interface
|
||||
from numpydantic.interface import (
|
||||
Interface,
|
||||
JsonDict,
|
||||
InterfaceMark,
|
||||
NumpyInterface,
|
||||
MarkedJson,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from numpydantic.interface.interface import V
|
||||
|
||||
|
||||
class MyJsonDict(JsonDict):
|
||||
type: Literal["my_json_dict"]
|
||||
field: str
|
||||
number: int
|
||||
|
||||
def to_array_input(self) -> V:
|
||||
dumped = self.model_dump()
|
||||
dumped["extra_input_param"] = True
|
||||
return dumped
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -162,3 +183,66 @@ def test_interface_recursive(interfaces):
|
|||
assert issubclass(interfaces.interface3, interfaces.interface1)
|
||||
assert issubclass(interfaces.interface1, Interface)
|
||||
assert interfaces.interface4 in ifaces
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_jsondict_is_valid():
|
||||
"""
|
||||
A JsonDict should return a bool true/false if it is valid or not,
|
||||
and raise an error when requested
|
||||
"""
|
||||
invalid = {"doesnt": "have", "the": "props"}
|
||||
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||
assert MyJsonDict.is_valid(valid)
|
||||
assert not MyJsonDict.is_valid(invalid)
|
||||
with pytest.raises(ValidationError):
|
||||
assert not MyJsonDict.is_valid(invalid, raise_on_error=True)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_jsondict_handle_input():
|
||||
"""
|
||||
JsonDict should be able to parse a valid dict and return it to the input format
|
||||
"""
|
||||
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||
instantiated = MyJsonDict(**valid)
|
||||
expected = {
|
||||
"type": "my_json_dict",
|
||||
"field": "a_field",
|
||||
"number": 1,
|
||||
"extra_input_param": True,
|
||||
}
|
||||
|
||||
for item in (valid, instantiated):
|
||||
result = MyJsonDict.handle_input(item)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("interface", Interface.interfaces())
|
||||
def test_interface_mark_match_by_name(interface):
|
||||
"""
|
||||
Interface mark should match an interface by its name
|
||||
"""
|
||||
# other parts don't matter
|
||||
mark = InterfaceMark(module="fake", cls="fake", version="fake", name=interface.name)
|
||||
fake_mark = InterfaceMark(
|
||||
module="fake", cls="fake", version="fake", name="also_fake"
|
||||
)
|
||||
assert mark.match_by_name() is interface
|
||||
assert fake_mark.match_by_name() is None
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_marked_json_try_cast():
|
||||
"""
|
||||
MarkedJson.try_cast should try and cast to a markedjson!
|
||||
returning the value unchanged if it's not a match
|
||||
"""
|
||||
valid = {"interface": NumpyInterface.mark_interface(), "value": [[1, 2], [3, 4]]}
|
||||
invalid = [1, 2, 3, 4, 5]
|
||||
mimic = {"interface": "not really", "value": "still not really"}
|
||||
|
||||
assert isinstance(MarkedJson.try_cast(valid), MarkedJson)
|
||||
assert MarkedJson.try_cast(invalid) is invalid
|
||||
assert MarkedJson.try_cast(mimic) is mimic
|
||||
|
|
|
@ -2,6 +2,41 @@
|
|||
Tests that should be applied to all interfaces
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from typing import Callable
|
||||
from importlib.metadata import version
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import dask.array as da
|
||||
from zarr.core import Array as ZarrArray
|
||||
from pydantic import BaseModel
|
||||
|
||||
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||
|
||||
|
||||
def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool):
|
||||
"""Test model equality for roundtrip tests"""
|
||||
if round_trip:
|
||||
assert type(target.array) is type(source.array)
|
||||
if isinstance(source.array, (np.ndarray, ZarrArray)):
|
||||
assert np.array_equal(target.array, np.array(source.array))
|
||||
elif isinstance(source.array, da.Array):
|
||||
assert np.all(da.equal(target.array, source.array))
|
||||
else:
|
||||
assert target.array == source.array
|
||||
|
||||
assert target.array.dtype == source.array.dtype
|
||||
else:
|
||||
assert np.array_equal(target.array, np.array(source.array))
|
||||
|
||||
|
||||
def test_dunder_len(all_interfaces):
|
||||
"""
|
||||
Each interface or proxy type should support __len__
|
||||
"""
|
||||
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
||||
|
||||
|
||||
def test_interface_revalidate(all_interfaces):
|
||||
"""
|
||||
|
@ -10,3 +45,86 @@ def test_interface_revalidate(all_interfaces):
|
|||
See: https://github.com/p2p-ld/numpydantic/pull/14
|
||||
"""
|
||||
_ = type(all_interfaces)(array=all_interfaces.array)
|
||||
|
||||
|
||||
def test_interface_rematch(interface_type):
|
||||
"""
|
||||
All interfaces should match the results of the object they return after validation
|
||||
"""
|
||||
array, interface = interface_type
|
||||
if isinstance(array, Callable):
|
||||
array = array()
|
||||
|
||||
assert Interface.match(interface().validate(array)) is interface
|
||||
|
||||
|
||||
def test_interface_to_numpy_array(all_interfaces):
|
||||
"""
|
||||
All interfaces should be able to have the output of their validation stage
|
||||
coerced to a numpy array with np.array()
|
||||
"""
|
||||
_ = np.array(all_interfaces.array)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_interface_dump_json(all_interfaces):
|
||||
"""
|
||||
All interfaces should be able to dump to json
|
||||
"""
|
||||
all_interfaces.model_dump_json()
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("round_trip", [True, False])
|
||||
def test_interface_roundtrip_json(all_interfaces, round_trip):
|
||||
"""
|
||||
All interfaces should be able to roundtrip to and from json
|
||||
"""
|
||||
dumped_json = all_interfaces.model_dump_json(round_trip=round_trip)
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
_test_roundtrip(all_interfaces, model, round_trip)
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("an_interface", Interface.interfaces())
|
||||
def test_interface_mark_interface(an_interface):
|
||||
"""
|
||||
All interfaces should be able to mark the current version and interface info
|
||||
"""
|
||||
mark = an_interface.mark_interface()
|
||||
assert isinstance(mark, InterfaceMark)
|
||||
assert mark.name == an_interface.name
|
||||
assert mark.cls == an_interface.__name__
|
||||
assert mark.module == an_interface.__module__
|
||||
assert mark.version == version(mark.module.split(".")[0])
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("valid", [True, False])
|
||||
@pytest.mark.parametrize("round_trip", [True, False])
|
||||
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
|
||||
def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
|
||||
"""
|
||||
All interfaces should be able to roundtrip with the marked interface,
|
||||
and a mismatch should raise a warning and attempt to proceed
|
||||
"""
|
||||
dumped_json = all_interfaces.model_dump_json(
|
||||
round_trip=round_trip, context={"mark_interface": True}
|
||||
)
|
||||
|
||||
data = json.loads(dumped_json)
|
||||
|
||||
# ensure that we are a MarkedJson
|
||||
_ = MarkedJson.model_validate_json(json.dumps(data["array"]))
|
||||
|
||||
if not valid:
|
||||
# ruin the version
|
||||
data["array"]["interface"]["version"] = "v99999999"
|
||||
dumped_json = json.dumps(data)
|
||||
|
||||
with pytest.warns(match="Mismatch.*"):
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
else:
|
||||
model = all_interfaces.model_validate_json(dumped_json)
|
||||
|
||||
_test_roundtrip(all_interfaces, model, round_trip)
|
||||
|
|
|
@ -5,6 +5,8 @@ from numpydantic.exceptions import DtypeError, ShapeError
|
|||
|
||||
from tests.conftest import ValidationCase
|
||||
|
||||
pytestmark = pytest.mark.numpy
|
||||
|
||||
|
||||
def numpy_array(case: ValidationCase) -> np.ndarray:
|
||||
if issubclass(case.dtype, BaseModel):
|
||||
|
@ -22,10 +24,12 @@ def _test_np_case(case: ValidationCase):
|
|||
case.model(array=array)
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
def test_numpy_shape(shape_cases):
|
||||
_test_np_case(shape_cases)
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
def test_numpy_dtype(dtype_cases):
|
||||
_test_np_case(dtype_cases)
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@ from numpydantic import NDArray, Shape
|
|||
from numpydantic import dtype as dt
|
||||
from numpydantic.interface.video import VideoProxy
|
||||
|
||||
pytestmark = pytest.mark.video
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_type", [str, Path])
|
||||
def test_video_validation(avi_video, input_type):
|
||||
|
@ -49,6 +51,7 @@ def test_video_from_videocapture(avi_video):
|
|||
opened_vid.release()
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
def test_video_wrong_shape(avi_video):
|
||||
shape = (100, 50)
|
||||
|
||||
|
@ -65,6 +68,7 @@ def test_video_wrong_shape(avi_video):
|
|||
instance = MyModel(array=vid)
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
def test_video_getitem(avi_video):
|
||||
"""
|
||||
Should be able to get individual frames and slices as if it were a normal array
|
||||
|
@ -127,6 +131,7 @@ def test_video_getitem(avi_video):
|
|||
instance.array[5] = 10
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
def test_video_attrs(avi_video):
|
||||
"""Should be able to access opencv properties"""
|
||||
shape = (100, 50)
|
||||
|
@ -142,6 +147,7 @@ def test_video_attrs(avi_video):
|
|||
assert int(instance.array.get(cv2.CAP_PROP_POS_FRAMES)) == 5
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
def test_video_close(avi_video):
|
||||
"""Should close and reopen video file if needed"""
|
||||
shape = (100, 50)
|
||||
|
@ -158,3 +164,42 @@ def test_video_close(avi_video):
|
|||
assert instance.array._video is None
|
||||
# reopen
|
||||
assert isinstance(instance.array.video, cv2.VideoCapture)
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
def test_video_not_exists(tmp_path):
|
||||
"""
|
||||
A video file that doesn't exist should raise an error
|
||||
"""
|
||||
video = VideoProxy(tmp_path / "not_real.avi")
|
||||
with pytest.raises(FileNotFoundError):
|
||||
_ = video.video
|
||||
|
||||
|
||||
@pytest.mark.proxy
|
||||
@pytest.mark.parametrize(
|
||||
"comparison,valid",
|
||||
[
|
||||
(VideoProxy("test_video.avi"), True),
|
||||
(VideoProxy("not_real_video.avi"), False),
|
||||
("not even a video proxy", TypeError),
|
||||
],
|
||||
)
|
||||
def test_video_proxy_eq(comparison, valid):
|
||||
"""
|
||||
Comparing a video proxy's equality should be valid if the path matches
|
||||
Args:
|
||||
comparison:
|
||||
valid:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
proxy_a = VideoProxy("test_video.avi")
|
||||
if valid is True:
|
||||
assert proxy_a == comparison
|
||||
elif valid is False:
|
||||
assert proxy_a != comparison
|
||||
else:
|
||||
with pytest.raises(valid):
|
||||
assert proxy_a == comparison
|
||||
|
|
|
@ -6,13 +6,14 @@ import zarr
|
|||
from pydantic import BaseModel, ValidationError
|
||||
from numcodecs import Pickle
|
||||
|
||||
|
||||
from numpydantic.interface import ZarrInterface
|
||||
from numpydantic.interface.zarr import ZarrArrayPath
|
||||
from numpydantic.exceptions import DtypeError, ShapeError
|
||||
|
||||
from tests.conftest import ValidationCase
|
||||
|
||||
pytestmark = pytest.mark.zarr
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dir_array(tmp_output_dir_func) -> zarr.DirectoryStore:
|
||||
|
@ -87,10 +88,12 @@ def test_zarr_check(interface_type):
|
|||
assert not ZarrInterface.check(interface_type[0])
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
def test_zarr_shape(store, shape_cases):
|
||||
_test_zarr_case(shape_cases, store)
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
def test_zarr_dtype(dtype_cases, store):
|
||||
_test_zarr_case(dtype_cases, store)
|
||||
|
||||
|
@ -123,7 +126,10 @@ def test_zarr_array_path_from_iterable(zarr_array):
|
|||
assert apath.path == inner_path
|
||||
|
||||
|
||||
def test_zarr_to_json(store, model_blank):
|
||||
@pytest.mark.serialization
|
||||
@pytest.mark.parametrize("dump_array", [True, False])
|
||||
@pytest.mark.parametrize("roundtrip", [True, False])
|
||||
def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
|
||||
expected_fields = (
|
||||
"Type",
|
||||
"Data type",
|
||||
|
@ -137,17 +143,22 @@ def test_zarr_to_json(store, model_blank):
|
|||
|
||||
array = zarr.array(lol_array, store=store)
|
||||
instance = model_blank(array=array)
|
||||
as_json = json.loads(instance.model_dump_json())["array"]
|
||||
assert "array" not in as_json
|
||||
for field in expected_fields:
|
||||
assert field in as_json
|
||||
assert len(as_json["hexdigest"]) == 40
|
||||
|
||||
# dump the array itself too
|
||||
as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[
|
||||
"array"
|
||||
]
|
||||
for field in expected_fields:
|
||||
assert field in as_json
|
||||
assert len(as_json["hexdigest"]) == 40
|
||||
context = {"dump_array": dump_array}
|
||||
as_json = json.loads(
|
||||
instance.model_dump_json(round_trip=roundtrip, context=context)
|
||||
)["array"]
|
||||
|
||||
if roundtrip:
|
||||
if dump_array:
|
||||
assert as_json["array"] == lol_array
|
||||
else:
|
||||
if as_json.get("file", False):
|
||||
assert "array" not in as_json
|
||||
|
||||
for field in expected_fields:
|
||||
assert field in as_json["info"]
|
||||
assert len(as_json["info"]["hexdigest"]) == 40
|
||||
|
||||
else:
|
||||
assert as_json == lol_array
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import pdb
|
||||
|
||||
import pytest
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
|
@ -15,6 +13,7 @@ from numpydantic import dtype
|
|||
from numpydantic.dtype import Number
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
def test_ndarray_type():
|
||||
class Model(BaseModel):
|
||||
array: NDArray[Shape["2 x, * y"], Number]
|
||||
|
@ -40,6 +39,8 @@ def test_ndarray_type():
|
|||
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
def test_schema_unsupported_type():
|
||||
"""
|
||||
Complex numbers should just be made with an `any` schema
|
||||
|
@ -55,9 +56,11 @@ def test_schema_unsupported_type():
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
def test_schema_tuple():
|
||||
"""
|
||||
Types specified as tupled should have their schemas as a union
|
||||
Types specified as tuples should have their schemas as a union
|
||||
"""
|
||||
|
||||
class Model(BaseModel):
|
||||
|
@ -72,6 +75,8 @@ def test_schema_tuple():
|
|||
assert all([i["minimum"] == 0 for i in conditions])
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
def test_schema_number():
|
||||
"""
|
||||
np.numeric should just be the float schema
|
||||
|
@ -115,12 +120,12 @@ def test_ndarray_union():
|
|||
instance = Model(array=np.random.random((5, 10, 4, 6)))
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.parametrize("dtype", dtype.Number)
|
||||
def test_ndarray_unparameterized(dtype):
|
||||
"""
|
||||
NDArray without any parameters is any shape, any type
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
class Model(BaseModel):
|
||||
|
@ -134,6 +139,7 @@ def test_ndarray_unparameterized(dtype):
|
|||
_ = Model(array=np.zeros(dim_sizes, dtype=dtype))
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
def test_ndarray_any():
|
||||
"""
|
||||
using :class:`typing.Any` in for the shape means any shape
|
||||
|
@ -164,6 +170,19 @@ def test_ndarray_coercion():
|
|||
amod = Model(array=["a", "b", "c"])
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
def test_shape_ellipsis():
|
||||
"""
|
||||
Test that ellipsis is a wildcard, rather than "repeat the last index"
|
||||
"""
|
||||
|
||||
class MyModel(BaseModel):
|
||||
array: NDArray[Shape["1, 2, ..."], Number]
|
||||
|
||||
_ = MyModel(array=np.zeros((1, 2, 3, 4, 5)))
|
||||
|
||||
|
||||
@pytest.mark.serialization
|
||||
def test_ndarray_serialize():
|
||||
"""
|
||||
Arrays should be dumped to a list when using json, but kept as ndarray otherwise
|
||||
|
@ -188,6 +207,7 @@ _json_schema_types = [
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.json_schema
|
||||
def test_json_schema_basic(array_model):
|
||||
"""
|
||||
NDArray types should correctly generate a list of lists JSON schema
|
||||
|
@ -210,6 +230,8 @@ def test_json_schema_basic(array_model):
|
|||
assert inner["items"]["type"] == "number"
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
|
||||
def test_json_schema_dtype_single(dtype, array_model):
|
||||
"""
|
||||
|
@ -240,6 +262,8 @@ def test_json_schema_dtype_single(dtype, array_model):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,expected",
|
||||
[
|
||||
|
@ -266,6 +290,8 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
|
|||
assert inner_type["type"] == expected
|
||||
|
||||
|
||||
@pytest.mark.dtype
|
||||
@pytest.mark.json_schema
|
||||
def test_json_schema_dtype_model():
|
||||
"""
|
||||
Pydantic models can be used in arrays as dtypes
|
||||
|
@ -314,6 +340,8 @@ def _recursive_array(schema):
|
|||
assert any_of[1]["minimum"] == 0
|
||||
|
||||
|
||||
@pytest.mark.shape
|
||||
@pytest.mark.json_schema
|
||||
def test_json_schema_ellipsis():
|
||||
"""
|
||||
NDArray types should create a recursive JSON schema for any-shaped arrays
|
||||
|
|
95
tests/test_serialization.py
Normal file
95
tests/test_serialization.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
Test serialization-specific functionality that doesn't need to be
|
||||
applied across every interface (use test_interface/test_interfaces for that
|
||||
"""
|
||||
|
||||
import h5py
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
pytestmark = pytest.mark.serialization
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hdf5_at_path() -> Callable[[Path], None]:
|
||||
_path = ""
|
||||
|
||||
def _hdf5_at_path(path: Path) -> None:
|
||||
nonlocal _path
|
||||
_path = path
|
||||
h5f = h5py.File(path, "w")
|
||||
_ = h5f.create_dataset("/data", data=np.array([[1, 2], [3, 4]]))
|
||||
_ = h5f.create_dataset("subpath/to/dataset", data=np.array([[1, 2], [4, 5]]))
|
||||
h5f.close()
|
||||
|
||||
yield _hdf5_at_path
|
||||
|
||||
Path(_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_relative_path(hdf5_at_path, tmp_output_dir, model_blank):
|
||||
"""
|
||||
By default, we should make all paths relative to the cwd
|
||||
"""
|
||||
out_path = tmp_output_dir / "relative.h5"
|
||||
hdf5_at_path(out_path)
|
||||
model = model_blank(array=(out_path, "/data"))
|
||||
rt = model.model_dump_json(round_trip=True)
|
||||
file = json.loads(rt)["array"]["file"]
|
||||
|
||||
# should not be absolute
|
||||
assert not Path(file).is_absolute()
|
||||
# should be relative to cwd
|
||||
out_file = (Path.cwd() / file).resolve()
|
||||
assert out_file == out_path.resolve()
|
||||
|
||||
|
||||
def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank):
|
||||
"""
|
||||
When explicitly passed a path to be ``relative_to`` ,
|
||||
relative to that instead of cwd
|
||||
"""
|
||||
out_path = tmp_output_dir / "relative.h5"
|
||||
relative_to_path = Path(__file__) / "fake_dir" / "sub_fake_dir"
|
||||
expected_path = "../../../__tmp__/relative.h5"
|
||||
|
||||
hdf5_at_path(out_path)
|
||||
model = model_blank(array=(out_path, "/data"))
|
||||
rt = model.model_dump_json(
|
||||
round_trip=True, context={"relative_to": str(relative_to_path)}
|
||||
)
|
||||
data = json.loads(rt)["array"]
|
||||
file = data["file"]
|
||||
|
||||
# should not be absolute
|
||||
assert not Path(file).is_absolute()
|
||||
# should be expected path and reach the file
|
||||
assert file == expected_path
|
||||
assert (relative_to_path / file).resolve() == out_path.resolve()
|
||||
|
||||
# we shouldn't have touched `/data` even though it is pathlike
|
||||
assert data["path"] == "/data"
|
||||
|
||||
|
||||
def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank):
|
||||
"""
|
||||
When told, we make paths absolute
|
||||
"""
|
||||
out_path = tmp_output_dir / "relative.h5"
|
||||
expected_dataset = "subpath/to/dataset"
|
||||
|
||||
hdf5_at_path(out_path)
|
||||
model = model_blank(array=(out_path, expected_dataset))
|
||||
rt = model.model_dump_json(round_trip=True, context={"absolute_paths": True})
|
||||
data = json.loads(rt)["array"]
|
||||
file = data["file"]
|
||||
|
||||
# should be absolute and equal to out_path
|
||||
assert Path(file).is_absolute()
|
||||
assert Path(file) == out_path.resolve()
|
||||
|
||||
# shouldn't have absolutized subpath even if it's pathlike
|
||||
assert data["path"] == expected_dataset
|
|
@ -1,5 +1,3 @@
|
|||
import pdb
|
||||
|
||||
import pytest
|
||||
|
||||
from typing import Any
|
||||
|
@ -9,6 +7,8 @@ import numpy as np
|
|||
|
||||
from numpydantic import NDArray, Shape
|
||||
|
||||
pytestmark = pytest.mark.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape,valid",
|
||||
|
|
Loading…
Reference in a new issue