fix array casting for dtypes that have a shape attr but nothing in it

This commit is contained in:
sneakers-the-rat 2024-08-06 20:41:00 -07:00
parent edea802ff1
commit 3ee7c68e15
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
3 changed files with 61 additions and 90 deletions

View file

@ -23,7 +23,6 @@ from pydantic import BaseModel, ConfigDict, Field
from nwb_linkml.annotations import unwrap_optional
from nwb_linkml.maps import Map
from nwb_linkml.maps.hdmf import dynamictable_to_model
from nwb_linkml.types.hdf5 import HDF5_Path
if sys.version_info.minor >= 11:
@ -234,63 +233,64 @@ class PruneEmpty(HDF5Map):
return H5ReadResult.model_construct(path=src.path, source=src, completed=True)
class ResolveDynamicTable(HDF5Map):
"""
Handle loading a dynamic table!
Dynamic tables are sort of odd in that their models don't include their fields
(except as a list of strings in ``colnames`` ),
so we need to create a new model that includes fields for each column,
and then we include the datasets as :class:`~numpydantic.interface.hdf5.H5ArrayPath`
objects which lazy load the arrays in a thread/process safe way.
This map also resolves the child elements,
indicating so by the ``completes`` field in the :class:`.ReadResult`
"""
phase = ReadPhases.read
priority = 1
@classmethod
def check(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> bool:
if src.h5_type == "dataset":
return False
if "neurodata_type" in src.attrs:
if src.attrs["neurodata_type"] == "DynamicTable":
return True
# otherwise, see if it's a subclass
model = provider.get_class(src.attrs["namespace"], src.attrs["neurodata_type"])
# just inspect the MRO as strings rather than trying to check subclasses because
# we might replace DynamicTable in the future, and there isn't a stable DynamicTable
# class to inherit from anyway because of the whole multiple versions thing
parents = [parent.__name__ for parent in model.__mro__]
return "DynamicTable" in parents
else:
return False
@classmethod
def apply(
cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
) -> H5ReadResult:
with h5py.File(src.h5f_path, "r") as h5f:
obj = h5f.get(src.path)
# make a populated model :)
base_model = provider.get_class(src.namespace, src.neurodata_type)
model = dynamictable_to_model(obj, base=base_model)
completes = [HDF5_Path(child.name) for child in obj.values()]
return H5ReadResult(
path=src.path,
source=src,
result=model,
completes=completes,
completed=True,
applied=["ResolveDynamicTable"],
)
#
# class ResolveDynamicTable(HDF5Map):
# """
# Handle loading a dynamic table!
#
# Dynamic tables are sort of odd in that their models don't include their fields
# (except as a list of strings in ``colnames`` ),
# so we need to create a new model that includes fields for each column,
# and then we include the datasets as :class:`~numpydantic.interface.hdf5.H5ArrayPath`
# objects which lazy load the arrays in a thread/process safe way.
#
# This map also resolves the child elements,
# indicating so by the ``completes`` field in the :class:`.ReadResult`
# """
#
# phase = ReadPhases.read
# priority = 1
#
# @classmethod
# def check(
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
# ) -> bool:
# if src.h5_type == "dataset":
# return False
# if "neurodata_type" in src.attrs:
# if src.attrs["neurodata_type"] == "DynamicTable":
# return True
# # otherwise, see if it's a subclass
# model = provider.get_class(src.attrs["namespace"], src.attrs["neurodata_type"])
# # just inspect the MRO as strings rather than trying to check subclasses because
# # we might replace DynamicTable in the future, and there isn't a stable DynamicTable
# # class to inherit from anyway because of the whole multiple versions thing
# parents = [parent.__name__ for parent in model.__mro__]
# return "DynamicTable" in parents
# else:
# return False
#
# @classmethod
# def apply(
# cls, src: H5SourceItem, provider: "SchemaProvider", completed: Dict[str, H5ReadResult]
# ) -> H5ReadResult:
# with h5py.File(src.h5f_path, "r") as h5f:
# obj = h5f.get(src.path)
#
# # make a populated model :)
# base_model = provider.get_class(src.namespace, src.neurodata_type)
# model = dynamictable_to_model(obj, base=base_model)
#
# completes = [HDF5_Path(child.name) for child in obj.values()]
#
# return H5ReadResult(
# path=src.path,
# source=src,
# result=model,
# completes=completes,
# completed=True,
# applied=["ResolveDynamicTable"],
# )
class ResolveModelGroup(HDF5Map):

View file

@ -1,12 +1,7 @@
from __future__ import annotations
from datetime import datetime, date
from decimal import Decimal
from enum import Enum
import re
import sys
from ...hdmf_common.v1_8_0.hdmf_common_base import Data, Container
from ...hdmf_common.v1_8_0.hdmf_common_base import Data
from pandas import DataFrame, Series
from typing import Any, ClassVar, List, Literal, Dict, Optional, Union, overload, Tuple
from typing import Any, ClassVar, List, Dict, Optional, Union, overload, Tuple
from pydantic import (
BaseModel,
ConfigDict,
@ -282,7 +277,7 @@ class DynamicTableMixin(BaseModel):
# special case where pandas will unpack a pydantic model
# into {n_fields} rows, rather than keeping it in a dict
val = Series([val])
elif isinstance(rows, int) and hasattr(val, "shape") and len(val) > 1:
elif isinstance(rows, int) and hasattr(val, "shape") and val.shape and val.shape[0] > 1:
# special case where we are returning a row in a ragged array,
# same as above - prevent pandas pivoting to long
val = Series([val])

View file

@ -1,24 +0,0 @@
import time
import h5py
import pytest
from nwb_linkml.maps.hdmf import dynamictable_to_model, model_from_dynamictable
NWBFILE = "/Users/jonny/Dropbox/lab/p2p_ld/data/nwb/sub-738651046_ses-760693773.nwb"
@pytest.mark.xfail()
@pytest.mark.parametrize("dataset", ["aibs.nwb"])
def test_make_dynamictable(data_dir, dataset):
nwbfile = data_dir / dataset
h5f = h5py.File(nwbfile, "r")
group = h5f["units"]
start_time = time.time()
model = model_from_dynamictable(group)
data = dynamictable_to_model(group, model)
_ = data.model_dump_json()
end_time = time.time()
total_time = end_time - start_time