nearing the end - need to do final top-level packing but we're almost there!

This commit is contained in:
sneakers-the-rat 2023-09-27 01:19:10 -07:00
parent a2da236b2b
commit 9fcc1458fb
7 changed files with 338 additions and 62 deletions

View file

@ -118,7 +118,7 @@ class DatasetAdapter(ClassAdapter):
- null
"""
if self.cls.name and ((
if self.cls.name and len(self.cls.attributes) == 0 and ((
# single-layer list
not any([isinstance(dim, list) for dim in self.cls.dims]) and
len(self.cls.dims) == 1
@ -209,27 +209,36 @@ class DatasetAdapter(ClassAdapter):
dims_shape = tuple(dict.fromkeys(dims_shape).keys())
# if we only have one possible dimension, it's equivalent to a list, so we just return the slot
if len(dims_shape) == 1 and self.parent:
quantity = QUANTITY_MAP[dataset.quantity]
slot = SlotDefinition(
name=dataset.name,
range=dtype,
description=dataset.doc,
required=quantity['required'],
multivalued=True
)
res.classes[0].attributes.update({dataset.name: slot})
self._handlers.append('arraylike-1d')
return res
# if len(dims_shape) == 1 and self.parent:
# quantity = QUANTITY_MAP[dataset.quantity]
# slot = SlotDefinition(
# name=dataset.name,
# range=dtype,
# description=dataset.doc,
# required=quantity['required'],
# multivalued=True
# )
# res.classes[0].attributes.update({dataset.name: slot})
# self._handlers.append('arraylike-1d')
# return res
# --------------------------------------------------
# SPECIAL CASE - allen institute's ndx-aibs-ecephys.extension
# confuses "dims" with "shape" , eg shape = [None], dims = [3].
# So we hardcode that here...
# --------------------------------------------------
if len(dims_shape) == 1 and isinstance(dims_shape[0][0], int) and dims_shape[0][1] is None:
dims_shape = (('dim', dims_shape[0][0]),)
# now make slots for each of them
slots = []
for dims, shape in dims_shape:
# if a dim is present in all possible combinations of dims, make it required
if all([dims in inner_dim for inner_dim in dataset.dims]):
# if there is just a single list of possible dimensions, it's required
if not any([isinstance(inner_dim, list) for inner_dim in dataset.dims]):
required = True
# or if there is just a single list of possible dimensions
elif not any([isinstance(inner_dim, list) for inner_dim in dataset.dims]):
# if a dim is present in all possible combinations of dims, make it required
elif all([dims in inner_dim for inner_dim in dataset.dims]):
required = True
else:
required = False

View file

@ -217,6 +217,7 @@ class NWBPydanticGenerator(PydanticGenerator):
SKIP_CLASSES=('',)
INJECTED_FIELDS = (
'hdf5_path: Optional[str] = Field(None, description="The absolute path that this object is stored in an NWB file")',
'object_id: Optional[str] = Field(None, description="Unique UUID for each object")'
)
# SKIP_CLASSES=('VectorData','VectorIndex')
split:bool=True

View file

@ -37,6 +37,7 @@ from nwb_linkml.translate import generate_from_nwbfile
if TYPE_CHECKING:
from nwb_linkml.models import NWBFile
from nwb_linkml.providers.schema import SchemaProvider
from nwb_linkml.types.hdf5 import HDF5_Path
class HDF5IO():
@ -58,8 +59,9 @@ class HDF5IO():
def read(self, path:str) -> BaseModel | Dict[str, BaseModel]: ...
def read(self, path:Optional[str] = None):
print('starting read')
provider = self.make_provider()
print('provider made')
h5f = h5py.File(str(self.path))
if path:
src = h5f.get(path)
@ -71,7 +73,7 @@ class HDF5IO():
children = flatten_hdf(src)
else:
raise NotImplementedError('directly read individual datasets')
print('hdf flattened')
queue = ReadQueue(
h5f=self.path,
queue=children,
@ -81,11 +83,30 @@ class HDF5IO():
#pdb.set_trace()
# Apply initial planning phase of reading
queue.apply_phase(ReadPhases.plan)
print('phase - plan completed')
# Now do read operations until we're finished
queue.apply_phase(ReadPhases.read)
print('phase - read completed')
# if len(queue.queue)> 0:
# warnings.warn('Did not complete all items during read phase!')
queue.apply_phase(ReadPhases.construct)
# --------------------------------------------------
# FIXME: Hardcoding top-level file reading just for the win
# --------------------------------------------------
root = finish_root_hackily(queue)
file = NWBFile(**root)
pdb.set_trace()
#
#
# data = {}
@ -169,6 +190,22 @@ class HDF5IO():
return list(data[:])
def finish_root_hackily(queue: ReadQueue) -> dict:
root = {'name': 'root'}
for k, v in queue.queue.items():
if isinstance(v.result, dict):
res_dict = {}
for inner_k, inner_v in v.result.items():
if isinstance(inner_v, HDF5_Path):
inner_res = queue.completed.get(inner_v)
if inner_res is not None:
res_dict[inner_k] = inner_res.result
else:
res_dict[inner_k] = inner_v
root[res_dict['name']] = res_dict
else:
root[v.path.split('/')[-1]] = v.result
return root
def read_specs_as_dicts(group: h5py.Group) -> dict:
"""

