mirror of
https://github.com/p2p-ld/numpydantic.git
synced 2025-01-10 05:54:26 +00:00
Merge pull request #20 from p2p-ld/dump_json
Roundtrip JSON serialization/deserialization
This commit is contained in:
commit
bd5b93773b
39 changed files with 2774 additions and 203 deletions
|
@ -1,6 +0,0 @@
|
||||||
# monkeypatch
|
|
||||||
|
|
||||||
```{eval-rst}
|
|
||||||
.. automodule:: numpydantic.monkeypatch
|
|
||||||
:members:
|
|
||||||
```
|
|
7
docs/api/serialization.md
Normal file
7
docs/api/serialization.md
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# serialization
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. automodule:: numpydantic.serialization
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
|
@ -2,6 +2,73 @@
|
||||||
|
|
||||||
## 1.*
|
## 1.*
|
||||||
|
|
||||||
|
### 1.6.*
|
||||||
|
|
||||||
|
#### 1.6.0 - 24-09-23 - Roundtrip JSON Serialization
|
||||||
|
|
||||||
|
Roundtrip JSON serialization is here - with serialization to list of lists,
|
||||||
|
as well as file references that don't require copying the whole array if
|
||||||
|
used in data modeling, control over path relativization, and stamping of
|
||||||
|
interface version for the extra provenance conscious.
|
||||||
|
|
||||||
|
Please see [serialization](./serialization.md) for narrative documentation :)
|
||||||
|
|
||||||
|
**Potentially Breaking Changes**
|
||||||
|
- See [development](./development.md) for a statement about API stability
|
||||||
|
- An additional {meth}`.Interface.deserialize` method has been added to
|
||||||
|
{meth}`.Interface.validate` - downstream users are not intended to override the
|
||||||
|
`validate method`, but if they have, then JSON deserialization will not work for them.
|
||||||
|
- `Interface` subclasses now require a `name` attribute, a short string identifier for that interface,
|
||||||
|
and a `json_model` that inherits from {class}`.interface.JsonDict`. Interfaces without
|
||||||
|
these attributes will not be able to be instantiated.
|
||||||
|
- {meth}`.Interface.to_json` is now an abstract method that all interfaces must define.
|
||||||
|
|
||||||
|
**Features**
|
||||||
|
- Roundtrip JSON serialization - by default dump to a list of list arrays, but
|
||||||
|
support the `round_trip` keyword in `model_dump_json` for provenance-preserving dumps
|
||||||
|
- JSON Schema generation has been separated from `core_schema` generation in {class}`.NDArray`.
|
||||||
|
Downstream interfaces can customize json schema generation without compromising ability to validate.
|
||||||
|
- All proxy classes must have an `__eq__` dunder method to compare equality -
|
||||||
|
in proxy classes, these compare equality of arguments, since the arrays that
|
||||||
|
are referenced on disk should be equal by definition. Direct array comparison
|
||||||
|
should use {func}`numpy.array_equal`
|
||||||
|
- Interfaces previously couldn't be instantiated without explicit shape and dtype arguments,
|
||||||
|
these have been given `Any` defaults.
|
||||||
|
- New {mod}`numpydantic.serialization` module to contain serialization logic.
|
||||||
|
|
||||||
|
**New Classes**
|
||||||
|
See the docstrings for descriptions of each class
|
||||||
|
- `MarkMismatchError` for when an array serialized with `mark_interface` doesn't match
|
||||||
|
the interface that's deserializing it
|
||||||
|
- {class}`.interface.InterfaceMark`
|
||||||
|
- {class}`.interface.MarkedJson`
|
||||||
|
- {class}`.interface.JsonDict`
|
||||||
|
- {class}`.dask.DaskJsonDict`
|
||||||
|
- {class}`.hdf5.H5JsonDict`
|
||||||
|
- {class}`.numpy.NumpyJsonDict`
|
||||||
|
- {class}`.video.VideoJsonDict`
|
||||||
|
- {class}`.zarr.ZarrJsonDict`
|
||||||
|
|
||||||
|
**Bugfix**
|
||||||
|
- [`#17`](https://github.com/p2p-ld/numpydantic/issues/17) - Arrays are re-validated as lists, rather than arrays
|
||||||
|
- Some proxy classes would fail to be serialized becauase they lacked an `__array__` method.
|
||||||
|
`__array__` methods have been added, and tests for coercing to an array to prevent regression.
|
||||||
|
- Some proxy classes lacked a `__name__` attribute, which caused failures to serialize
|
||||||
|
when the `__getattr__` methods attempted to pass it through. These have been added where needed.
|
||||||
|
|
||||||
|
**Docs**
|
||||||
|
- Add statement about versioning and API stability to [development](./development.md)
|
||||||
|
- Add docs for serialization!
|
||||||
|
- Remove stranded docs from hooks and monkeypatch
|
||||||
|
- Added `myst_nb` to docs dependencies for direct rendering of code and output
|
||||||
|
|
||||||
|
**Tests**
|
||||||
|
- Marks have been added for running subsets of the tests for a given interface,
|
||||||
|
package feature, etc.
|
||||||
|
- Tests for all the above functionality
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### 1.5.*
|
### 1.5.*
|
||||||
|
|
||||||
#### 1.5.3 - 24-09-03 - Bugfix, type checking for empty HDF5 datasets
|
#### 1.5.3 - 24-09-03 - Bugfix, type checking for empty HDF5 datasets
|
||||||
|
|
|
@ -25,7 +25,7 @@ extensions = [
|
||||||
"sphinx.ext.doctest",
|
"sphinx.ext.doctest",
|
||||||
"sphinx_design",
|
"sphinx_design",
|
||||||
"sphinxcontrib.mermaid",
|
"sphinxcontrib.mermaid",
|
||||||
"myst_parser",
|
"myst_nb",
|
||||||
"sphinx.ext.todo",
|
"sphinx.ext.todo",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -77,3 +77,8 @@ napoleon_attr_annotations = True
|
||||||
# todo
|
# todo
|
||||||
todo_include_todos = True
|
todo_include_todos = True
|
||||||
todo_link_only = True
|
todo_link_only = True
|
||||||
|
|
||||||
|
# myst
|
||||||
|
# myst-nb
|
||||||
|
nb_render_markdown_format = "myst"
|
||||||
|
nb_execution_show_tb = True
|
||||||
|
|
BIN
docs/data/test.avi
Normal file
BIN
docs/data/test.avi
Normal file
Binary file not shown.
BIN
docs/data/test.h5
Normal file
BIN
docs/data/test.h5
Normal file
Binary file not shown.
22
docs/data/test.zarr/.zarray
Normal file
22
docs/data/test.zarr/.zarray
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
{
|
||||||
|
"chunks": [
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"compressor": {
|
||||||
|
"blocksize": 0,
|
||||||
|
"clevel": 5,
|
||||||
|
"cname": "lz4",
|
||||||
|
"id": "blosc",
|
||||||
|
"shuffle": 1
|
||||||
|
},
|
||||||
|
"dtype": "<i8",
|
||||||
|
"fill_value": 0,
|
||||||
|
"filters": null,
|
||||||
|
"order": "C",
|
||||||
|
"shape": [
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"zarr_format": 2
|
||||||
|
}
|
BIN
docs/data/test.zarr/0.0
Normal file
BIN
docs/data/test.zarr/0.0
Normal file
Binary file not shown.
84
docs/development.md
Normal file
84
docs/development.md
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# Development
|
||||||
|
|
||||||
|
## Versioning
|
||||||
|
|
||||||
|
This package uses a colloquial form of [semantic versioning 2](https://semver.org/).
|
||||||
|
|
||||||
|
Specifically:
|
||||||
|
|
||||||
|
- Major version `2.*.*` is reserved for the transition from nptyping to using
|
||||||
|
`TypeVarTuple`, `Generic`, and `Protocol`. Until `2.*.*`...
|
||||||
|
- breaking changes will be indicated with an advance in `MINOR`
|
||||||
|
version, taking the place of `MAJOR` in semver
|
||||||
|
- backwards-compatible bugfixes **and** additions in functionality
|
||||||
|
will be indicated by a `PATCH` release, taking the place of `MINOR` and
|
||||||
|
`PATCH` in semver.
|
||||||
|
- After `2.*.*`, semver as usual will resume
|
||||||
|
|
||||||
|
You are encouraged to set an upper bound on your version dependencies until
|
||||||
|
we pass `2.*.*`, as the major function of numpydantic is stable,
|
||||||
|
but there is still a decent amount of jostling things around to be expected.
|
||||||
|
|
||||||
|
|
||||||
|
### API Stability
|
||||||
|
|
||||||
|
- All breaking changes to the **public API** will be signaled by a major
|
||||||
|
version's worth of deprecation warnings
|
||||||
|
- All breaking changes to the **development API** will be signaled by a
|
||||||
|
minor version's worth of deprecation warnings.
|
||||||
|
- Changes to the remainder of the package, whether marked as private with a
|
||||||
|
leading underscore or not, including the import structure of the package,
|
||||||
|
are not considered part of the API and should not be relied on as stable
|
||||||
|
until explicitly marked otherwise.
|
||||||
|
|
||||||
|
#### Public API
|
||||||
|
|
||||||
|
**Only the {class}`.NDArray` and {class}`.Shape` classes should be considered
|
||||||
|
part of the stable public API.**
|
||||||
|
|
||||||
|
All associated functionality for validation should also be considered
|
||||||
|
a stable part of the `NDArray` and `Shape` classes - functionality
|
||||||
|
will only be added here, and the departure for the string-form of the
|
||||||
|
shape specifications (and its removal) will take place in `v3.*.*`
|
||||||
|
|
||||||
|
End-users of numpydantic should pin an upper bound for the `MAJOR` version
|
||||||
|
until after `v2.*.*`, after which time it is up to your discretion -
|
||||||
|
no breaking changes are planned, but they would be signaled by a major version change.
|
||||||
|
|
||||||
|
#### Development API
|
||||||
|
|
||||||
|
**Only the {class}`.Interface` class and its subclasses,
|
||||||
|
along with the Public API,
|
||||||
|
should be considered part of the stable development API.**
|
||||||
|
|
||||||
|
The `Interface` class is the primary point of external development expected
|
||||||
|
for numpydantic. It is still somewhat in flux, but it is prioritized for stability
|
||||||
|
and deprecation warnings above the rest of the package.
|
||||||
|
|
||||||
|
Dependent packages that define their own `Interface`s should pin an upper
|
||||||
|
bound for the `PATCH` version until `2.*.*`, and afterwards likely pin a `MINOR` version.
|
||||||
|
Tests are designed such that it should be easy to test major features against
|
||||||
|
each interface, and that work is also ongoing. Once the test suite reaches
|
||||||
|
maturity, it should be possible for any downstream interfaces to simply use those to
|
||||||
|
ensure they are compatible with the latest version.
|
||||||
|
|
||||||
|
## Release Schedule
|
||||||
|
|
||||||
|
There is no release schedule. Versions are released according to need and available labor.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
### Dev environment
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Document dev environment
|
||||||
|
|
||||||
|
Really it's very simple, you just clone a fork and install
|
||||||
|
the `dev` environment like `pip install '.[dev]'`
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pull Requests
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Document pull requests if we ever receive one
|
||||||
|
```
|
|
@ -1,11 +0,0 @@
|
||||||
# Hooks
|
|
||||||
|
|
||||||
What hooks do we want to expose to downstream users so they can use this without needing
|
|
||||||
to override everything?
|
|
||||||
|
|
||||||
```{todo}
|
|
||||||
**NWB Compatibility**
|
|
||||||
|
|
||||||
**Precision:** NWB allows for a sort of hierarchy of type specification -
|
|
||||||
a less precise type also allows the data to be specified in a more precise type
|
|
||||||
```
|
|
|
@ -86,8 +86,8 @@ isinstance(np.zeros((1,2,3), dtype=float), array_type)
|
||||||
and a simple extension system to make it work with whatever else you want! Provides
|
and a simple extension system to make it work with whatever else you want! Provides
|
||||||
a uniform and transparent interface so you can both use common indexing operations
|
a uniform and transparent interface so you can both use common indexing operations
|
||||||
and also access any special features of a given array library.
|
and also access any special features of a given array library.
|
||||||
- **Serialization** - Dump an array as a JSON-compatible array-of-arrays with enough metadata to be able to
|
- [**Serialization**](./serialization.md) - Dump an array as a JSON-compatible array-of-arrays with enough metadata to be able to
|
||||||
recreate the model in the native format
|
recreate the model in the native format. Full roundtripping is supported :)
|
||||||
- **Schema Generation** - Correct JSON Schema for arrays, complete with shape and dtype constraints, to
|
- **Schema Generation** - Correct JSON Schema for arrays, complete with shape and dtype constraints, to
|
||||||
make your models interoperable
|
make your models interoperable
|
||||||
|
|
||||||
|
@ -473,9 +473,8 @@ dumped = instance.model_dump_json(context={'zarr_dump_array': True})
|
||||||
|
|
||||||
design
|
design
|
||||||
syntax
|
syntax
|
||||||
|
serialization
|
||||||
interfaces
|
interfaces
|
||||||
todo
|
|
||||||
changelog
|
|
||||||
```
|
```
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
|
@ -489,13 +488,23 @@ api/dtype
|
||||||
api/ndarray
|
api/ndarray
|
||||||
api/maps
|
api/maps
|
||||||
api/meta
|
api/meta
|
||||||
api/monkeypatch
|
|
||||||
api/schema
|
api/schema
|
||||||
|
api/serialization
|
||||||
api/shape
|
api/shape
|
||||||
api/types
|
api/types
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: Meta
|
||||||
|
:hidden: true
|
||||||
|
|
||||||
|
changelog
|
||||||
|
development
|
||||||
|
todo
|
||||||
|
```
|
||||||
|
|
||||||
## See Also
|
## See Also
|
||||||
|
|
||||||
- [`jaxtyping`](https://docs.kidger.site/jaxtyping/)
|
- [`jaxtyping`](https://docs.kidger.site/jaxtyping/)
|
||||||
|
|
|
@ -46,6 +46,11 @@ for interfaces to implement custom behavior that matches the array format.
|
||||||
|
|
||||||
{meth}`.Interface.validate` calls the following methods, in order:
|
{meth}`.Interface.validate` calls the following methods, in order:
|
||||||
|
|
||||||
|
A method to deserialize the array dumped with a {func}`~pydantic.BaseModel.model_dump_json`
|
||||||
|
with `round_trip = True` (see [serialization](./serialization.md))
|
||||||
|
|
||||||
|
- {meth}`.Interface.deserialize`
|
||||||
|
|
||||||
An initial hook for modifying the input data before validation, eg.
|
An initial hook for modifying the input data before validation, eg.
|
||||||
if it needs to be coerced or wrapped in some proxy class. This method
|
if it needs to be coerced or wrapped in some proxy class. This method
|
||||||
should accept all and only the types specified in that interface's
|
should accept all and only the types specified in that interface's
|
||||||
|
|
332
docs/serialization.md
Normal file
332
docs/serialization.md
Normal file
|
@ -0,0 +1,332 @@
|
||||||
|
---
|
||||||
|
file_format: mystnb
|
||||||
|
mystnb:
|
||||||
|
output_stderr: remove
|
||||||
|
render_text_lexer: python
|
||||||
|
render_markdown_format: myst
|
||||||
|
myst:
|
||||||
|
enable_extensions: ["colon_fence"]
|
||||||
|
---
|
||||||
|
|
||||||
|
# Serialization
|
||||||
|
|
||||||
|
## Python
|
||||||
|
|
||||||
|
In most cases, dumping to python should work as expected.
|
||||||
|
|
||||||
|
When a given array framework doesn't provide a tidy means of interacting
|
||||||
|
with it from python, we substitute a proxy class like {class}`.hdf5.H5Proxy`,
|
||||||
|
but aside from that numpydantic {class}`.NDArray` annotations
|
||||||
|
should be passthrough when using {func}`~pydantic.BaseModel.model_dump` .
|
||||||
|
|
||||||
|
## JSON
|
||||||
|
|
||||||
|
JSON is the ~ ♥ fun one ♥ ~
|
||||||
|
|
||||||
|
There isn't necessarily a single optimal way to represent all possible
|
||||||
|
arrays in JSON. The standard way that n-dimensional arrays are rendered
|
||||||
|
in json is as a list-of-lists (or array of arrays, in JSON parlance),
|
||||||
|
but that's almost never what is desirable, especially for large arrays.
|
||||||
|
|
||||||
|
### Normal Style[^normalstyle]
|
||||||
|
|
||||||
|
Lists-of-lists are the standard, however, so it is the default behavior
|
||||||
|
for all interfaces, and all interfaces must support it.
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
---
|
||||||
|
tags: [hide-cell]
|
||||||
|
---
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from numpydantic import NDArray, Shape
|
||||||
|
from numpydantic.interface.dask import DaskJsonDict
|
||||||
|
from numpydantic.interface.numpy import NumpyJsonDict
|
||||||
|
import numpy as np
|
||||||
|
import dask.array as da
|
||||||
|
import zarr
|
||||||
|
import json
|
||||||
|
from rich import print
|
||||||
|
from rich.console import Console
|
||||||
|
|
||||||
|
def print_json(string:str):
|
||||||
|
data = json.loads(string)
|
||||||
|
console = Console(width=74)
|
||||||
|
console.print(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
For our humble model:
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray
|
||||||
|
```
|
||||||
|
|
||||||
|
We should get the same thing for each interface:
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=[[1,2],[3,4]])
|
||||||
|
print(model.model_dump_json())
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=da.array([[1,2],[3,4]], dtype=int))
|
||||||
|
print(model.model_dump_json())
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=zarr.array([[1,2],[3,4]], dtype=int))
|
||||||
|
print(model.model_dump_json())
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array="data/test.avi")
|
||||||
|
print(model.model_dump_json())
|
||||||
|
```
|
||||||
|
|
||||||
|
(ok maybe not that last one, since the video reader still incorrectly
|
||||||
|
reads grayscale videos as BGR values for now, but you get the idea)
|
||||||
|
|
||||||
|
Since by default arrays are dumped into unadorned JSON arrays,
|
||||||
|
when they are re-validated, they will always be handled by the
|
||||||
|
{class}`.NumpyInterface`
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
dask_array = da.array([[1,2],[3,4]], dtype=int)
|
||||||
|
model = MyModel(array=dask_array)
|
||||||
|
type(model.array)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model_json = model.model_dump_json()
|
||||||
|
deserialized_model = MyModel.model_validate_json(model_json)
|
||||||
|
type(deserialized_model.array)
|
||||||
|
```
|
||||||
|
|
||||||
|
All information about `dtype` will be lost, and numbers will be either parsed
|
||||||
|
as `int` ({class}`numpy.int64`) or `float` ({class}`numpy.float64`)
|
||||||
|
|
||||||
|
## Roundtripping
|
||||||
|
|
||||||
|
To roundtrip make arrays round-trippable, use the `round_trip` argument
|
||||||
|
to {func}`~pydantic.BaseModel.model_dump_json`.
|
||||||
|
|
||||||
|
All the following should return an equivalent array from the same
|
||||||
|
file/etc. as the source array when using
|
||||||
|
`{func}`~pydantic.BaseModel.model_validate_json`` .
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
print_json(model.model_dump_json(round_trip=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
Each interface must implement a dataclass that describes a
|
||||||
|
json-able roundtrip form (see {class}`.interface.JsonDict`).
|
||||||
|
|
||||||
|
That dataclass then has a {meth}`JsonDict.is_valid` method that checks
|
||||||
|
whether an incoming dict matches its schema
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
roundtrip_json = json.loads(model.model_dump_json(round_trip=True))['array']
|
||||||
|
DaskJsonDict.is_valid(roundtrip_json)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
NumpyJsonDict.is_valid(roundtrip_json)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Controlling paths
|
||||||
|
|
||||||
|
When possible, the full content of the array is omitted in favor
|
||||||
|
of the path to the file that provided it.
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array="data/test.avi")
|
||||||
|
print_json(model.model_dump_json(round_trip=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=("data/test.h5", "/data"))
|
||||||
|
print_json(model.model_dump_json(round_trip=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
You may notice the relative, rather than absolute paths.
|
||||||
|
|
||||||
|
|
||||||
|
We expect that when people are dumping data to json in this roundtripped
|
||||||
|
form that they are either working locally
|
||||||
|
(e.g. transmitting an array specification across a socket in multiprocessing
|
||||||
|
or in a computing cluster),
|
||||||
|
or exporting to some directory structure of data,
|
||||||
|
where they are making an index file that refers to datasets in a directory
|
||||||
|
as part of a data standard or vernacular format.
|
||||||
|
|
||||||
|
By default, numpydantic uses the current working directory as the root to find
|
||||||
|
paths relative to, but this can be controlled by the [`relative_to`](#relative_to)
|
||||||
|
context parameter:
|
||||||
|
|
||||||
|
For example if you're working on data in many subdirectories,
|
||||||
|
you might want to serialize relative to each of them:
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
print_json(
|
||||||
|
model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"relative_to": Path('./data')}
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
Or in the other direction:
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
print_json(
|
||||||
|
model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"relative_to": Path('../')}
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
Or you might be working in some completely different place,
|
||||||
|
numpydantic will try and find the way from here to there as long as it exists,
|
||||||
|
even if it means traversing to the root of the readthedocs filesystem
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
print_json(
|
||||||
|
model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"relative_to": Path('/a/long/distance/directory')}
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
You can force absolute paths with the `absolute_paths` context parameter
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
print_json(
|
||||||
|
model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"absolute_paths": True}
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Durable Interface Metadata
|
||||||
|
|
||||||
|
Numpydantic tries to be [stable](./development.md#api-stability),
|
||||||
|
but we're not perfect. To preserve the full information about the
|
||||||
|
interface that's needed to load the data referred to by the value,
|
||||||
|
use the `mark_interface` contest parameter:
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
print_json(
|
||||||
|
model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"mark_interface": True}
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
When an array marked with the interface is deserialized,
|
||||||
|
it short-circuits the {meth}`.Interface.match` method,
|
||||||
|
attempting to directly return the indicated interface as long as the
|
||||||
|
array dumped in `value` still satisfies that interface's {meth}`.Interface.check`
|
||||||
|
method. Arrays dumped *without* `round_trip=True` might *not* validate with
|
||||||
|
the originating model, even when marked -- eg. an array dumped without `round_trip`
|
||||||
|
will be revalidated as a numpy array for the same reasons it is everywhere else,
|
||||||
|
since all connection to the source file is lost.
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
Currently, the version of the package the interface is from (usually `numpydantic`)
|
||||||
|
will be stored, but there is no means of resolving it on the fly.
|
||||||
|
If there is a mismatch between the marked interface description and the interface
|
||||||
|
that was matched on revalidation, a warning is emitted, but validation
|
||||||
|
attempts to proceed as normal.
|
||||||
|
|
||||||
|
This feature is for extra-verbose provenance, rather than airtight serialization
|
||||||
|
and deserialization, but PRs welcome if you would like to make it be that way.
|
||||||
|
```
|
||||||
|
|
||||||
|
```{todo}
|
||||||
|
We will also add a separate `mark_version` parameter for marking
|
||||||
|
the specific version of the relevant interface package, like `zarr`, or `numpy`,
|
||||||
|
patience.
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Context parameters
|
||||||
|
|
||||||
|
A reference listing of all the things that can be passed to
|
||||||
|
{func}`~pydantic.BaseModel.model_dump_json`
|
||||||
|
|
||||||
|
|
||||||
|
### `mark_interface`
|
||||||
|
|
||||||
|
Nest an additional layer of metadata for unambigous serialization that
|
||||||
|
can be absolutely resolved across numpydantic versions
|
||||||
|
(for now for downstream metadata purposes only,
|
||||||
|
automatically resolving to a numpydantic version is not yet possible.)
|
||||||
|
|
||||||
|
Supported interfaces:
|
||||||
|
|
||||||
|
- (all)
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=[[1,2],[3,4]])
|
||||||
|
data = model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"mark_interface": True}
|
||||||
|
)
|
||||||
|
print_json(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `absolute_paths`
|
||||||
|
|
||||||
|
Make all paths (that exist) absolute.
|
||||||
|
|
||||||
|
Supported interfaces:
|
||||||
|
|
||||||
|
- (all)
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=("data/test.h5", "/data"))
|
||||||
|
data = model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"absolute_paths": True}
|
||||||
|
)
|
||||||
|
print_json(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `relative_to`
|
||||||
|
|
||||||
|
Make all paths (that exist) relative to the given path
|
||||||
|
|
||||||
|
Supported interfaces:
|
||||||
|
|
||||||
|
- (all)
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=("data/test.h5", "/data"))
|
||||||
|
data = model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"relative_to": Path('../')}
|
||||||
|
)
|
||||||
|
print_json(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
### `dump_array`
|
||||||
|
|
||||||
|
Dump the raw array contents when serializing to json inside an `array` field
|
||||||
|
|
||||||
|
Supported interfaces:
|
||||||
|
- {class}`.ZarrInterface`
|
||||||
|
|
||||||
|
```{code-cell}
|
||||||
|
model = MyModel(array=("data/test.zarr",))
|
||||||
|
data = model.model_dump_json(
|
||||||
|
round_trip=True,
|
||||||
|
context={"dump_array": True}
|
||||||
|
)
|
||||||
|
print_json(data)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[^normalstyle]: o ya we're posting JSON [normal style](https://normal.style)
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "numpydantic"
|
name = "numpydantic"
|
||||||
version = "1.5.3"
|
version = "1.6.0"
|
||||||
description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
|
description = "Type and shape validation and serialization for arbitrary array types in pydantic models"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
{name = "sneakers-the-rat", email = "sneakers-the-rat@protonmail.com"},
|
||||||
|
@ -73,12 +73,15 @@ tests = [
|
||||||
"coveralls<4.0.0,>=3.3.1",
|
"coveralls<4.0.0,>=3.3.1",
|
||||||
]
|
]
|
||||||
docs = [
|
docs = [
|
||||||
|
"numpydantic[arrays]",
|
||||||
"sphinx<8.0.0,>=7.2.6",
|
"sphinx<8.0.0,>=7.2.6",
|
||||||
"furo>=2024.1.29",
|
"furo>=2024.1.29",
|
||||||
"myst-parser<3.0.0,>=2.0.0",
|
"myst-parser<3.0.0,>=2.0.0",
|
||||||
"autodoc-pydantic<3.0.0,>=2.0.1",
|
"autodoc-pydantic<3.0.0,>=2.0.1",
|
||||||
"sphinx-design<1.0.0,>=0.5.0",
|
"sphinx-design<1.0.0,>=0.5.0",
|
||||||
"sphinxcontrib-mermaid>=0.9.2",
|
"sphinxcontrib-mermaid>=0.9.2",
|
||||||
|
"myst-nb>=1.1.1",
|
||||||
|
"rich>=13.8.1",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"numpydantic[tests,docs]",
|
"numpydantic[tests,docs]",
|
||||||
|
@ -109,6 +112,18 @@ filterwarnings = [
|
||||||
# nptyping's alias warnings
|
# nptyping's alias warnings
|
||||||
'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*'
|
'ignore:.*deprecated alias.*Deprecated NumPy 1\.24.*'
|
||||||
]
|
]
|
||||||
|
markers = [
|
||||||
|
"dtype: mark test related to dtype validation",
|
||||||
|
"shape: mark test related to shape validation",
|
||||||
|
"json_schema: mark test related to json schema generation",
|
||||||
|
"serialization: mark test related to serialization",
|
||||||
|
"proxy: test for proxy class in any interface",
|
||||||
|
"dask: dask interface",
|
||||||
|
"hdf5: hdf5 interface",
|
||||||
|
"numpy: numpy interface",
|
||||||
|
"video: video interface",
|
||||||
|
"zarr: zarr interface",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
|
|
|
@ -25,3 +25,7 @@ class NoMatchError(MatchError):
|
||||||
|
|
||||||
class TooManyMatchesError(MatchError):
|
class TooManyMatchesError(MatchError):
|
||||||
"""Too many matches found by :class:`.Interface.match`"""
|
"""Too many matches found by :class:`.Interface.match`"""
|
||||||
|
|
||||||
|
|
||||||
|
class MarkMismatchError(MatchError):
|
||||||
|
"""A serialized :class:`.InterfaceMark` doesn't match the receiving interface"""
|
||||||
|
|
|
@ -4,15 +4,23 @@ Interfaces between nptyping types and array backends
|
||||||
|
|
||||||
from numpydantic.interface.dask import DaskInterface
|
from numpydantic.interface.dask import DaskInterface
|
||||||
from numpydantic.interface.hdf5 import H5Interface
|
from numpydantic.interface.hdf5 import H5Interface
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import (
|
||||||
|
Interface,
|
||||||
|
InterfaceMark,
|
||||||
|
JsonDict,
|
||||||
|
MarkedJson,
|
||||||
|
)
|
||||||
from numpydantic.interface.numpy import NumpyInterface
|
from numpydantic.interface.numpy import NumpyInterface
|
||||||
from numpydantic.interface.video import VideoInterface
|
from numpydantic.interface.video import VideoInterface
|
||||||
from numpydantic.interface.zarr import ZarrInterface
|
from numpydantic.interface.zarr import ZarrInterface
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Interface",
|
|
||||||
"DaskInterface",
|
"DaskInterface",
|
||||||
"H5Interface",
|
"H5Interface",
|
||||||
|
"Interface",
|
||||||
|
"InterfaceMark",
|
||||||
|
"JsonDict",
|
||||||
|
"MarkedJson",
|
||||||
"NumpyInterface",
|
"NumpyInterface",
|
||||||
"VideoInterface",
|
"VideoInterface",
|
||||||
"ZarrInterface",
|
"ZarrInterface",
|
||||||
|
|
|
@ -2,34 +2,73 @@
|
||||||
Interface for Dask arrays
|
Interface for Dask arrays
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Iterable, List, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface, JsonDict
|
||||||
from numpydantic.types import DtypeType, NDArrayType
|
from numpydantic.types import DtypeType, NDArrayType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from dask.array import from_array
|
||||||
from dask.array.core import Array as DaskArray
|
from dask.array.core import Array as DaskArray
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
DaskArray = None
|
DaskArray = None
|
||||||
|
|
||||||
|
|
||||||
|
def _as_tuple(a_list: Any) -> tuple:
|
||||||
|
"""Make a list of list into a tuple of tuples"""
|
||||||
|
return tuple(
|
||||||
|
[_as_tuple(item) if isinstance(item, list) else item for item in a_list]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DaskJsonDict(JsonDict):
|
||||||
|
"""
|
||||||
|
Round-trip json serialized form of a dask array
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["dask"]
|
||||||
|
name: str
|
||||||
|
chunks: Iterable[tuple[int, ...]]
|
||||||
|
dtype: str
|
||||||
|
array: list
|
||||||
|
|
||||||
|
def to_array_input(self) -> DaskArray:
|
||||||
|
"""Construct a dask array"""
|
||||||
|
np_array = np.array(self.array, dtype=self.dtype)
|
||||||
|
array = from_array(
|
||||||
|
np_array,
|
||||||
|
name=self.name,
|
||||||
|
chunks=_as_tuple(self.chunks),
|
||||||
|
)
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
class DaskInterface(Interface):
|
class DaskInterface(Interface):
|
||||||
"""
|
"""
|
||||||
Interface for Dask :class:`~dask.array.core.Array`
|
Interface for Dask :class:`~dask.array.core.Array`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_types = (DaskArray,)
|
name = "dask"
|
||||||
|
input_types = (DaskArray, dict)
|
||||||
return_type = DaskArray
|
return_type = DaskArray
|
||||||
|
json_model = DaskJsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check(cls, array: Any) -> bool:
|
def check(cls, array: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
check if array is a dask array
|
check if array is a dask array
|
||||||
"""
|
"""
|
||||||
return DaskArray is not None and isinstance(array, DaskArray)
|
if DaskArray is None: # pragma: no cover - no tests for interface deps atm
|
||||||
|
return False
|
||||||
|
elif isinstance(array, DaskArray):
|
||||||
|
return True
|
||||||
|
elif isinstance(array, dict):
|
||||||
|
return DaskJsonDict.is_valid(array)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
def get_object_dtype(self, array: NDArrayType) -> DtypeType:
|
||||||
"""Dask arrays require a compute() call to retrieve a single value"""
|
"""Dask arrays require a compute() call to retrieve a single value"""
|
||||||
|
@ -43,7 +82,7 @@ class DaskInterface(Interface):
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_json(
|
def to_json(
|
||||||
cls, array: DaskArray, info: Optional[SerializationInfo] = None
|
cls, array: DaskArray, info: Optional[SerializationInfo] = None
|
||||||
) -> list:
|
) -> Union[List, DaskJsonDict]:
|
||||||
"""
|
"""
|
||||||
Convert an array to a JSON serializable array by first converting to a numpy
|
Convert an array to a JSON serializable array by first converting to a numpy
|
||||||
array and then to a list.
|
array and then to a list.
|
||||||
|
@ -56,4 +95,14 @@ class DaskInterface(Interface):
|
||||||
method of serialization here using the python object itself rather than
|
method of serialization here using the python object itself rather than
|
||||||
its JSON representation.
|
its JSON representation.
|
||||||
"""
|
"""
|
||||||
return np.array(array).tolist()
|
np_array = np.array(array)
|
||||||
|
as_json = np_array.tolist()
|
||||||
|
if info.round_trip:
|
||||||
|
as_json = DaskJsonDict(
|
||||||
|
type=cls.name,
|
||||||
|
array=as_json,
|
||||||
|
name=array.name,
|
||||||
|
chunks=array.chunks,
|
||||||
|
dtype=str(np_array.dtype),
|
||||||
|
)
|
||||||
|
return as_json
|
||||||
|
|
|
@ -47,7 +47,7 @@ from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Un
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface, JsonDict
|
||||||
from numpydantic.types import DtypeType, NDArrayType
|
from numpydantic.types import DtypeType, NDArrayType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -76,6 +76,20 @@ class H5ArrayPath(NamedTuple):
|
||||||
"""Refer to a specific field within a compound dtype"""
|
"""Refer to a specific field within a compound dtype"""
|
||||||
|
|
||||||
|
|
||||||
|
class H5JsonDict(JsonDict):
|
||||||
|
"""Round-trip Json-able version of an HDF5 dataset"""
|
||||||
|
|
||||||
|
file: str
|
||||||
|
path: str
|
||||||
|
field: Optional[str] = None
|
||||||
|
|
||||||
|
def to_array_input(self) -> H5ArrayPath:
|
||||||
|
"""Construct an :class:`.H5ArrayPath`"""
|
||||||
|
return H5ArrayPath(
|
||||||
|
**{k: v for k, v in self.model_dump().items() if k in H5ArrayPath._fields}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class H5Proxy:
|
class H5Proxy:
|
||||||
"""
|
"""
|
||||||
Proxy class to mimic numpy-like array behavior with an HDF5 array
|
Proxy class to mimic numpy-like array behavior with an HDF5 array
|
||||||
|
@ -106,10 +120,11 @@ class H5Proxy:
|
||||||
annotation_dtype: Optional[DtypeType] = None,
|
annotation_dtype: Optional[DtypeType] = None,
|
||||||
):
|
):
|
||||||
self._h5f = None
|
self._h5f = None
|
||||||
self.file = Path(file)
|
self.file = Path(file).resolve()
|
||||||
self.path = path
|
self.path = path
|
||||||
self.field = field
|
self.field = field
|
||||||
self._annotation_dtype = annotation_dtype
|
self._annotation_dtype = annotation_dtype
|
||||||
|
self._h5arraypath = H5ArrayPath(self.file, self.path, self.field)
|
||||||
|
|
||||||
def array_exists(self) -> bool:
|
def array_exists(self) -> bool:
|
||||||
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
"""Check that there is in fact an array at :attr:`.path` within :attr:`.file`"""
|
||||||
|
@ -134,10 +149,20 @@ class H5Proxy:
|
||||||
else:
|
else:
|
||||||
return obj.dtype[self.field]
|
return obj.dtype[self.field]
|
||||||
|
|
||||||
def __getattr__(self, item: str):
|
def __array__(self) -> np.ndarray:
|
||||||
|
"""To a numpy array"""
|
||||||
with h5py.File(self.file, "r") as h5f:
|
with h5py.File(self.file, "r") as h5f:
|
||||||
obj = h5f.get(self.path)
|
obj = h5f.get(self.path)
|
||||||
return getattr(obj, item)
|
return obj[:]
|
||||||
|
|
||||||
|
def __getattr__(self, item: str):
|
||||||
|
if item == "__name__":
|
||||||
|
# special case for H5Proxies that don't refer to a real file during testing
|
||||||
|
return "H5Proxy"
|
||||||
|
with h5py.File(self.file, "r") as h5f:
|
||||||
|
obj = h5f.get(self.path)
|
||||||
|
val = getattr(obj, item)
|
||||||
|
return val
|
||||||
|
|
||||||
def __getitem__(
|
def __getitem__(
|
||||||
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
|
self, item: Union[int, slice, Tuple[Union[int, slice], ...]]
|
||||||
|
@ -205,6 +230,15 @@ class H5Proxy:
|
||||||
"""self.shape[0]"""
|
"""self.shape[0]"""
|
||||||
return self.shape[0]
|
return self.shape[0]
|
||||||
|
|
||||||
|
def __eq__(self, other: "H5Proxy") -> bool:
|
||||||
|
"""
|
||||||
|
Check that we are referring to the same hdf5 array
|
||||||
|
"""
|
||||||
|
if isinstance(other, H5Proxy):
|
||||||
|
return self._h5arraypath == other._h5arraypath
|
||||||
|
else:
|
||||||
|
raise ValueError("Can only compare equality of two H5Proxies")
|
||||||
|
|
||||||
def open(self, mode: str = "r") -> "h5py.Dataset":
|
def open(self, mode: str = "r") -> "h5py.Dataset":
|
||||||
"""
|
"""
|
||||||
Return the opened :class:`h5py.Dataset` object
|
Return the opened :class:`h5py.Dataset` object
|
||||||
|
@ -244,8 +278,10 @@ class H5Interface(Interface):
|
||||||
passthrough numpy-like interface to the dataset.
|
passthrough numpy-like interface to the dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
name = "hdf5"
|
||||||
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
|
input_types = (H5ArrayPath, H5Arraylike, H5Proxy)
|
||||||
return_type = H5Proxy
|
return_type = H5Proxy
|
||||||
|
json_model = H5JsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
|
@ -261,6 +297,13 @@ class H5Interface(Interface):
|
||||||
if isinstance(array, (H5ArrayPath, H5Proxy)):
|
if isinstance(array, (H5ArrayPath, H5Proxy)):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if isinstance(array, dict):
|
||||||
|
if array.get("type", False) == cls.name:
|
||||||
|
return True
|
||||||
|
# continue checking if dict contains an hdf5 file
|
||||||
|
file = array.get("file", "")
|
||||||
|
array = (file, "")
|
||||||
|
|
||||||
if isinstance(array, (tuple, list)) and len(array) in (2, 3):
|
if isinstance(array, (tuple, list)) and len(array) in (2, 3):
|
||||||
# check that the first arg is an hdf5 file
|
# check that the first arg is an hdf5 file
|
||||||
try:
|
try:
|
||||||
|
@ -342,21 +385,27 @@ class H5Interface(Interface):
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
|
def to_json(cls, array: H5Proxy, info: Optional[SerializationInfo] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Dump to a dictionary containing
|
Render HDF5 array as JSON
|
||||||
|
|
||||||
|
If ``round_trip == True``, we dump just the proxy info, a dictionary like:
|
||||||
|
|
||||||
* ``file``: :attr:`.file`
|
* ``file``: :attr:`.file`
|
||||||
* ``path``: :attr:`.path`
|
* ``path``: :attr:`.path`
|
||||||
* ``attrs``: Any HDF5 attributes on the dataset
|
* ``attrs``: Any HDF5 attributes on the dataset
|
||||||
* ``array``: The array as a list of lists
|
* ``array``: The array as a list of lists
|
||||||
|
|
||||||
|
Otherwise, we dump the array as a list of lists
|
||||||
"""
|
"""
|
||||||
try:
|
if info.round_trip:
|
||||||
dset = array.open()
|
as_json = {
|
||||||
meta = {
|
"type": cls.name,
|
||||||
"file": array.file,
|
|
||||||
"path": array.path,
|
|
||||||
"attrs": dict(dset.attrs),
|
|
||||||
"array": dset[:].tolist(),
|
|
||||||
}
|
}
|
||||||
return meta
|
as_json.update(array._h5arraypath._asdict())
|
||||||
finally:
|
else:
|
||||||
array.close()
|
try:
|
||||||
|
dset = array.open()
|
||||||
|
as_json = dset[:].tolist()
|
||||||
|
finally:
|
||||||
|
array.close()
|
||||||
|
|
||||||
|
return as_json
|
||||||
|
|
|
@ -2,15 +2,20 @@
|
||||||
Base Interface metaclass
|
Base Interface metaclass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from functools import lru_cache
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import BaseModel, SerializationInfo, ValidationError
|
||||||
|
|
||||||
from numpydantic.exceptions import (
|
from numpydantic.exceptions import (
|
||||||
DtypeError,
|
DtypeError,
|
||||||
|
MarkMismatchError,
|
||||||
NoMatchError,
|
NoMatchError,
|
||||||
ShapeError,
|
ShapeError,
|
||||||
TooManyMatchesError,
|
TooManyMatchesError,
|
||||||
|
@ -19,6 +24,130 @@ from numpydantic.shape import check_shape
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
|
|
||||||
T = TypeVar("T", bound=NDArrayType)
|
T = TypeVar("T", bound=NDArrayType)
|
||||||
|
U = TypeVar("U", bound="JsonDict")
|
||||||
|
V = TypeVar("V") # input type
|
||||||
|
W = TypeVar("W") # Any type in handle_input
|
||||||
|
|
||||||
|
|
||||||
|
class InterfaceMark(BaseModel):
|
||||||
|
"""JSON-able mark to be able to round-trip json dumps"""
|
||||||
|
|
||||||
|
module: str
|
||||||
|
cls: str
|
||||||
|
name: str
|
||||||
|
version: str
|
||||||
|
|
||||||
|
def is_valid(self, cls: Type["Interface"], raise_on_error: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Check that a given interface matches the mark.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls (Type): Interface type to check
|
||||||
|
raise_on_error (bool): Raise an ``MarkMismatchError`` when the match
|
||||||
|
is incorrect
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
:class:`.MarkMismatchError` if requested by ``raise_on_error``
|
||||||
|
for an invalid match
|
||||||
|
"""
|
||||||
|
mark = cls.mark_interface()
|
||||||
|
valid = self == mark
|
||||||
|
if not valid and raise_on_error:
|
||||||
|
raise MarkMismatchError(
|
||||||
|
"Mismatch between serialized mark and current interface, "
|
||||||
|
f"Serialized: {self}; current: {cls}"
|
||||||
|
)
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def match_by_name(self) -> Optional[Type["Interface"]]:
|
||||||
|
"""
|
||||||
|
Try to find a matching interface by its name, returning it if found,
|
||||||
|
or None if not found.
|
||||||
|
"""
|
||||||
|
for i in Interface.interfaces(sort=False):
|
||||||
|
if i.name == self.name:
|
||||||
|
return i
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class JsonDict(BaseModel):
|
||||||
|
"""
|
||||||
|
Representation of array when dumped with round_trip == True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_array_input(self) -> V:
|
||||||
|
"""
|
||||||
|
Convert this roundtrip specifier to the relevant input class
|
||||||
|
(one of the ``input_types`` of an interface).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid(cls, val: dict, raise_on_error: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether a given dictionary matches this JsonDict specification
|
||||||
|
|
||||||
|
Args:
|
||||||
|
val (dict): The dictionary to check for validity
|
||||||
|
raise_on_error (bool): If ``True``, raise the validation error
|
||||||
|
rather than returning a bool. (default: ``False``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool - true if valid, false if not
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_ = cls.model_validate(val)
|
||||||
|
return True
|
||||||
|
except ValidationError as e:
|
||||||
|
if raise_on_error:
|
||||||
|
raise e
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def handle_input(cls: Type[U], value: Union[dict, U, W]) -> Union[V, W]:
|
||||||
|
"""
|
||||||
|
Handle input that is the json serialized roundtrip version
|
||||||
|
(from :func:`~pydantic.BaseModel.model_dump` with ``round_trip=True``)
|
||||||
|
converting it to the input format with :meth:`.JsonDict.to_array_input`
|
||||||
|
or passing it through if not applicable
|
||||||
|
"""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = cls(**value).to_array_input()
|
||||||
|
elif isinstance(value, cls):
|
||||||
|
value = value.to_array_input()
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class MarkedJson(BaseModel):
|
||||||
|
"""
|
||||||
|
Model of JSON dumped with an additional interface mark
|
||||||
|
with ``model_dump_json({'mark_interface': True})``
|
||||||
|
"""
|
||||||
|
|
||||||
|
interface: InterfaceMark
|
||||||
|
value: Union[list, dict]
|
||||||
|
"""
|
||||||
|
Inner value of the array, we don't validate for JsonDict here,
|
||||||
|
that should be downstream from us for performance reasons
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def try_cast(cls, value: Union[V, dict]) -> Union[V, "MarkedJson"]:
|
||||||
|
"""
|
||||||
|
Try to cast to MarkedJson if applicable, otherwise return input
|
||||||
|
"""
|
||||||
|
if isinstance(value, dict) and "interface" in value and "value" in value:
|
||||||
|
try:
|
||||||
|
value = MarkedJson(**value)
|
||||||
|
except ValidationError:
|
||||||
|
# fine, just not a MarkedJson dict even if it looks like one
|
||||||
|
return value
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class Interface(ABC, Generic[T]):
|
class Interface(ABC, Generic[T]):
|
||||||
|
@ -30,7 +159,7 @@ class Interface(ABC, Generic[T]):
|
||||||
return_type: Type[T]
|
return_type: Type[T]
|
||||||
priority: int = 0
|
priority: int = 0
|
||||||
|
|
||||||
def __init__(self, shape: ShapeType, dtype: DtypeType) -> None:
|
def __init__(self, shape: ShapeType = Any, dtype: DtypeType = Any) -> None:
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
@ -40,6 +169,7 @@ class Interface(ABC, Generic[T]):
|
||||||
|
|
||||||
Calls the methods, in order:
|
Calls the methods, in order:
|
||||||
|
|
||||||
|
* array = :meth:`.deserialize` (array)
|
||||||
* array = :meth:`.before_validation` (array)
|
* array = :meth:`.before_validation` (array)
|
||||||
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
* dtype = :meth:`.get_dtype` (array) - get the dtype from the array,
|
||||||
override if eg. the dtype is not contained in ``array.dtype``
|
override if eg. the dtype is not contained in ``array.dtype``
|
||||||
|
@ -74,6 +204,8 @@ class Interface(ABC, Generic[T]):
|
||||||
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
:class:`.DtypeError` and :class:`.ShapeError` (both of which are children
|
||||||
of :class:`.InterfaceError` )
|
of :class:`.InterfaceError` )
|
||||||
"""
|
"""
|
||||||
|
array = self.deserialize(array)
|
||||||
|
|
||||||
array = self.before_validation(array)
|
array = self.before_validation(array)
|
||||||
|
|
||||||
dtype = self.get_dtype(array)
|
dtype = self.get_dtype(array)
|
||||||
|
@ -86,8 +218,32 @@ class Interface(ABC, Generic[T]):
|
||||||
self.raise_for_shape(shape_valid, shape)
|
self.raise_for_shape(shape_valid, shape)
|
||||||
|
|
||||||
array = self.after_validation(array)
|
array = self.after_validation(array)
|
||||||
|
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
def deserialize(self, array: Any) -> Union[V, Any]:
|
||||||
|
"""
|
||||||
|
If given a JSON serialized version of the array,
|
||||||
|
deserialize it first.
|
||||||
|
|
||||||
|
If a roundtrip-serialized :class:`.JsonDict`,
|
||||||
|
pass to :meth:`.JsonDict.handle_input`.
|
||||||
|
|
||||||
|
If a roundtrip-serialized :class:`.MarkedJson`,
|
||||||
|
unpack mark, check for validity, warn if not,
|
||||||
|
and try to continue with validation
|
||||||
|
"""
|
||||||
|
if isinstance(marked_array := MarkedJson.try_cast(array), MarkedJson):
|
||||||
|
try:
|
||||||
|
marked_array.interface.is_valid(self.__class__, raise_on_error=True)
|
||||||
|
except MarkMismatchError as e:
|
||||||
|
warnings.warn(
|
||||||
|
str(e) + "\nAttempting to continue validation...", stacklevel=2
|
||||||
|
)
|
||||||
|
array = marked_array.value
|
||||||
|
|
||||||
|
return self.json_model.handle_input(array)
|
||||||
|
|
||||||
def before_validation(self, array: Any) -> NDArrayType:
|
def before_validation(self, array: Any) -> NDArrayType:
|
||||||
"""
|
"""
|
||||||
Optional step pre-validation that coerces the input into a type that can be
|
Optional step pre-validation that coerces the input into a type that can be
|
||||||
|
@ -117,8 +273,6 @@ class Interface(ABC, Generic[T]):
|
||||||
"""
|
"""
|
||||||
Validate the dtype of the given array, returning
|
Validate the dtype of the given array, returning
|
||||||
``True`` if valid, ``False`` if not.
|
``True`` if valid, ``False`` if not.
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.dtype is Any:
|
if self.dtype is Any:
|
||||||
return True
|
return True
|
||||||
|
@ -211,17 +365,48 @@ class Interface(ABC, Generic[T]):
|
||||||
installed, etc.)
|
installed, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""
|
||||||
|
Short name for this interface
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def json_model(self) -> JsonDict:
|
||||||
|
"""
|
||||||
|
The :class:`.JsonDict` model used for roundtripping
|
||||||
|
JSON serialization
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_json(
|
@abstractmethod
|
||||||
cls, array: Type[T], info: Optional[SerializationInfo] = None
|
def to_json(cls, array: Type[T], info: SerializationInfo) -> Union[list, JsonDict]:
|
||||||
) -> Union[list, dict]:
|
|
||||||
"""
|
"""
|
||||||
Convert an array of :attr:`.return_type` to a JSON-compatible format using
|
Convert an array of :attr:`.return_type` to a JSON-compatible format using
|
||||||
base python types
|
base python types
|
||||||
"""
|
"""
|
||||||
if not isinstance(array, np.ndarray): # pragma: no cover
|
|
||||||
array = np.array(array)
|
@classmethod
|
||||||
return array.tolist()
|
def mark_json(cls, array: Union[list, dict]) -> dict:
|
||||||
|
"""
|
||||||
|
When using ``model_dump_json`` with ``mark_interface: True`` in the ``context``,
|
||||||
|
add additional annotations that would allow the serialized array to be
|
||||||
|
roundtripped.
|
||||||
|
|
||||||
|
Default is just to add an :class:`.InterfaceMark`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
>>> from pprint import pprint
|
||||||
|
>>> pprint(Interface.mark_json([1.0, 2.0]))
|
||||||
|
{'interface': {'cls': 'Interface',
|
||||||
|
'module': 'numpydantic.interface.interface',
|
||||||
|
'version': '1.2.2'},
|
||||||
|
'value': [1.0, 2.0]}
|
||||||
|
"""
|
||||||
|
return {"interface": cls.mark_interface(), "value": array}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def interfaces(
|
def interfaces(
|
||||||
|
@ -274,6 +459,28 @@ class Interface(ABC, Generic[T]):
|
||||||
|
|
||||||
return tuple(in_types)
|
return tuple(in_types)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def match_mark(cls, array: Any) -> Optional[Type["Interface"]]:
|
||||||
|
"""
|
||||||
|
Match a marked JSON dump of this array to the interface that it indicates.
|
||||||
|
|
||||||
|
First find an interface that matches by name, and then run its
|
||||||
|
``check`` method, because arrays can be dumped with a mark
|
||||||
|
but without ``round_trip == True`` (and thus can't necessarily
|
||||||
|
use the same interface that they were dumped with)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Interface if match found, None otherwise
|
||||||
|
"""
|
||||||
|
mark = MarkedJson.try_cast(array)
|
||||||
|
if not isinstance(mark, MarkedJson):
|
||||||
|
return None
|
||||||
|
|
||||||
|
interface = mark.interface.match_by_name()
|
||||||
|
if interface is not None and interface.check(mark.value):
|
||||||
|
return interface
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
def match(cls, array: Any, fast: bool = False) -> Type["Interface"]:
|
||||||
"""
|
"""
|
||||||
|
@ -291,11 +498,18 @@ class Interface(ABC, Generic[T]):
|
||||||
check each interface (as ordered by its ``priority`` , decreasing),
|
check each interface (as ordered by its ``priority`` , decreasing),
|
||||||
and return on the first match.
|
and return on the first match.
|
||||||
"""
|
"""
|
||||||
|
# Shortcircuit match if this is a marked json dump
|
||||||
|
array = MarkedJson.try_cast(array)
|
||||||
|
if (match := cls.match_mark(array)) is not None:
|
||||||
|
return match
|
||||||
|
elif isinstance(array, MarkedJson):
|
||||||
|
array = array.value
|
||||||
|
|
||||||
# first try and find a non-numpy interface, since the numpy interface
|
# first try and find a non-numpy interface, since the numpy interface
|
||||||
# will try and load the array into memory in its check method
|
# will try and load the array into memory in its check method
|
||||||
interfaces = cls.interfaces()
|
interfaces = cls.interfaces()
|
||||||
non_np_interfaces = [i for i in interfaces if i.__name__ != "NumpyInterface"]
|
non_np_interfaces = [i for i in interfaces if i.name != "numpy"]
|
||||||
np_interface = [i for i in interfaces if i.__name__ == "NumpyInterface"][0]
|
np_interface = [i for i in interfaces if i.name == "numpy"][0]
|
||||||
|
|
||||||
if fast:
|
if fast:
|
||||||
matches = []
|
matches = []
|
||||||
|
@ -335,3 +549,29 @@ class Interface(ABC, Generic[T]):
|
||||||
raise NoMatchError(f"No matching interfaces found for output {array}")
|
raise NoMatchError(f"No matching interfaces found for output {array}")
|
||||||
else:
|
else:
|
||||||
return matches[0]
|
return matches[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def mark_interface(cls) -> InterfaceMark:
|
||||||
|
"""
|
||||||
|
Create an interface mark indicating this interface for validation after
|
||||||
|
JSON serialization with ``round_trip==True``
|
||||||
|
"""
|
||||||
|
interface_module = inspect.getmodule(cls)
|
||||||
|
interface_module = (
|
||||||
|
None if interface_module is None else interface_module.__name__
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
v = (
|
||||||
|
None
|
||||||
|
if interface_module is None
|
||||||
|
else version(interface_module.split(".")[0])
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
PackageNotFoundError
|
||||||
|
): # pragma: no cover - no tests for missing interface deps
|
||||||
|
v = None
|
||||||
|
|
||||||
|
return InterfaceMark(
|
||||||
|
module=interface_module, cls=cls.__name__, name=cls.name, version=v
|
||||||
|
)
|
||||||
|
|
|
@ -2,9 +2,11 @@
|
||||||
Interface to numpy arrays
|
Interface to numpy arrays
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Literal, Union
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
|
from numpydantic.interface.interface import Interface, JsonDict
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -18,13 +20,31 @@ except ImportError: # pragma: no cover
|
||||||
np = None
|
np = None
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyJsonDict(JsonDict):
|
||||||
|
"""
|
||||||
|
JSON-able roundtrip representation of numpy array
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["numpy"]
|
||||||
|
dtype: str
|
||||||
|
array: list
|
||||||
|
|
||||||
|
def to_array_input(self) -> ndarray:
|
||||||
|
"""
|
||||||
|
Construct a numpy array
|
||||||
|
"""
|
||||||
|
return np.array(self.array, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class NumpyInterface(Interface):
|
class NumpyInterface(Interface):
|
||||||
"""
|
"""
|
||||||
Numpy :class:`~numpy.ndarray` s!
|
Numpy :class:`~numpy.ndarray` s!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
name = "numpy"
|
||||||
input_types = (ndarray, list)
|
input_types = (ndarray, list)
|
||||||
return_type = ndarray
|
return_type = ndarray
|
||||||
|
json_model = NumpyJsonDict
|
||||||
priority = -999
|
priority = -999
|
||||||
"""
|
"""
|
||||||
The numpy interface is usually the interface of last resort.
|
The numpy interface is usually the interface of last resort.
|
||||||
|
@ -41,6 +61,8 @@ class NumpyInterface(Interface):
|
||||||
"""
|
"""
|
||||||
if isinstance(array, ndarray):
|
if isinstance(array, ndarray):
|
||||||
return True
|
return True
|
||||||
|
elif isinstance(array, dict):
|
||||||
|
return NumpyJsonDict.is_valid(array)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
_ = np.array(array)
|
_ = np.array(array)
|
||||||
|
@ -61,3 +83,22 @@ class NumpyInterface(Interface):
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
"""Check that numpy is present in the environment"""
|
"""Check that numpy is present in the environment"""
|
||||||
return ENABLED
|
return ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_json(
|
||||||
|
cls, array: ndarray, info: SerializationInfo = None
|
||||||
|
) -> Union[list, JsonDict]:
|
||||||
|
"""
|
||||||
|
Convert an array of :attr:`.return_type` to a JSON-compatible format using
|
||||||
|
base python types
|
||||||
|
"""
|
||||||
|
if not isinstance(array, np.ndarray): # pragma: no cover
|
||||||
|
array = np.array(array)
|
||||||
|
|
||||||
|
json_array = array.tolist()
|
||||||
|
|
||||||
|
if info.round_trip:
|
||||||
|
json_array = NumpyJsonDict(
|
||||||
|
type=cls.name, dtype=str(array.dtype), array=json_array
|
||||||
|
)
|
||||||
|
return json_array
|
||||||
|
|
|
@ -3,10 +3,12 @@ Interface to support treating videos like arrays using OpenCV
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Tuple, Union
|
from typing import Any, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pydantic_core.core_schema import SerializationInfo
|
||||||
|
|
||||||
|
from numpydantic.interface import JsonDict
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -19,6 +21,19 @@ except ImportError: # pragma: no cover
|
||||||
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||||
|
|
||||||
|
|
||||||
|
class VideoJsonDict(JsonDict):
|
||||||
|
"""Json-able roundtrip representation of a video file"""
|
||||||
|
|
||||||
|
type: Literal["video"]
|
||||||
|
file: str
|
||||||
|
|
||||||
|
def to_array_input(self) -> "VideoProxy":
|
||||||
|
"""
|
||||||
|
Construct a :class:`.VideoProxy`
|
||||||
|
"""
|
||||||
|
return VideoProxy(path=Path(self.file))
|
||||||
|
|
||||||
|
|
||||||
class VideoProxy:
|
class VideoProxy:
|
||||||
"""
|
"""
|
||||||
Passthrough proxy class to interact with videos as arrays
|
Passthrough proxy class to interact with videos as arrays
|
||||||
|
@ -33,7 +48,7 @@ class VideoProxy:
|
||||||
)
|
)
|
||||||
|
|
||||||
if path is not None:
|
if path is not None:
|
||||||
path = Path(path)
|
path = Path(path).resolve()
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
self._video = video # type: Optional[VideoCapture]
|
self._video = video # type: Optional[VideoCapture]
|
||||||
|
@ -52,6 +67,9 @@ class VideoProxy:
|
||||||
"and it cant be reopened since source path cant be gotten "
|
"and it cant be reopened since source path cant be gotten "
|
||||||
"from VideoCapture objects"
|
"from VideoCapture objects"
|
||||||
)
|
)
|
||||||
|
if not self.path.exists():
|
||||||
|
raise FileNotFoundError(f"Video file {self.path} does not exist!")
|
||||||
|
|
||||||
self._video = VideoCapture(str(self.path))
|
self._video = VideoCapture(str(self.path))
|
||||||
return self._video
|
return self._video
|
||||||
|
|
||||||
|
@ -137,6 +155,10 @@ class VideoProxy:
|
||||||
slice_ = slice(0, slice_.stop, slice_.step)
|
slice_ = slice(0, slice_.stop, slice_.step)
|
||||||
return slice_
|
return slice_
|
||||||
|
|
||||||
|
def __array__(self) -> np.ndarray:
|
||||||
|
"""Whole video as a numpy array"""
|
||||||
|
return self[:]
|
||||||
|
|
||||||
def __getitem__(self, item: Union[int, slice, tuple]) -> np.ndarray:
|
def __getitem__(self, item: Union[int, slice, tuple]) -> np.ndarray:
|
||||||
if isinstance(item, int):
|
if isinstance(item, int):
|
||||||
# want a single frame
|
# want a single frame
|
||||||
|
@ -178,8 +200,16 @@ class VideoProxy:
|
||||||
raise NotImplementedError("Setting pixel values on videos is not supported!")
|
raise NotImplementedError("Setting pixel values on videos is not supported!")
|
||||||
|
|
||||||
def __getattr__(self, item: str):
|
def __getattr__(self, item: str):
|
||||||
|
if item == "__name__":
|
||||||
|
return "VideoProxy"
|
||||||
return getattr(self.video, item)
|
return getattr(self.video, item)
|
||||||
|
|
||||||
|
def __eq__(self, other: "VideoProxy") -> bool:
|
||||||
|
"""Check if this is a proxy to the same video file"""
|
||||||
|
if not isinstance(other, VideoProxy):
|
||||||
|
raise TypeError("Can only compare equality of two VideoProxies")
|
||||||
|
return self.path == other.path
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Number of frames in the video"""
|
"""Number of frames in the video"""
|
||||||
return self.shape[0]
|
return self.shape[0]
|
||||||
|
@ -190,8 +220,10 @@ class VideoInterface(Interface):
|
||||||
OpenCV interface to treat videos as arrays.
|
OpenCV interface to treat videos as arrays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
name = "video"
|
||||||
input_types = (str, Path, VideoCapture, VideoProxy)
|
input_types = (str, Path, VideoCapture, VideoProxy)
|
||||||
return_type = VideoProxy
|
return_type = VideoProxy
|
||||||
|
json_model = VideoJsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
|
@ -209,6 +241,9 @@ class VideoInterface(Interface):
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if isinstance(array, dict):
|
||||||
|
array = array.get("file", "")
|
||||||
|
|
||||||
if isinstance(array, str):
|
if isinstance(array, str):
|
||||||
try:
|
try:
|
||||||
array = Path(array)
|
array = Path(array)
|
||||||
|
@ -227,3 +262,13 @@ class VideoInterface(Interface):
|
||||||
else:
|
else:
|
||||||
proxy = VideoProxy(path=array)
|
proxy = VideoProxy(path=array)
|
||||||
return proxy
|
return proxy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_json(
|
||||||
|
cls, array: VideoProxy, info: SerializationInfo
|
||||||
|
) -> Union[list, VideoJsonDict]:
|
||||||
|
"""Return a json-representation of a video"""
|
||||||
|
if info.round_trip:
|
||||||
|
return VideoJsonDict(type=cls.name, file=str(array.path))
|
||||||
|
else:
|
||||||
|
return np.array(array).tolist()
|
||||||
|
|
|
@ -5,12 +5,12 @@ Interface to zarr arrays
|
||||||
import contextlib
|
import contextlib
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Sequence, Union
|
from typing import Any, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import SerializationInfo
|
from pydantic import SerializationInfo
|
||||||
|
|
||||||
from numpydantic.interface.interface import Interface
|
from numpydantic.interface.interface import Interface, JsonDict
|
||||||
from numpydantic.types import DtypeType
|
from numpydantic.types import DtypeType
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -56,13 +56,36 @@ class ZarrArrayPath:
|
||||||
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
|
raise ValueError("Only len 1-2 iterables can be used for a ZarrArrayPath")
|
||||||
|
|
||||||
|
|
||||||
|
class ZarrJsonDict(JsonDict):
|
||||||
|
"""Round-trip Json-able version of a Zarr Array"""
|
||||||
|
|
||||||
|
info: dict[str, str]
|
||||||
|
type: Literal["zarr"]
|
||||||
|
file: Optional[str] = None
|
||||||
|
path: Optional[str] = None
|
||||||
|
array: Optional[list] = None
|
||||||
|
|
||||||
|
def to_array_input(self) -> Union[ZarrArray, ZarrArrayPath]:
|
||||||
|
"""
|
||||||
|
Construct a ZarrArrayPath if file and path are present,
|
||||||
|
otherwise a ZarrArray
|
||||||
|
"""
|
||||||
|
if self.file:
|
||||||
|
array = ZarrArrayPath(file=self.file, path=self.path)
|
||||||
|
else:
|
||||||
|
array = zarr.array(self.array)
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
class ZarrInterface(Interface):
|
class ZarrInterface(Interface):
|
||||||
"""
|
"""
|
||||||
Interface to in-memory or on-disk zarr arrays
|
Interface to in-memory or on-disk zarr arrays
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
name = "zarr"
|
||||||
input_types = (Path, ZarrArray, ZarrArrayPath)
|
input_types = (Path, ZarrArray, ZarrArrayPath)
|
||||||
return_type = ZarrArray
|
return_type = ZarrArray
|
||||||
|
json_model = ZarrJsonDict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enabled(cls) -> bool:
|
def enabled(cls) -> bool:
|
||||||
|
@ -71,7 +94,7 @@ class ZarrInterface(Interface):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_array(
|
def _get_array(
|
||||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence]
|
array: Union[ZarrArray, str, dict, ZarrJsonDict, Path, ZarrArrayPath, Sequence]
|
||||||
) -> ZarrArray:
|
) -> ZarrArray:
|
||||||
if isinstance(array, ZarrArray):
|
if isinstance(array, ZarrArray):
|
||||||
return array
|
return array
|
||||||
|
@ -92,6 +115,12 @@ class ZarrInterface(Interface):
|
||||||
if isinstance(array, ZarrArray):
|
if isinstance(array, ZarrArray):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if isinstance(array, dict):
|
||||||
|
if array.get("type", False) == cls.name:
|
||||||
|
return True
|
||||||
|
# continue checking if dict contains a zarr file
|
||||||
|
array = array.get("file", "")
|
||||||
|
|
||||||
# See if can be coerced to ZarrArrayPath
|
# See if can be coerced to ZarrArrayPath
|
||||||
if isinstance(array, (Path, str)):
|
if isinstance(array, (Path, str)):
|
||||||
array = ZarrArrayPath(file=array)
|
array = ZarrArrayPath(file=array)
|
||||||
|
@ -135,26 +164,48 @@ class ZarrInterface(Interface):
|
||||||
cls,
|
cls,
|
||||||
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence],
|
array: Union[ZarrArray, str, Path, ZarrArrayPath, Sequence],
|
||||||
info: Optional[SerializationInfo] = None,
|
info: Optional[SerializationInfo] = None,
|
||||||
) -> dict:
|
) -> Union[list, ZarrJsonDict]:
|
||||||
"""
|
"""
|
||||||
Dump just the metadata for an array from :meth:`zarr.core.Array.info_items`
|
Dump a Zarr Array to JSON
|
||||||
plus the :meth:`zarr.core.Array.hexdigest`.
|
|
||||||
|
|
||||||
The full array can be returned by passing ``'zarr_dump_array': True`` to the
|
If ``info.round_trip == False``, dump the array as a list of lists.
|
||||||
serialization ``context`` ::
|
This may be a memory-intensive operation.
|
||||||
|
|
||||||
|
Otherwise, dump the metadata for an array from
|
||||||
|
:meth:`zarr.core.Array.info_items`
|
||||||
|
plus the :meth:`zarr.core.Array.hexdigest` as a :class:`.ZarrJsonDict`
|
||||||
|
|
||||||
|
If either the ``dump_array`` value in the context dictionary is ``True``
|
||||||
|
or the zarr array is an in-memory array, dump the array as well
|
||||||
|
(since without a persistent array it would be impossible to roundtrip and
|
||||||
|
dumping to JSON would be meaningless)
|
||||||
|
|
||||||
|
Passing ```dump_array': True`` to the serialization ``context``
|
||||||
|
looks like this::
|
||||||
|
|
||||||
model.model_dump_json(context={'zarr_dump_array': True})
|
model.model_dump_json(context={'zarr_dump_array': True})
|
||||||
"""
|
"""
|
||||||
dump_array = False
|
|
||||||
if info is not None and info.context is not None:
|
|
||||||
dump_array = info.context.get("zarr_dump_array", False)
|
|
||||||
|
|
||||||
array = cls._get_array(array)
|
array = cls._get_array(array)
|
||||||
info = array.info_items()
|
|
||||||
info_dict = {i[0]: i[1] for i in info}
|
|
||||||
info_dict["hexdigest"] = array.hexdigest()
|
|
||||||
|
|
||||||
if dump_array:
|
if info.round_trip:
|
||||||
info_dict["array"] = array[:].tolist()
|
dump_array = False
|
||||||
|
if info is not None and info.context is not None:
|
||||||
|
dump_array = info.context.get("dump_array", False)
|
||||||
|
is_file = False
|
||||||
|
|
||||||
return info_dict
|
as_json = {"type": cls.name}
|
||||||
|
if hasattr(array.store, "dir_path"):
|
||||||
|
is_file = True
|
||||||
|
as_json["file"] = array.store.dir_path()
|
||||||
|
as_json["path"] = array.name
|
||||||
|
as_json["info"] = {i[0]: i[1] for i in array.info_items()}
|
||||||
|
as_json["info"]["hexdigest"] = array.hexdigest()
|
||||||
|
|
||||||
|
if dump_array or not is_file:
|
||||||
|
as_json["array"] = array[:].tolist()
|
||||||
|
|
||||||
|
as_json = ZarrJsonDict(**as_json)
|
||||||
|
else:
|
||||||
|
as_json = array[:].tolist()
|
||||||
|
|
||||||
|
return as_json
|
||||||
|
|
|
@ -24,11 +24,10 @@ from numpydantic.exceptions import InterfaceError
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import Interface
|
||||||
from numpydantic.maps import python_to_nptyping
|
from numpydantic.maps import python_to_nptyping
|
||||||
from numpydantic.schema import (
|
from numpydantic.schema import (
|
||||||
_handler_type,
|
|
||||||
_jsonize_array,
|
|
||||||
get_validate_interface,
|
get_validate_interface,
|
||||||
make_json_schema,
|
make_json_schema,
|
||||||
)
|
)
|
||||||
|
from numpydantic.serialization import jsonize_array
|
||||||
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
from numpydantic.vendor.nptyping.error import InvalidArgumentsError
|
||||||
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
from numpydantic.vendor.nptyping.ndarray import NDArrayMeta as _NDArrayMeta
|
||||||
|
@ -41,6 +40,9 @@ from numpydantic.vendor.nptyping.typing_ import (
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
from nptyping.base_meta_classes import SubscriptableMeta
|
from nptyping.base_meta_classes import SubscriptableMeta
|
||||||
|
from pydantic._internal._schema_generation_shared import (
|
||||||
|
CallbackGetCoreSchemaHandler,
|
||||||
|
)
|
||||||
|
|
||||||
from numpydantic import Shape
|
from numpydantic import Shape
|
||||||
|
|
||||||
|
@ -164,33 +166,34 @@ class NDArray(NPTypingType, metaclass=NDArrayMeta):
|
||||||
def __get_pydantic_core_schema__(
|
def __get_pydantic_core_schema__(
|
||||||
cls,
|
cls,
|
||||||
_source_type: "NDArray",
|
_source_type: "NDArray",
|
||||||
_handler: _handler_type,
|
_handler: "CallbackGetCoreSchemaHandler",
|
||||||
) -> core_schema.CoreSchema:
|
) -> core_schema.CoreSchema:
|
||||||
shape, dtype = _source_type.__args__
|
shape, dtype = _source_type.__args__
|
||||||
shape: ShapeType
|
shape: ShapeType
|
||||||
dtype: DtypeType
|
dtype: DtypeType
|
||||||
|
|
||||||
# get pydantic core schema as a list of lists for JSON schema
|
# make core schema for json schema, store it and any model definitions
|
||||||
list_schema = make_json_schema(shape, dtype, _handler)
|
# note that there is a big of fragility in this function,
|
||||||
|
# as we need to access a private method of _handler to
|
||||||
|
# flatten out the json schema. See help(make_json_schema)
|
||||||
|
json_schema = make_json_schema(shape, dtype, _handler)
|
||||||
|
|
||||||
return core_schema.json_or_python_schema(
|
return core_schema.with_info_plain_validator_function(
|
||||||
json_schema=list_schema,
|
get_validate_interface(shape, dtype),
|
||||||
python_schema=core_schema.with_info_plain_validator_function(
|
|
||||||
get_validate_interface(shape, dtype)
|
|
||||||
),
|
|
||||||
serialization=core_schema.plain_serializer_function_ser_schema(
|
serialization=core_schema.plain_serializer_function_ser_schema(
|
||||||
_jsonize_array, when_used="json", info_arg=True
|
jsonize_array, when_used="json", info_arg=True
|
||||||
),
|
),
|
||||||
|
metadata=json_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_pydantic_json_schema__(
|
def __get_pydantic_json_schema__(
|
||||||
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
cls, schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||||
) -> core_schema.JsonSchema:
|
) -> core_schema.JsonSchema:
|
||||||
json_schema = handler(schema)
|
shape, dtype = cls.__args__
|
||||||
|
json_schema = handler(schema["metadata"])
|
||||||
json_schema = handler.resolve_ref_schema(json_schema)
|
json_schema = handler.resolve_ref_schema(json_schema)
|
||||||
|
|
||||||
dtype = cls.__args__[1]
|
|
||||||
if not isinstance(dtype, tuple) and dtype.__module__ not in (
|
if not isinstance(dtype, tuple) and dtype.__module__ not in (
|
||||||
"builtins",
|
"builtins",
|
||||||
"typing",
|
"typing",
|
||||||
|
|
|
@ -5,10 +5,10 @@ Helper functions for use with :class:`~numpydantic.NDArray` - see the note in
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, SerializationInfo
|
from pydantic import BaseModel
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
from pydantic_core.core_schema import ListSchema, ValidationInfo
|
||||||
|
|
||||||
|
@ -19,13 +19,16 @@ from numpydantic.types import DtypeType, NDArrayType, ShapeType
|
||||||
from numpydantic.vendor.nptyping.structure import StructureMeta
|
from numpydantic.vendor.nptyping.structure import StructureMeta
|
||||||
|
|
||||||
if TYPE_CHECKING: # pragma: no cover
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from pydantic._internal._schema_generation_shared import (
|
||||||
|
CallbackGetCoreSchemaHandler,
|
||||||
|
)
|
||||||
|
|
||||||
from numpydantic import Shape
|
from numpydantic import Shape
|
||||||
|
|
||||||
_handler_type = Callable[[Any], core_schema.CoreSchema]
|
|
||||||
_UNSUPPORTED_TYPES = (complex,)
|
|
||||||
|
|
||||||
|
def _numeric_dtype(
|
||||||
def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||||
|
) -> CoreSchema:
|
||||||
"""Make a numeric dtype that respects min/max values from extended numpy types"""
|
"""Make a numeric dtype that respects min/max values from extended numpy types"""
|
||||||
if dtype in (np.number,):
|
if dtype in (np.number,):
|
||||||
dtype = float
|
dtype = float
|
||||||
|
@ -36,14 +39,15 @@ def _numeric_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||||
elif issubclass(dtype, np.integer):
|
elif issubclass(dtype, np.integer):
|
||||||
info = np.iinfo(dtype)
|
info = np.iinfo(dtype)
|
||||||
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
|
schema = core_schema.int_schema(le=int(info.max), ge=int(info.min))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
schema = _handler.generate_schema(dtype)
|
schema = _handler.generate_schema(dtype)
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
|
|
||||||
def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
def _lol_dtype(
|
||||||
|
dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||||
|
) -> CoreSchema:
|
||||||
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
"""Get the innermost dtype schema to use in the generated pydantic schema"""
|
||||||
if isinstance(dtype, StructureMeta): # pragma: no cover
|
if isinstance(dtype, StructureMeta): # pragma: no cover
|
||||||
raise NotImplementedError("Structured dtypes are currently unsupported")
|
raise NotImplementedError("Structured dtypes are currently unsupported")
|
||||||
|
@ -79,11 +83,12 @@ def _lol_dtype(dtype: DtypeType, _handler: _handler_type) -> CoreSchema:
|
||||||
# does this need a warning?
|
# does this need a warning?
|
||||||
python_type = Any
|
python_type = Any
|
||||||
|
|
||||||
if python_type in _UNSUPPORTED_TYPES:
|
if python_type in (float, int):
|
||||||
array_type = core_schema.any_schema()
|
|
||||||
# TODO: warn and log here
|
|
||||||
elif python_type in (float, int):
|
|
||||||
array_type = _numeric_dtype(dtype, _handler)
|
array_type = _numeric_dtype(dtype, _handler)
|
||||||
|
elif python_type is bool:
|
||||||
|
array_type = core_schema.bool_schema()
|
||||||
|
elif python_type is Any:
|
||||||
|
array_type = core_schema.any_schema()
|
||||||
else:
|
else:
|
||||||
array_type = _handler.generate_schema(python_type)
|
array_type = _handler.generate_schema(python_type)
|
||||||
|
|
||||||
|
@ -208,14 +213,24 @@ def _unbounded_shape(
|
||||||
|
|
||||||
|
|
||||||
def make_json_schema(
|
def make_json_schema(
|
||||||
shape: ShapeType, dtype: DtypeType, _handler: _handler_type
|
shape: ShapeType, dtype: DtypeType, _handler: "CallbackGetCoreSchemaHandler"
|
||||||
) -> ListSchema:
|
) -> ListSchema:
|
||||||
"""
|
"""
|
||||||
Make a list of list JSON schema from a shape and a dtype.
|
Make a list of list pydantic core schema for an array from a shape and a dtype.
|
||||||
|
Used to generate JSON schema in the containing model, but not for validation,
|
||||||
|
which is handled by interfaces.
|
||||||
|
|
||||||
First resolves the dtype into a pydantic ``CoreSchema`` ,
|
First resolves the dtype into a pydantic ``CoreSchema`` ,
|
||||||
and then uses that with :func:`.list_of_lists_schema` .
|
and then uses that with :func:`.list_of_lists_schema` .
|
||||||
|
|
||||||
|
.. admonition:: Potentially Fragile
|
||||||
|
|
||||||
|
Uses a private method from the handler to flatten out nested definitions
|
||||||
|
(e.g. when dtype is a pydantic model)
|
||||||
|
so that they are present in the generated schema directly rather than
|
||||||
|
as references. Otherwise, at the time __get_pydantic_json_schema__ is called,
|
||||||
|
the definition references are lost.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
shape ( ShapeType ): Specification of a shape, as a tuple or
|
shape ( ShapeType ): Specification of a shape, as a tuple or
|
||||||
an nptyping ``Shape``
|
an nptyping ``Shape``
|
||||||
|
@ -234,6 +249,8 @@ def make_json_schema(
|
||||||
else:
|
else:
|
||||||
list_schema = list_of_lists_schema(shape, dtype_schema)
|
list_schema = list_of_lists_schema(shape, dtype_schema)
|
||||||
|
|
||||||
|
list_schema = _handler._generate_schema.clean_schema(list_schema)
|
||||||
|
|
||||||
return list_schema
|
return list_schema
|
||||||
|
|
||||||
|
|
||||||
|
@ -252,9 +269,3 @@ def get_validate_interface(shape: ShapeType, dtype: DtypeType) -> Callable:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return validate_interface
|
return validate_interface
|
||||||
|
|
||||||
|
|
||||||
def _jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
|
||||||
"""Use an interface class to render an array as JSON"""
|
|
||||||
interface_cls = Interface.match_output(value)
|
|
||||||
return interface_cls.to_json(value, info)
|
|
||||||
|
|
128
src/numpydantic/serialization.py
Normal file
128
src/numpydantic/serialization.py
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
"""
|
||||||
|
Serialization helpers for :func:`pydantic.BaseModel.model_dump`
|
||||||
|
and :func:`pydantic.BaseModel.model_dump_json` .
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, TypeVar, Union
|
||||||
|
|
||||||
|
from pydantic_core.core_schema import SerializationInfo
|
||||||
|
|
||||||
|
from numpydantic.interface import Interface, JsonDict
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
|
||||||
|
def jsonize_array(value: Any, info: SerializationInfo) -> Union[list, dict]:
|
||||||
|
"""Use an interface class to render an array as JSON"""
|
||||||
|
interface_cls = Interface.match_output(value)
|
||||||
|
array = interface_cls.to_json(value, info)
|
||||||
|
if isinstance(array, JsonDict):
|
||||||
|
array = array.model_dump(exclude_none=True)
|
||||||
|
|
||||||
|
if info.context:
|
||||||
|
if info.context.get("mark_interface", False):
|
||||||
|
array = interface_cls.mark_json(array)
|
||||||
|
|
||||||
|
if info.context.get("absolute_paths", False):
|
||||||
|
array = _absolutize_paths(array)
|
||||||
|
else:
|
||||||
|
relative_to = info.context.get("relative_to", ".")
|
||||||
|
array = _relativize_paths(array, relative_to)
|
||||||
|
else:
|
||||||
|
# relativize paths by default
|
||||||
|
array = _relativize_paths(array, ".")
|
||||||
|
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
def _relativize_paths(value: dict, relative_to: str = ".") -> dict:
|
||||||
|
"""
|
||||||
|
Make paths relative to either the current directory or the provided
|
||||||
|
``relative_to`` directory, if provided in the context
|
||||||
|
"""
|
||||||
|
relative_to = Path(relative_to).resolve()
|
||||||
|
# pdb.set_trace()
|
||||||
|
|
||||||
|
def _r_path(v: Any) -> Any:
|
||||||
|
try:
|
||||||
|
path = Path(v)
|
||||||
|
if not path.exists():
|
||||||
|
return v
|
||||||
|
return str(relative_path(path, relative_to))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return v
|
||||||
|
|
||||||
|
return _walk_and_apply(value, _r_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _absolutize_paths(value: dict) -> dict:
|
||||||
|
def _a_path(v: Any) -> Any:
|
||||||
|
try:
|
||||||
|
path = Path(v)
|
||||||
|
if not path.exists():
|
||||||
|
return v
|
||||||
|
return str(path.resolve())
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return v
|
||||||
|
|
||||||
|
return _walk_and_apply(value, _a_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _walk_and_apply(value: T, f: Callable[[U], U]) -> T:
|
||||||
|
"""
|
||||||
|
Walk an object, applying a function
|
||||||
|
"""
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for k, v in value.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
_walk_and_apply(v, f)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
value[k] = [_walk_and_apply(sub_v, f) for sub_v in v]
|
||||||
|
else:
|
||||||
|
value[k] = f(v)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
value = [_walk_and_apply(v, f) for v in value]
|
||||||
|
else:
|
||||||
|
value = f(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def relative_path(self: Path, other: Path, walk_up: bool = True) -> Path:
|
||||||
|
"""
|
||||||
|
"Backport" of :meth:`pathlib.Path.relative_to` with ``walk_up=True``
|
||||||
|
that's not available pre 3.12.
|
||||||
|
|
||||||
|
Return the relative path to another path identified by the passed
|
||||||
|
arguments. If the operation is not possible (because this is not
|
||||||
|
related to the other path), raise ValueError.
|
||||||
|
|
||||||
|
The *walk_up* parameter controls whether `..` may be used to resolve
|
||||||
|
the path.
|
||||||
|
|
||||||
|
References:
|
||||||
|
https://github.com/python/cpython/blob/8a2baedc4bcb606da937e4e066b4b3a18961cace/Lib/pathlib/_abc.py#L244-L270
|
||||||
|
"""
|
||||||
|
# pdb.set_trace()
|
||||||
|
if not isinstance(other, Path): # pragma: no cover - ripped from cpython
|
||||||
|
other = Path(other)
|
||||||
|
self_parts = self.parts
|
||||||
|
other_parts = other.parts
|
||||||
|
anchor0, parts0 = self_parts[0], list(reversed(self_parts[1:]))
|
||||||
|
anchor1, parts1 = other_parts[0], list(reversed(other_parts[1:]))
|
||||||
|
if anchor0 != anchor1:
|
||||||
|
raise ValueError(f"{self!r} and {other!r} have different anchors")
|
||||||
|
while parts0 and parts1 and parts0[-1] == parts1[-1]:
|
||||||
|
parts0.pop()
|
||||||
|
parts1.pop()
|
||||||
|
for part in parts1: # pragma: no cover - not testing, ripped off from cpython
|
||||||
|
if not part or part == ".":
|
||||||
|
pass
|
||||||
|
elif not walk_up:
|
||||||
|
raise ValueError(f"{self!r} is not in the subpath of {other!r}")
|
||||||
|
elif part == "..":
|
||||||
|
raise ValueError(f"'..' segment in {other!r} cannot be walked")
|
||||||
|
else:
|
||||||
|
parts0.append("..")
|
||||||
|
return Path(*reversed(parts0))
|
|
@ -1,4 +1,3 @@
|
||||||
import pdb
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from typing import Tuple, Callable
|
from typing import Callable, Tuple, Type
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
import zarr
|
import zarr
|
||||||
|
@ -12,27 +12,47 @@ from numpydantic import interface, NDArray
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="function",
|
scope="function",
|
||||||
params=[
|
params=[
|
||||||
([[1, 2], [3, 4]], interface.NumpyInterface),
|
pytest.param(
|
||||||
(np.zeros((3, 4)), interface.NumpyInterface),
|
([[1, 2], [3, 4]], interface.NumpyInterface),
|
||||||
("hdf5_array", interface.H5Interface),
|
marks=pytest.mark.numpy,
|
||||||
(da.random.random((10, 10)), interface.DaskInterface),
|
id="numpy-list",
|
||||||
(zarr.ones((10, 10)), interface.ZarrInterface),
|
),
|
||||||
("zarr_nested_array", interface.ZarrInterface),
|
pytest.param(
|
||||||
("zarr_array", interface.ZarrInterface),
|
(np.zeros((3, 4)), interface.NumpyInterface),
|
||||||
("avi_video", interface.VideoInterface),
|
marks=pytest.mark.numpy,
|
||||||
],
|
id="numpy",
|
||||||
ids=[
|
),
|
||||||
"numpy_list",
|
pytest.param(
|
||||||
"numpy",
|
("hdf5_array", interface.H5Interface),
|
||||||
"H5ArrayPath",
|
marks=pytest.mark.hdf5,
|
||||||
"dask",
|
id="h5-array-path",
|
||||||
"zarr_memory",
|
),
|
||||||
"zarr_nested",
|
pytest.param(
|
||||||
"zarr_array",
|
(da.random.random((10, 10)), interface.DaskInterface),
|
||||||
"video",
|
marks=pytest.mark.dask,
|
||||||
|
id="dask",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
(zarr.ones((10, 10)), interface.ZarrInterface),
|
||||||
|
marks=pytest.mark.zarr,
|
||||||
|
id="zarr-memory",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
("zarr_nested_array", interface.ZarrInterface),
|
||||||
|
marks=pytest.mark.zarr,
|
||||||
|
id="zarr-nested",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
("zarr_array", interface.ZarrInterface),
|
||||||
|
marks=pytest.mark.zarr,
|
||||||
|
id="zarr-array",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
("avi_video", interface.VideoInterface), marks=pytest.mark.video, id="video"
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def interface_type(request) -> Tuple[NDArray, interface.Interface]:
|
def interface_type(request) -> Tuple[NDArray, Type[interface.Interface]]:
|
||||||
"""
|
"""
|
||||||
Test cases for each interface's ``check`` method - each input should match the
|
Test cases for each interface's ``check`` method - each input should match the
|
||||||
provided interface and that interface only
|
provided interface and that interface only
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import pdb
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -11,6 +9,8 @@ from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
|
||||||
from tests.conftest import ValidationCase
|
from tests.conftest import ValidationCase
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.dask
|
||||||
|
|
||||||
|
|
||||||
def dask_array(case: ValidationCase) -> da.Array:
|
def dask_array(case: ValidationCase) -> da.Array:
|
||||||
if issubclass(case.dtype, BaseModel):
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
@ -42,14 +42,17 @@ def test_dask_check(interface_type):
|
||||||
assert not DaskInterface.check(interface_type[0])
|
assert not DaskInterface.check(interface_type[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
def test_dask_shape(shape_cases):
|
def test_dask_shape(shape_cases):
|
||||||
_test_dask_case(shape_cases)
|
_test_dask_case(shape_cases)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
def test_dask_dtype(dtype_cases):
|
def test_dask_dtype(dtype_cases):
|
||||||
_test_dask_case(dtype_cases)
|
_test_dask_case(dtype_cases)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
def test_dask_to_json(array_model):
|
def test_dask_to_json(array_model):
|
||||||
array_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
array_list = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
|
||||||
array = da.array(array_list)
|
array = da.array(array_list)
|
||||||
|
|
|
@ -1,10 +0,0 @@
|
||||||
"""
|
|
||||||
Tests for dunder methods on all interfaces
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def test_dunder_len(all_interfaces):
|
|
||||||
"""
|
|
||||||
Each interface or proxy type should support __len__
|
|
||||||
"""
|
|
||||||
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
|
|
@ -14,6 +14,8 @@ from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
|
||||||
from tests.conftest import ValidationCase
|
from tests.conftest import ValidationCase
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.hdf5
|
||||||
|
|
||||||
|
|
||||||
def hdf5_array_case(
|
def hdf5_array_case(
|
||||||
case: ValidationCase, array_func, compound: bool = False
|
case: ValidationCase, array_func, compound: bool = False
|
||||||
|
@ -72,11 +74,13 @@ def test_hdf5_check_not_hdf5(tmp_path):
|
||||||
assert not H5Interface.check(spec)
|
assert not H5Interface.check(spec)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
@pytest.mark.parametrize("compound", [True, False])
|
@pytest.mark.parametrize("compound", [True, False])
|
||||||
def test_hdf5_shape(shape_cases, hdf5_array, compound):
|
def test_hdf5_shape(shape_cases, hdf5_array, compound):
|
||||||
_test_hdf5_case(shape_cases, hdf5_array, compound)
|
_test_hdf5_case(shape_cases, hdf5_array, compound)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
@pytest.mark.parametrize("compound", [True, False])
|
@pytest.mark.parametrize("compound", [True, False])
|
||||||
def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
|
def test_hdf5_dtype(dtype_cases, hdf5_array, compound):
|
||||||
_test_hdf5_case(dtype_cases, hdf5_array, compound)
|
_test_hdf5_case(dtype_cases, hdf5_array, compound)
|
||||||
|
@ -90,6 +94,7 @@ def test_hdf5_dataset_not_exists(hdf5_array, model_blank):
|
||||||
assert "no array found" in e
|
assert "no array found" in e
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
def test_assignment(hdf5_array, model_blank):
|
def test_assignment(hdf5_array, model_blank):
|
||||||
array = hdf5_array()
|
array = hdf5_array()
|
||||||
|
|
||||||
|
@ -101,7 +106,9 @@ def test_assignment(hdf5_array, model_blank):
|
||||||
assert (model.array[1:3, 2:4] == 10).all()
|
assert (model.array[1:3, 2:4] == 10).all()
|
||||||
|
|
||||||
|
|
||||||
def test_to_json(hdf5_array, array_model):
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("round_trip", (True, False))
|
||||||
|
def test_to_json(hdf5_array, array_model, round_trip):
|
||||||
"""
|
"""
|
||||||
Test serialization of HDF5 arrays to JSON
|
Test serialization of HDF5 arrays to JSON
|
||||||
Args:
|
Args:
|
||||||
|
@ -115,15 +122,19 @@ def test_to_json(hdf5_array, array_model):
|
||||||
|
|
||||||
instance = model(array=array) # type: BaseModel
|
instance = model(array=array) # type: BaseModel
|
||||||
|
|
||||||
json_str = instance.model_dump_json()
|
json_str = instance.model_dump_json(
|
||||||
json_dict = json.loads(json_str)["array"]
|
round_trip=round_trip, context={"absolute_paths": True}
|
||||||
|
)
|
||||||
assert json_dict["file"] == str(array.file)
|
json_dumped = json.loads(json_str)["array"]
|
||||||
assert json_dict["path"] == str(array.path)
|
if round_trip:
|
||||||
assert json_dict["attrs"] == {}
|
assert json_dumped["file"] == str(array.file)
|
||||||
assert json_dict["array"] == instance.array[:].tolist()
|
assert json_dumped["path"] == str(array.path)
|
||||||
|
else:
|
||||||
|
assert json_dumped == instance.array[:].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.proxy
|
||||||
def test_compound_dtype(tmp_path):
|
def test_compound_dtype(tmp_path):
|
||||||
"""
|
"""
|
||||||
hdf5 proxy indexes compound dtypes as single fields when field is given
|
hdf5 proxy indexes compound dtypes as single fields when field is given
|
||||||
|
@ -158,6 +169,8 @@ def test_compound_dtype(tmp_path):
|
||||||
assert all(instance.array[1] == 2)
|
assert all(instance.array[1] == 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.proxy
|
||||||
@pytest.mark.parametrize("compound", [True, False])
|
@pytest.mark.parametrize("compound", [True, False])
|
||||||
def test_strings(hdf5_array, compound):
|
def test_strings(hdf5_array, compound):
|
||||||
"""
|
"""
|
||||||
|
@ -177,6 +190,8 @@ def test_strings(hdf5_array, compound):
|
||||||
assert all(instance.array[1] == "sup")
|
assert all(instance.array[1] == "sup")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.proxy
|
||||||
@pytest.mark.parametrize("compound", [True, False])
|
@pytest.mark.parametrize("compound", [True, False])
|
||||||
def test_datetime(hdf5_array, compound):
|
def test_datetime(hdf5_array, compound):
|
||||||
"""
|
"""
|
||||||
|
@ -218,3 +233,29 @@ def test_empty_dataset(dtype, tmp_path):
|
||||||
array: NDArray[Any, dtype]
|
array: NDArray[Any, dtype]
|
||||||
|
|
||||||
_ = MyModel(array=(array_path, "/data"))
|
_ = MyModel(array=(array_path, "/data"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"comparison,valid",
|
||||||
|
[
|
||||||
|
(H5Proxy(file="test_file.h5", path="/subpath", field="sup"), True),
|
||||||
|
(H5Proxy(file="test_file.h5", path="/subpath"), False),
|
||||||
|
(H5Proxy(file="different_file.h5", path="/subpath"), False),
|
||||||
|
(("different_file.h5", "/subpath", "sup"), ValueError),
|
||||||
|
("not even a proxy-like thing", ValueError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_proxy_eq(comparison, valid):
|
||||||
|
"""
|
||||||
|
test the __eq__ method of H5ArrayProxy matches proxies to the same
|
||||||
|
dataset (and path), or raises a ValueError
|
||||||
|
"""
|
||||||
|
proxy_a = H5Proxy(file="test_file.h5", path="/subpath", field="sup")
|
||||||
|
if valid is True:
|
||||||
|
assert proxy_a == comparison
|
||||||
|
elif valid is False:
|
||||||
|
assert proxy_a != comparison
|
||||||
|
else:
|
||||||
|
with pytest.raises(valid):
|
||||||
|
assert proxy_a == comparison
|
||||||
|
|
|
@ -4,11 +4,32 @@ for tests that should apply to all interfaces, use ``test_interfaces.py``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from numpydantic.interface import Interface
|
from numpydantic.interface import (
|
||||||
|
Interface,
|
||||||
|
JsonDict,
|
||||||
|
InterfaceMark,
|
||||||
|
NumpyInterface,
|
||||||
|
MarkedJson,
|
||||||
|
)
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from numpydantic.interface.interface import V
|
||||||
|
|
||||||
|
|
||||||
|
class MyJsonDict(JsonDict):
|
||||||
|
type: Literal["my_json_dict"]
|
||||||
|
field: str
|
||||||
|
number: int
|
||||||
|
|
||||||
|
def to_array_input(self) -> V:
|
||||||
|
dumped = self.model_dump()
|
||||||
|
dumped["extra_input_param"] = True
|
||||||
|
return dumped
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -162,3 +183,66 @@ def test_interface_recursive(interfaces):
|
||||||
assert issubclass(interfaces.interface3, interfaces.interface1)
|
assert issubclass(interfaces.interface3, interfaces.interface1)
|
||||||
assert issubclass(interfaces.interface1, Interface)
|
assert issubclass(interfaces.interface1, Interface)
|
||||||
assert interfaces.interface4 in ifaces
|
assert interfaces.interface4 in ifaces
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_jsondict_is_valid():
|
||||||
|
"""
|
||||||
|
A JsonDict should return a bool true/false if it is valid or not,
|
||||||
|
and raise an error when requested
|
||||||
|
"""
|
||||||
|
invalid = {"doesnt": "have", "the": "props"}
|
||||||
|
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||||
|
assert MyJsonDict.is_valid(valid)
|
||||||
|
assert not MyJsonDict.is_valid(invalid)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
assert not MyJsonDict.is_valid(invalid, raise_on_error=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_jsondict_handle_input():
|
||||||
|
"""
|
||||||
|
JsonDict should be able to parse a valid dict and return it to the input format
|
||||||
|
"""
|
||||||
|
valid = {"type": "my_json_dict", "field": "a_field", "number": 1}
|
||||||
|
instantiated = MyJsonDict(**valid)
|
||||||
|
expected = {
|
||||||
|
"type": "my_json_dict",
|
||||||
|
"field": "a_field",
|
||||||
|
"number": 1,
|
||||||
|
"extra_input_param": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
for item in (valid, instantiated):
|
||||||
|
result = MyJsonDict.handle_input(item)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("interface", Interface.interfaces())
|
||||||
|
def test_interface_mark_match_by_name(interface):
|
||||||
|
"""
|
||||||
|
Interface mark should match an interface by its name
|
||||||
|
"""
|
||||||
|
# other parts don't matter
|
||||||
|
mark = InterfaceMark(module="fake", cls="fake", version="fake", name=interface.name)
|
||||||
|
fake_mark = InterfaceMark(
|
||||||
|
module="fake", cls="fake", version="fake", name="also_fake"
|
||||||
|
)
|
||||||
|
assert mark.match_by_name() is interface
|
||||||
|
assert fake_mark.match_by_name() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_marked_json_try_cast():
|
||||||
|
"""
|
||||||
|
MarkedJson.try_cast should try and cast to a markedjson!
|
||||||
|
returning the value unchanged if it's not a match
|
||||||
|
"""
|
||||||
|
valid = {"interface": NumpyInterface.mark_interface(), "value": [[1, 2], [3, 4]]}
|
||||||
|
invalid = [1, 2, 3, 4, 5]
|
||||||
|
mimic = {"interface": "not really", "value": "still not really"}
|
||||||
|
|
||||||
|
assert isinstance(MarkedJson.try_cast(valid), MarkedJson)
|
||||||
|
assert MarkedJson.try_cast(invalid) is invalid
|
||||||
|
assert MarkedJson.try_cast(mimic) is mimic
|
||||||
|
|
|
@ -2,6 +2,41 @@
|
||||||
Tests that should be applied to all interfaces
|
Tests that should be applied to all interfaces
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typing import Callable
|
||||||
|
from importlib.metadata import version
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import dask.array as da
|
||||||
|
from zarr.core import Array as ZarrArray
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from numpydantic.interface import Interface, InterfaceMark, MarkedJson
|
||||||
|
|
||||||
|
|
||||||
|
def _test_roundtrip(source: BaseModel, target: BaseModel, round_trip: bool):
|
||||||
|
"""Test model equality for roundtrip tests"""
|
||||||
|
if round_trip:
|
||||||
|
assert type(target.array) is type(source.array)
|
||||||
|
if isinstance(source.array, (np.ndarray, ZarrArray)):
|
||||||
|
assert np.array_equal(target.array, np.array(source.array))
|
||||||
|
elif isinstance(source.array, da.Array):
|
||||||
|
assert np.all(da.equal(target.array, source.array))
|
||||||
|
else:
|
||||||
|
assert target.array == source.array
|
||||||
|
|
||||||
|
assert target.array.dtype == source.array.dtype
|
||||||
|
else:
|
||||||
|
assert np.array_equal(target.array, np.array(source.array))
|
||||||
|
|
||||||
|
|
||||||
|
def test_dunder_len(all_interfaces):
|
||||||
|
"""
|
||||||
|
Each interface or proxy type should support __len__
|
||||||
|
"""
|
||||||
|
assert len(all_interfaces.array) == all_interfaces.array.shape[0]
|
||||||
|
|
||||||
|
|
||||||
def test_interface_revalidate(all_interfaces):
|
def test_interface_revalidate(all_interfaces):
|
||||||
"""
|
"""
|
||||||
|
@ -10,3 +45,86 @@ def test_interface_revalidate(all_interfaces):
|
||||||
See: https://github.com/p2p-ld/numpydantic/pull/14
|
See: https://github.com/p2p-ld/numpydantic/pull/14
|
||||||
"""
|
"""
|
||||||
_ = type(all_interfaces)(array=all_interfaces.array)
|
_ = type(all_interfaces)(array=all_interfaces.array)
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_rematch(interface_type):
|
||||||
|
"""
|
||||||
|
All interfaces should match the results of the object they return after validation
|
||||||
|
"""
|
||||||
|
array, interface = interface_type
|
||||||
|
if isinstance(array, Callable):
|
||||||
|
array = array()
|
||||||
|
|
||||||
|
assert Interface.match(interface().validate(array)) is interface
|
||||||
|
|
||||||
|
|
||||||
|
def test_interface_to_numpy_array(all_interfaces):
|
||||||
|
"""
|
||||||
|
All interfaces should be able to have the output of their validation stage
|
||||||
|
coerced to a numpy array with np.array()
|
||||||
|
"""
|
||||||
|
_ = np.array(all_interfaces.array)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
def test_interface_dump_json(all_interfaces):
|
||||||
|
"""
|
||||||
|
All interfaces should be able to dump to json
|
||||||
|
"""
|
||||||
|
all_interfaces.model_dump_json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("round_trip", [True, False])
|
||||||
|
def test_interface_roundtrip_json(all_interfaces, round_trip):
|
||||||
|
"""
|
||||||
|
All interfaces should be able to roundtrip to and from json
|
||||||
|
"""
|
||||||
|
dumped_json = all_interfaces.model_dump_json(round_trip=round_trip)
|
||||||
|
model = all_interfaces.model_validate_json(dumped_json)
|
||||||
|
_test_roundtrip(all_interfaces, model, round_trip)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("an_interface", Interface.interfaces())
|
||||||
|
def test_interface_mark_interface(an_interface):
|
||||||
|
"""
|
||||||
|
All interfaces should be able to mark the current version and interface info
|
||||||
|
"""
|
||||||
|
mark = an_interface.mark_interface()
|
||||||
|
assert isinstance(mark, InterfaceMark)
|
||||||
|
assert mark.name == an_interface.name
|
||||||
|
assert mark.cls == an_interface.__name__
|
||||||
|
assert mark.module == an_interface.__module__
|
||||||
|
assert mark.version == version(mark.module.split(".")[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("valid", [True, False])
|
||||||
|
@pytest.mark.parametrize("round_trip", [True, False])
|
||||||
|
@pytest.mark.filterwarnings("ignore:Mismatch between serialized mark")
|
||||||
|
def test_interface_mark_roundtrip(all_interfaces, valid, round_trip):
|
||||||
|
"""
|
||||||
|
All interfaces should be able to roundtrip with the marked interface,
|
||||||
|
and a mismatch should raise a warning and attempt to proceed
|
||||||
|
"""
|
||||||
|
dumped_json = all_interfaces.model_dump_json(
|
||||||
|
round_trip=round_trip, context={"mark_interface": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
data = json.loads(dumped_json)
|
||||||
|
|
||||||
|
# ensure that we are a MarkedJson
|
||||||
|
_ = MarkedJson.model_validate_json(json.dumps(data["array"]))
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
# ruin the version
|
||||||
|
data["array"]["interface"]["version"] = "v99999999"
|
||||||
|
dumped_json = json.dumps(data)
|
||||||
|
|
||||||
|
with pytest.warns(match="Mismatch.*"):
|
||||||
|
model = all_interfaces.model_validate_json(dumped_json)
|
||||||
|
else:
|
||||||
|
model = all_interfaces.model_validate_json(dumped_json)
|
||||||
|
|
||||||
|
_test_roundtrip(all_interfaces, model, round_trip)
|
||||||
|
|
|
@ -5,6 +5,8 @@ from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
|
||||||
from tests.conftest import ValidationCase
|
from tests.conftest import ValidationCase
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.numpy
|
||||||
|
|
||||||
|
|
||||||
def numpy_array(case: ValidationCase) -> np.ndarray:
|
def numpy_array(case: ValidationCase) -> np.ndarray:
|
||||||
if issubclass(case.dtype, BaseModel):
|
if issubclass(case.dtype, BaseModel):
|
||||||
|
@ -22,10 +24,12 @@ def _test_np_case(case: ValidationCase):
|
||||||
case.model(array=array)
|
case.model(array=array)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
def test_numpy_shape(shape_cases):
|
def test_numpy_shape(shape_cases):
|
||||||
_test_np_case(shape_cases)
|
_test_np_case(shape_cases)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
def test_numpy_dtype(dtype_cases):
|
def test_numpy_dtype(dtype_cases):
|
||||||
_test_np_case(dtype_cases)
|
_test_np_case(dtype_cases)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,8 @@ from numpydantic import NDArray, Shape
|
||||||
from numpydantic import dtype as dt
|
from numpydantic import dtype as dt
|
||||||
from numpydantic.interface.video import VideoProxy
|
from numpydantic.interface.video import VideoProxy
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.video
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("input_type", [str, Path])
|
@pytest.mark.parametrize("input_type", [str, Path])
|
||||||
def test_video_validation(avi_video, input_type):
|
def test_video_validation(avi_video, input_type):
|
||||||
|
@ -49,6 +51,7 @@ def test_video_from_videocapture(avi_video):
|
||||||
opened_vid.release()
|
opened_vid.release()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
def test_video_wrong_shape(avi_video):
|
def test_video_wrong_shape(avi_video):
|
||||||
shape = (100, 50)
|
shape = (100, 50)
|
||||||
|
|
||||||
|
@ -65,6 +68,7 @@ def test_video_wrong_shape(avi_video):
|
||||||
instance = MyModel(array=vid)
|
instance = MyModel(array=vid)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
def test_video_getitem(avi_video):
|
def test_video_getitem(avi_video):
|
||||||
"""
|
"""
|
||||||
Should be able to get individual frames and slices as if it were a normal array
|
Should be able to get individual frames and slices as if it were a normal array
|
||||||
|
@ -127,6 +131,7 @@ def test_video_getitem(avi_video):
|
||||||
instance.array[5] = 10
|
instance.array[5] = 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
def test_video_attrs(avi_video):
|
def test_video_attrs(avi_video):
|
||||||
"""Should be able to access opencv properties"""
|
"""Should be able to access opencv properties"""
|
||||||
shape = (100, 50)
|
shape = (100, 50)
|
||||||
|
@ -142,6 +147,7 @@ def test_video_attrs(avi_video):
|
||||||
assert int(instance.array.get(cv2.CAP_PROP_POS_FRAMES)) == 5
|
assert int(instance.array.get(cv2.CAP_PROP_POS_FRAMES)) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
def test_video_close(avi_video):
|
def test_video_close(avi_video):
|
||||||
"""Should close and reopen video file if needed"""
|
"""Should close and reopen video file if needed"""
|
||||||
shape = (100, 50)
|
shape = (100, 50)
|
||||||
|
@ -158,3 +164,42 @@ def test_video_close(avi_video):
|
||||||
assert instance.array._video is None
|
assert instance.array._video is None
|
||||||
# reopen
|
# reopen
|
||||||
assert isinstance(instance.array.video, cv2.VideoCapture)
|
assert isinstance(instance.array.video, cv2.VideoCapture)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
|
def test_video_not_exists(tmp_path):
|
||||||
|
"""
|
||||||
|
A video file that doesn't exist should raise an error
|
||||||
|
"""
|
||||||
|
video = VideoProxy(tmp_path / "not_real.avi")
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
_ = video.video
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.proxy
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"comparison,valid",
|
||||||
|
[
|
||||||
|
(VideoProxy("test_video.avi"), True),
|
||||||
|
(VideoProxy("not_real_video.avi"), False),
|
||||||
|
("not even a video proxy", TypeError),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_video_proxy_eq(comparison, valid):
|
||||||
|
"""
|
||||||
|
Comparing a video proxy's equality should be valid if the path matches
|
||||||
|
Args:
|
||||||
|
comparison:
|
||||||
|
valid:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
proxy_a = VideoProxy("test_video.avi")
|
||||||
|
if valid is True:
|
||||||
|
assert proxy_a == comparison
|
||||||
|
elif valid is False:
|
||||||
|
assert proxy_a != comparison
|
||||||
|
else:
|
||||||
|
with pytest.raises(valid):
|
||||||
|
assert proxy_a == comparison
|
||||||
|
|
|
@ -6,13 +6,14 @@ import zarr
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from numcodecs import Pickle
|
from numcodecs import Pickle
|
||||||
|
|
||||||
|
|
||||||
from numpydantic.interface import ZarrInterface
|
from numpydantic.interface import ZarrInterface
|
||||||
from numpydantic.interface.zarr import ZarrArrayPath
|
from numpydantic.interface.zarr import ZarrArrayPath
|
||||||
from numpydantic.exceptions import DtypeError, ShapeError
|
from numpydantic.exceptions import DtypeError, ShapeError
|
||||||
|
|
||||||
from tests.conftest import ValidationCase
|
from tests.conftest import ValidationCase
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.zarr
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def dir_array(tmp_output_dir_func) -> zarr.DirectoryStore:
|
def dir_array(tmp_output_dir_func) -> zarr.DirectoryStore:
|
||||||
|
@ -87,10 +88,12 @@ def test_zarr_check(interface_type):
|
||||||
assert not ZarrInterface.check(interface_type[0])
|
assert not ZarrInterface.check(interface_type[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
def test_zarr_shape(store, shape_cases):
|
def test_zarr_shape(store, shape_cases):
|
||||||
_test_zarr_case(shape_cases, store)
|
_test_zarr_case(shape_cases, store)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
def test_zarr_dtype(dtype_cases, store):
|
def test_zarr_dtype(dtype_cases, store):
|
||||||
_test_zarr_case(dtype_cases, store)
|
_test_zarr_case(dtype_cases, store)
|
||||||
|
|
||||||
|
@ -123,7 +126,10 @@ def test_zarr_array_path_from_iterable(zarr_array):
|
||||||
assert apath.path == inner_path
|
assert apath.path == inner_path
|
||||||
|
|
||||||
|
|
||||||
def test_zarr_to_json(store, model_blank):
|
@pytest.mark.serialization
|
||||||
|
@pytest.mark.parametrize("dump_array", [True, False])
|
||||||
|
@pytest.mark.parametrize("roundtrip", [True, False])
|
||||||
|
def test_zarr_to_json(store, model_blank, roundtrip, dump_array):
|
||||||
expected_fields = (
|
expected_fields = (
|
||||||
"Type",
|
"Type",
|
||||||
"Data type",
|
"Data type",
|
||||||
|
@ -137,17 +143,22 @@ def test_zarr_to_json(store, model_blank):
|
||||||
|
|
||||||
array = zarr.array(lol_array, store=store)
|
array = zarr.array(lol_array, store=store)
|
||||||
instance = model_blank(array=array)
|
instance = model_blank(array=array)
|
||||||
as_json = json.loads(instance.model_dump_json())["array"]
|
|
||||||
assert "array" not in as_json
|
|
||||||
for field in expected_fields:
|
|
||||||
assert field in as_json
|
|
||||||
assert len(as_json["hexdigest"]) == 40
|
|
||||||
|
|
||||||
# dump the array itself too
|
context = {"dump_array": dump_array}
|
||||||
as_json = json.loads(instance.model_dump_json(context={"zarr_dump_array": True}))[
|
as_json = json.loads(
|
||||||
"array"
|
instance.model_dump_json(round_trip=roundtrip, context=context)
|
||||||
]
|
)["array"]
|
||||||
for field in expected_fields:
|
|
||||||
assert field in as_json
|
if roundtrip:
|
||||||
assert len(as_json["hexdigest"]) == 40
|
if dump_array:
|
||||||
assert as_json["array"] == lol_array
|
assert as_json["array"] == lol_array
|
||||||
|
else:
|
||||||
|
if as_json.get("file", False):
|
||||||
|
assert "array" not in as_json
|
||||||
|
|
||||||
|
for field in expected_fields:
|
||||||
|
assert field in as_json["info"]
|
||||||
|
assert len(as_json["info"]["hexdigest"]) == 40
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert as_json == lol_array
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import pdb
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from typing import Union, Optional, Any
|
from typing import Union, Optional, Any
|
||||||
|
@ -15,6 +13,7 @@ from numpydantic import dtype
|
||||||
from numpydantic.dtype import Number
|
from numpydantic.dtype import Number
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.json_schema
|
||||||
def test_ndarray_type():
|
def test_ndarray_type():
|
||||||
class Model(BaseModel):
|
class Model(BaseModel):
|
||||||
array: NDArray[Shape["2 x, * y"], Number]
|
array: NDArray[Shape["2 x, * y"], Number]
|
||||||
|
@ -40,6 +39,8 @@ def test_ndarray_type():
|
||||||
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
instance = Model(array=np.zeros((2, 3)), array_any=np.ones((3, 4, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.json_schema
|
||||||
def test_schema_unsupported_type():
|
def test_schema_unsupported_type():
|
||||||
"""
|
"""
|
||||||
Complex numbers should just be made with an `any` schema
|
Complex numbers should just be made with an `any` schema
|
||||||
|
@ -55,9 +56,11 @@ def test_schema_unsupported_type():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.json_schema
|
||||||
def test_schema_tuple():
|
def test_schema_tuple():
|
||||||
"""
|
"""
|
||||||
Types specified as tupled should have their schemas as a union
|
Types specified as tuples should have their schemas as a union
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Model(BaseModel):
|
class Model(BaseModel):
|
||||||
|
@ -72,6 +75,8 @@ def test_schema_tuple():
|
||||||
assert all([i["minimum"] == 0 for i in conditions])
|
assert all([i["minimum"] == 0 for i in conditions])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.json_schema
|
||||||
def test_schema_number():
|
def test_schema_number():
|
||||||
"""
|
"""
|
||||||
np.numeric should just be the float schema
|
np.numeric should just be the float schema
|
||||||
|
@ -115,12 +120,12 @@ def test_ndarray_union():
|
||||||
instance = Model(array=np.random.random((5, 10, 4, 6)))
|
instance = Model(array=np.random.random((5, 10, 4, 6)))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
|
@pytest.mark.dtype
|
||||||
@pytest.mark.parametrize("dtype", dtype.Number)
|
@pytest.mark.parametrize("dtype", dtype.Number)
|
||||||
def test_ndarray_unparameterized(dtype):
|
def test_ndarray_unparameterized(dtype):
|
||||||
"""
|
"""
|
||||||
NDArray without any parameters is any shape, any type
|
NDArray without any parameters is any shape, any type
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Model(BaseModel):
|
class Model(BaseModel):
|
||||||
|
@ -134,6 +139,7 @@ def test_ndarray_unparameterized(dtype):
|
||||||
_ = Model(array=np.zeros(dim_sizes, dtype=dtype))
|
_ = Model(array=np.zeros(dim_sizes, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
def test_ndarray_any():
|
def test_ndarray_any():
|
||||||
"""
|
"""
|
||||||
using :class:`typing.Any` in for the shape means any shape
|
using :class:`typing.Any` in for the shape means any shape
|
||||||
|
@ -164,6 +170,19 @@ def test_ndarray_coercion():
|
||||||
amod = Model(array=["a", "b", "c"])
|
amod = Model(array=["a", "b", "c"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
|
def test_shape_ellipsis():
|
||||||
|
"""
|
||||||
|
Test that ellipsis is a wildcard, rather than "repeat the last index"
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyModel(BaseModel):
|
||||||
|
array: NDArray[Shape["1, 2, ..."], Number]
|
||||||
|
|
||||||
|
_ = MyModel(array=np.zeros((1, 2, 3, 4, 5)))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.serialization
|
||||||
def test_ndarray_serialize():
|
def test_ndarray_serialize():
|
||||||
"""
|
"""
|
||||||
Arrays should be dumped to a list when using json, but kept as ndarray otherwise
|
Arrays should be dumped to a list when using json, but kept as ndarray otherwise
|
||||||
|
@ -188,6 +207,7 @@ _json_schema_types = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.json_schema
|
||||||
def test_json_schema_basic(array_model):
|
def test_json_schema_basic(array_model):
|
||||||
"""
|
"""
|
||||||
NDArray types should correctly generate a list of lists JSON schema
|
NDArray types should correctly generate a list of lists JSON schema
|
||||||
|
@ -210,6 +230,8 @@ def test_json_schema_basic(array_model):
|
||||||
assert inner["items"]["type"] == "number"
|
assert inner["items"]["type"] == "number"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.json_schema
|
||||||
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
|
@pytest.mark.parametrize("dtype", [*dtype.Integer, *dtype.Float])
|
||||||
def test_json_schema_dtype_single(dtype, array_model):
|
def test_json_schema_dtype_single(dtype, array_model):
|
||||||
"""
|
"""
|
||||||
|
@ -240,6 +262,8 @@ def test_json_schema_dtype_single(dtype, array_model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.json_schema
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dtype,expected",
|
"dtype,expected",
|
||||||
[
|
[
|
||||||
|
@ -266,6 +290,8 @@ def test_json_schema_dtype_builtin(dtype, expected, array_model):
|
||||||
assert inner_type["type"] == expected
|
assert inner_type["type"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dtype
|
||||||
|
@pytest.mark.json_schema
|
||||||
def test_json_schema_dtype_model():
|
def test_json_schema_dtype_model():
|
||||||
"""
|
"""
|
||||||
Pydantic models can be used in arrays as dtypes
|
Pydantic models can be used in arrays as dtypes
|
||||||
|
@ -314,6 +340,8 @@ def _recursive_array(schema):
|
||||||
assert any_of[1]["minimum"] == 0
|
assert any_of[1]["minimum"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.shape
|
||||||
|
@pytest.mark.json_schema
|
||||||
def test_json_schema_ellipsis():
|
def test_json_schema_ellipsis():
|
||||||
"""
|
"""
|
||||||
NDArray types should create a recursive JSON schema for any-shaped arrays
|
NDArray types should create a recursive JSON schema for any-shaped arrays
|
||||||
|
|
95
tests/test_serialization.py
Normal file
95
tests/test_serialization.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
Test serialization-specific functionality that doesn't need to be
|
||||||
|
applied across every interface (use test_interface/test_interfaces for that
|
||||||
|
"""
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.serialization
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def hdf5_at_path() -> Callable[[Path], None]:
|
||||||
|
_path = ""
|
||||||
|
|
||||||
|
def _hdf5_at_path(path: Path) -> None:
|
||||||
|
nonlocal _path
|
||||||
|
_path = path
|
||||||
|
h5f = h5py.File(path, "w")
|
||||||
|
_ = h5f.create_dataset("/data", data=np.array([[1, 2], [3, 4]]))
|
||||||
|
_ = h5f.create_dataset("subpath/to/dataset", data=np.array([[1, 2], [4, 5]]))
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
yield _hdf5_at_path
|
||||||
|
|
||||||
|
Path(_path).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_relative_path(hdf5_at_path, tmp_output_dir, model_blank):
|
||||||
|
"""
|
||||||
|
By default, we should make all paths relative to the cwd
|
||||||
|
"""
|
||||||
|
out_path = tmp_output_dir / "relative.h5"
|
||||||
|
hdf5_at_path(out_path)
|
||||||
|
model = model_blank(array=(out_path, "/data"))
|
||||||
|
rt = model.model_dump_json(round_trip=True)
|
||||||
|
file = json.loads(rt)["array"]["file"]
|
||||||
|
|
||||||
|
# should not be absolute
|
||||||
|
assert not Path(file).is_absolute()
|
||||||
|
# should be relative to cwd
|
||||||
|
out_file = (Path.cwd() / file).resolve()
|
||||||
|
assert out_file == out_path.resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank):
|
||||||
|
"""
|
||||||
|
When explicitly passed a path to be ``relative_to`` ,
|
||||||
|
relative to that instead of cwd
|
||||||
|
"""
|
||||||
|
out_path = tmp_output_dir / "relative.h5"
|
||||||
|
relative_to_path = Path(__file__) / "fake_dir" / "sub_fake_dir"
|
||||||
|
expected_path = "../../../__tmp__/relative.h5"
|
||||||
|
|
||||||
|
hdf5_at_path(out_path)
|
||||||
|
model = model_blank(array=(out_path, "/data"))
|
||||||
|
rt = model.model_dump_json(
|
||||||
|
round_trip=True, context={"relative_to": str(relative_to_path)}
|
||||||
|
)
|
||||||
|
data = json.loads(rt)["array"]
|
||||||
|
file = data["file"]
|
||||||
|
|
||||||
|
# should not be absolute
|
||||||
|
assert not Path(file).is_absolute()
|
||||||
|
# should be expected path and reach the file
|
||||||
|
assert file == expected_path
|
||||||
|
assert (relative_to_path / file).resolve() == out_path.resolve()
|
||||||
|
|
||||||
|
# we shouldn't have touched `/data` even though it is pathlike
|
||||||
|
assert data["path"] == "/data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_relative_to_path(hdf5_at_path, tmp_output_dir, model_blank):
|
||||||
|
"""
|
||||||
|
When told, we make paths absolute
|
||||||
|
"""
|
||||||
|
out_path = tmp_output_dir / "relative.h5"
|
||||||
|
expected_dataset = "subpath/to/dataset"
|
||||||
|
|
||||||
|
hdf5_at_path(out_path)
|
||||||
|
model = model_blank(array=(out_path, expected_dataset))
|
||||||
|
rt = model.model_dump_json(round_trip=True, context={"absolute_paths": True})
|
||||||
|
data = json.loads(rt)["array"]
|
||||||
|
file = data["file"]
|
||||||
|
|
||||||
|
# should be absolute and equal to out_path
|
||||||
|
assert Path(file).is_absolute()
|
||||||
|
assert Path(file) == out_path.resolve()
|
||||||
|
|
||||||
|
# shouldn't have absolutized subpath even if it's pathlike
|
||||||
|
assert data["path"] == expected_dataset
|
|
@ -1,5 +1,3 @@
|
||||||
import pdb
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -9,6 +7,8 @@ import numpy as np
|
||||||
|
|
||||||
from numpydantic import NDArray, Shape
|
from numpydantic import NDArray, Shape
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.shape
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"shape,valid",
|
"shape,valid",
|
||||||
|
|
Loading…
Reference in a new issue