View file

@ -70,4 +70,22 @@ np_to_python = {
**{n:int for n in (np.int8, np.int16, np.int32, np.int64, np.short, np.uint8, np.uint16, np.uint32, np.uint64, np.uint)},
**{n:float for n in (np.float16, np.float32, np.floating, np.float32, np.float64, np.single, np.double, np.float_)},
**{n:str for n in (np.character, np.str_, np.string_, np.unicode_)}
}
}
allowed_precisions = {
'float': ['double'],
'int8': ['short', 'int', 'long', 'int16', 'int32', 'int64'],
'short': ['int', 'long'],
'int': ['long'],
'uint8': ['uint8', 'uint16', 'uint32', 'uint64'],
'uint16': ['uint16', 'uint32', 'uint64'],
'uint32': ['uint32', 'uint64'],
'float16': ['float16', 'float32', 'float64'],
'float32': ['float32', 'float64'],
'utf': ['ascii']
}
"""
Following HDMF, it turns out that specifying precision actually specifies minimum precision
https://github.com/hdmf-dev/hdmf/blob/ddc842b5c81d96e0b957b96e88533b16c137e206/src/hdmf/validate/validator.py#L22
https://github.com/hdmf-dev/hdmf/blob/ddc842b5c81d96e0b957b96e88533b16c137e206/src/hdmf/spec/spec.py#L694-L714
"""

View file

@ -4,18 +4,21 @@ Maps for reading and writing from HDF5
We have sort of diverged from the initial idea of a generalized map as in :class:`linkml.map.Map` ,
so we will make our own mapping class here and re-evaluate whether they should be unified later
"""
import pdb
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Literal, List, Dict, Optional, Type
from typing import Literal, List, Dict, Optional, Type, Union
import h5py
from enum import StrEnum
from pydantic import BaseModel, Field, ConfigDict
from pydantic import BaseModel, Field, ConfigDict, ValidationError
import dask.array as da
from nwb_linkml.providers.schema import SchemaProvider
from nwb_linkml.maps.hdmf import dynamictable_to_model
from nwb_linkml.types.hdf5 import HDF5_Path
from nwb_linkml.types.ndarray import NDArrayProxy
class ReadPhases(StrEnum):
@ -55,7 +58,7 @@ class H5ReadResult(BaseModel):
"""Result returned by each of our mapping operations"""
path: str
"""absolute hdf5 path of element"""
source: H5SourceItem
source: Union[H5SourceItem, 'H5ReadResult']
"""
Source that this result is based on.
The map can modify this item, so the container should update the source
@ -66,7 +69,7 @@ class H5ReadResult(BaseModel):
Was this item completed by this map step? False for cases where eg.
we still have dependencies that need to be completed before this one
"""
result: Optional[BaseModel | dict | str | int | float] = None
result: Optional[dict | str | int | float | BaseModel] = None
"""
If completed, built result. A dict that can be instantiated into the model.
If completed is True and result is None, then remove this object
@ -87,6 +90,14 @@ class H5ReadResult(BaseModel):
"""
Optional: The neurodata type to use for this object
"""
applied: List[str] = Field(default_factory=list)
"""
Which stages were applied to this item
"""
errors: List[str] = Field(default_factory=list)
"""
Problems that occurred during resolution
"""
FlatH5 = Dict[str, H5SourceItem]
@ -133,6 +144,8 @@ class PruneEmpty(HDF5Map):
completed=True
)
# class ResolveVectorData(HDF5Map):
# """
# We will load vanilla VectorData as part of :class:`.ResolveDynamicTable`
@ -194,7 +207,8 @@ class ResolveDynamicTable(HDF5Map):
source=src,
result=model,
completes=completes,
completed = True
completed = True,
applied=['ResolveDynamicTable']
)
@ -212,6 +226,8 @@ class ResolveModelGroup(HDF5Map):
@classmethod
def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
model = provider.get_class(src.namespace, src.neurodata_type)
res = {}
with h5py.File(src.h5f_path, 'r') as h5f:
@ -222,7 +238,12 @@ class ResolveModelGroup(HDF5Map):
continue
if key in obj.keys():
# stash a reference to this, we'll compile it at the end
res[key] = HDF5_Path('/'.join([src.path, key]))
if src.path == '/':
target_path = '/' + key
else:
target_path = '/'.join([src.path, key])
res[key] = HDF5_Path(target_path)
res['hdf5_path'] = src.path
res['name'] = src.parts[-1]
@ -233,25 +254,44 @@ class ResolveModelGroup(HDF5Map):
result = res,
model = model,
namespace=src.namespace,
neurodata_type=src.neurodata_type
neurodata_type=src.neurodata_type,
applied=['ResolveModelGroup']
)
class ResolveDatasetAsDict(HDF5Map):
"""Mutually exclusive with :class:`.ResolveScalars`"""
phase = ReadPhases.read
priority = 11
exclusive = True
@classmethod
def check(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if src.h5_type == 'dataset' and 'neurodata_type' not in src.attrs:
with h5py.File(src.h5f_path, 'r') as h5f:
obj = h5f.get(src.path)
if obj.shape != ():
return True
else: return False
else:
return False
@classmethod
def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
res = {
'array': NDArrayProxy(h5f_file=src.h5f_path, path=src.path),
'hdf5_path' : src.path,
'name': src.parts[-1],
**src.attrs
}
return H5ReadResult(
path = src.path,
source=src,
completed=True,
result=res,
applied=['ResolveDatasetAsDict']
)
#
# class ResolveModelDataset(HDF5Map):
# phase = ReadPhases.read
# priority = 10
# exclusive = True
#
# @classmethod
# def check(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
# if 'neurodata_type' in src.attrs and src.h5_type == 'dataset':
# return True
# else:
# return False
#
# def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
#
class ResolveScalars(HDF5Map):
phase = ReadPhases.read
priority = 11 #catchall
@ -264,6 +304,8 @@ class ResolveScalars(HDF5Map):
obj = h5f.get(src.path)
if obj.shape == ():
return True
else:
return False
else:
return False
@classmethod
@ -275,9 +317,138 @@ class ResolveScalars(HDF5Map):
path=src.path,
source = src,
completed=True,
result = res
result = res,
applied=['ResolveScalars']
)
class ResolveContainerGroups(HDF5Map):
phase = ReadPhases.read
priority = 9
@classmethod
def check(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if src.h5_type == 'group' and 'neurodata_type' not in src.attrs and len(src.attrs) == 0:
with h5py.File(src.h5f_path, 'r') as h5f:
obj = h5f.get(src.path)
if len(obj.keys()) > 0:
return True
else:
return False
else:
return False
@classmethod
def apply(cls, src: H5SourceItem, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
"""Simple, just return a dict with references to its children"""
with h5py.File(src.h5f_path, 'r') as h5f:
obj = h5f.get(src.path)
children = {}
for k, v in obj.items():
children[k] = HDF5_Path(v.name)
res = {
'name': src.parts[-1],
**children
}
return H5ReadResult(
path=src.path,
source=src,
completed=True,
result=res,
applied=['ResolveContainerGroups']
)
# --------------------------------------------------
# Completion Steps
# --------------------------------------------------
class CompleteDynamicTables(HDF5Map):
"""Nothing to do! already done!"""
phase = ReadPhases.construct
priority = 1
exclusive = True
@classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if 'ResolveDynamicTable' in src.applied:
return True
else:
return False
@classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
return src
class CompleteModelGroups(HDF5Map):
phase = ReadPhases.construct
priority = 2
@classmethod
def check(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> bool:
if src.model is not None:
return True
else:
return False
@classmethod
def apply(cls, src: H5ReadResult, provider:SchemaProvider, completed: Dict[str, H5ReadResult]) -> H5ReadResult:
# gather any results that were left for completion elsewhere
res = {k:v for k,v in src.result.items() if not isinstance(v, HDF5_Path)}
errors = []
completes = []
for path, item in src.result.items():
if isinstance(item, HDF5_Path):
other_item = completed.get(item, None)
if other_item is None:
errors.append(f'Couldnt find {item}')
continue
if isinstance(other_item.result, dict):
# resolve any other children that it might have...
# FIXME: refactor this lmao so bad
for k,v in other_item.result.items():
if isinstance(v, HDF5_Path):
inner_result = completed.get(v, None)
if inner_result is None:
errors.append(f'Couldnt find inner item {v}')
continue
other_item.result[k] = inner_result.result
completes.append(v)
res[other_item.result['name']] = other_item.result
else:
res[path] = other_item.result
completes.append(other_item.path)
#try:
instance = src.model(**res)
return H5ReadResult(
path=src.path,
source=src,
result=instance,
model=src.model,
completed=True,
completes=completes,
neurodata_type=src.neurodata_type,
namespace=src.namespace,
applied=src.applied + ['CompleteModelGroups'],
errors=errors
)
# except ValidationError:
# # didn't get it! try again next time
# return H5ReadResult(
# path=src.path,
# source=src,
# result=src,
# model=src.model,
# completed=True,
# completes=completes,
# neurodata_type=src.neurodata_type,
# namespace=src.namespace,
# applied=src.applied + ['CompleteModelGroups']
# )
@ -291,7 +462,7 @@ class ReadQueue(BaseModel):
provider: SchemaProvider = Field(
description="SchemaProvider used by each of the items in the read queue"
)
queue: Dict[str,H5SourceItem] = Field(
queue: Dict[str,H5SourceItem|H5ReadResult] = Field(
default_factory=dict,
description="Items left to be instantiated, keyed by hdf5 path",
)
@ -300,11 +471,16 @@ class ReadQueue(BaseModel):
description="Items that have already been instantiated, keyed by hdf5 path"
)
model_config = ConfigDict(arbitrary_types_allowed=True)
phases_completed: List[ReadPhases] = Field(
default_factory=list,
description="Phases that have already been completed")
def apply_phase(self, phase:ReadPhases):
phase_maps = [m for m in HDF5Map.__subclasses__() if m.phase == phase]
phase_maps = sorted(phase_maps, key=lambda x: x.priority)
# if we've moved to the
results = []
# TODO: Thread/multiprocess this
@ -316,6 +492,7 @@ class ReadQueue(BaseModel):
break # out of inner iteration
# remake the source queue and save results
completes = []
for res in results:
# remove the original item
del self.queue[res.path]
@ -327,16 +504,42 @@ class ReadQueue(BaseModel):
# just drop it.
# if we have completed other things, delete them from the queue
for also_completed in res.completes:
try:
del self.queue[also_completed]
except KeyError:
# normal, we might have already deleted this in a previous step
pass
completes.extend(res.completes)
# for also_completed in res.completes:
# try:
# del self.queue[also_completed]
# except KeyError:
# # normal, we might have already deleted this in a previous step
# pass
else:
# if we didn't complete the item (eg. we found we needed more dependencies),
# add the updated source to the queue again
self.queue[res.path] = res.source
if phase != ReadPhases.construct:
self.queue[res.path] = res.source
else:
self.queue[res.path] = res
# delete the ones that were already completed but might have been
# incorrectly added back in the pile
for c in completes:
try:
del self.queue[c]
except KeyError:
pass
# if we have nothing left in our queue, we have completed this phase
# and prepare only ever has one pass
if phase == ReadPhases.plan:
self.phases_completed.append(phase)
return
if len(self.queue) == 0:
self.phases_completed.append(phase)
if phase != ReadPhases.construct:
# if we're not in the last phase, move our completed to our queue
self.queue = self.completed.copy()
@ -367,8 +570,8 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str,
# get references in attrs and datasets to populate dependencies
#depends = get_references(obj)
#if not name.startswith('/'):
# name = '/' + name
if not name.startswith('/'):
name = '/' + name
attrs = dict(obj.attrs.items())
@ -384,6 +587,8 @@ def flatten_hdf(h5f:h5py.File | h5py.Group, skip='specifications') -> Dict[str,
)
h5f.visititems(_itemize)
# # then add the root item
# _itemize(h5f.name, h5f)
return items

View file

@ -1,3 +1,5 @@
from typing import Annotated
HDF5_Path = Annotated[str, """Trivial subclass of string to indicate that it is a reference to a location within an HDF5 file"""]
class HDF5_Path(str):
"""Trivial subclass of string to indicate that it is a reference to a location within an HDF5 file"""
pass

View file

@ -35,7 +35,8 @@ from nptyping.ndarray import NDArrayMeta
from nptyping import Shape, Number
from nptyping.shape_expression import check_shape
from nwb_linkml.maps.dtype import np_to_python
from nwb_linkml.maps.dtype import np_to_python, allowed_precisions
class NDArray(_NDArray):
@ -59,13 +60,15 @@ class NDArray(_NDArray):
def validate_dtype(value: np.ndarray) -> np.ndarray:
if dtype is Any:
return value
assert value.dtype == dtype, f"Invalid dtype! expected {dtype}, got {value.dtype}"
assert value.dtype == dtype or value.dtype.name in allowed_precisions[dtype.__name__], f"Invalid dtype! expected {dtype}, got {value.dtype}"
return value
def validate_array(value: Any) -> np.ndarray:
if isinstance(value, np.ndarray):
assert cls.__instancecheck__(value), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
elif isinstance(value, DaskArray):
assert shape is Any or check_shape(value.shape, shape), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
# not using instancecheck because nwb doesnt actually validate precision
# this step is now just validating shape
# if isinstance(value, np.ndarray):
# assert cls.__instancecheck__(value), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
# elif isinstance(value, DaskArray):
assert shape is Any or check_shape(value.shape, shape), f'Invalid shape! expected shape {shape.prepared_args}, got shape {value.shape}'
return value
@ -146,7 +149,8 @@ class NDArray(_NDArray):
core_schema.no_info_plain_validator_function(coerce_list),
core_schema.union_schema([
core_schema.is_instance_schema(cls=np.ndarray),
core_schema.is_instance_schema(cls=DaskArray)
core_schema.is_instance_schema(cls=DaskArray),
core_schema.is_instance_schema(cls=NDArrayProxy)
]),
core_schema.no_info_plain_validator_function(validate_dtype),
core_schema.no_info_plain_validator_function(validate_array)