update models

This commit is contained in:
sneakers-the-rat 2024-08-15 01:47:55 -07:00
parent 24494b8ee4
commit f5a4173494
Signed by untrusted user who does not match committer: jonny
GPG key ID: 6DCB96EF1E4D232D
4 changed files with 96 additions and 60 deletions

View file

@ -134,9 +134,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
for index, span, and timeseries for index, span, and timeseries
""" """
idx_start: NDArray[Any, int] idx_start: NDArray[Shape["*"], int]
count: NDArray[Any, int] count: NDArray[Shape["*"], int]
timeseries: NDArray[Any, BaseModel] timeseries: NDArray
@model_validator(mode="after") @model_validator(mode="after")
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin": def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
@ -175,11 +175,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
) )
if isinstance(item, (int, np.integer)): if isinstance(item, (int, np.integer)):
return self.timeseries[self._slice_helper(item)] return self.timeseries[item][self._slice_helper(item)]
elif isinstance(item, slice): elif isinstance(item, (slice, Iterable)):
return [self.timeseries[subitem] for subitem in self._slice_helper(item)] if isinstance(item, slice):
elif isinstance(item, Iterable): item = range(*item.indices(len(self.idx_start)))
return [self.timeseries[self._slice_helper(subitem)] for subitem in item] return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item]
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {item}, must be an int, slice, or iterable" f"Dont know how to index with {item}, must be an int, slice, or iterable"
@ -192,13 +192,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
" never done in the core schema." " never done in the core schema."
) )
if isinstance(key, (int, np.integer)): if isinstance(key, (int, np.integer)):
self.timeseries[self._slice_helper(key)] = value self.timeseries[key][self._slice_helper(key)] = value
elif isinstance(key, slice): elif isinstance(key, (slice, Iterable)):
for subitem in self._slice_helper(key): if isinstance(key, slice):
self.timeseries[subitem] = value key = range(*key.indices(len(self.idx_start)))
elif isinstance(key, Iterable):
if isinstance(value, Iterable):
if len(key) != len(value):
raise ValueError(
"Can only assign equal-length iterable to a slice, manually index the"
" target Timeseries object if you need more control"
)
for subitem, subvalue in zip(key, value):
self.timeseries[subitem][self._slice_helper(subitem)] = subvalue
else:
for subitem in key: for subitem in key:
self.timeseries[self._slice_helper(subitem)] = value self.timeseries[subitem][self._slice_helper(subitem)] = value
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {key}, must be an int, slice, or iterable" f"Dont know how to index with {key}, must be an int, slice, or iterable"

View file

@ -145,9 +145,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
for index, span, and timeseries for index, span, and timeseries
""" """
idx_start: NDArray[Any, int] idx_start: NDArray[Shape["*"], int]
count: NDArray[Any, int] count: NDArray[Shape["*"], int]
timeseries: NDArray[Any, BaseModel] timeseries: NDArray
@model_validator(mode="after") @model_validator(mode="after")
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin": def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
@ -186,11 +186,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
) )
if isinstance(item, (int, np.integer)): if isinstance(item, (int, np.integer)):
return self.timeseries[self._slice_helper(item)] return self.timeseries[item][self._slice_helper(item)]
elif isinstance(item, slice): elif isinstance(item, (slice, Iterable)):
return [self.timeseries[subitem] for subitem in self._slice_helper(item)] if isinstance(item, slice):
elif isinstance(item, Iterable): item = range(*item.indices(len(self.idx_start)))
return [self.timeseries[self._slice_helper(subitem)] for subitem in item] return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item]
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {item}, must be an int, slice, or iterable" f"Dont know how to index with {item}, must be an int, slice, or iterable"
@ -203,13 +203,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
" never done in the core schema." " never done in the core schema."
) )
if isinstance(key, (int, np.integer)): if isinstance(key, (int, np.integer)):
self.timeseries[self._slice_helper(key)] = value self.timeseries[key][self._slice_helper(key)] = value
elif isinstance(key, slice): elif isinstance(key, (slice, Iterable)):
for subitem in self._slice_helper(key): if isinstance(key, slice):
self.timeseries[subitem] = value key = range(*key.indices(len(self.idx_start)))
elif isinstance(key, Iterable):
if isinstance(value, Iterable):
if len(key) != len(value):
raise ValueError(
"Can only assign equal-length iterable to a slice, manually index the"
" target Timeseries object if you need more control"
)
for subitem, subvalue in zip(key, value):
self.timeseries[subitem][self._slice_helper(subitem)] = subvalue
else:
for subitem in key: for subitem in key:
self.timeseries[self._slice_helper(subitem)] = value self.timeseries[subitem][self._slice_helper(subitem)] = value
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {key}, must be an int, slice, or iterable" f"Dont know how to index with {key}, must be an int, slice, or iterable"

View file

@ -145,9 +145,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
for index, span, and timeseries for index, span, and timeseries
""" """
idx_start: NDArray[Any, int] idx_start: NDArray[Shape["*"], int]
count: NDArray[Any, int] count: NDArray[Shape["*"], int]
timeseries: NDArray[Any, BaseModel] timeseries: NDArray
@model_validator(mode="after") @model_validator(mode="after")
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin": def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
@ -186,11 +186,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
) )
if isinstance(item, (int, np.integer)): if isinstance(item, (int, np.integer)):
return self.timeseries[self._slice_helper(item)] return self.timeseries[item][self._slice_helper(item)]
elif isinstance(item, slice): elif isinstance(item, (slice, Iterable)):
return [self.timeseries[subitem] for subitem in self._slice_helper(item)] if isinstance(item, slice):
elif isinstance(item, Iterable): item = range(*item.indices(len(self.idx_start)))
return [self.timeseries[self._slice_helper(subitem)] for subitem in item] return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item]
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {item}, must be an int, slice, or iterable" f"Dont know how to index with {item}, must be an int, slice, or iterable"
@ -203,13 +203,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
" never done in the core schema." " never done in the core schema."
) )
if isinstance(key, (int, np.integer)): if isinstance(key, (int, np.integer)):
self.timeseries[self._slice_helper(key)] = value self.timeseries[key][self._slice_helper(key)] = value
elif isinstance(key, slice): elif isinstance(key, (slice, Iterable)):
for subitem in self._slice_helper(key): if isinstance(key, slice):
self.timeseries[subitem] = value key = range(*key.indices(len(self.idx_start)))
elif isinstance(key, Iterable):
if isinstance(value, Iterable):
if len(key) != len(value):
raise ValueError(
"Can only assign equal-length iterable to a slice, manually index the"
" target Timeseries object if you need more control"
)
for subitem, subvalue in zip(key, value):
self.timeseries[subitem][self._slice_helper(subitem)] = subvalue
else:
for subitem in key: for subitem in key:
self.timeseries[self._slice_helper(subitem)] = value self.timeseries[subitem][self._slice_helper(subitem)] = value
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {key}, must be an int, slice, or iterable" f"Dont know how to index with {key}, must be an int, slice, or iterable"

View file

@ -145,9 +145,9 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
for index, span, and timeseries for index, span, and timeseries
""" """
idx_start: NDArray[Any, int] idx_start: NDArray[Shape["*"], int]
count: NDArray[Any, int] count: NDArray[Shape["*"], int]
timeseries: NDArray[Any, BaseModel] timeseries: NDArray
@model_validator(mode="after") @model_validator(mode="after")
def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin": def ensure_equal_length(self) -> "TimeSeriesReferenceVectorDataMixin":
@ -186,11 +186,11 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
) )
if isinstance(item, (int, np.integer)): if isinstance(item, (int, np.integer)):
return self.timeseries[self._slice_helper(item)] return self.timeseries[item][self._slice_helper(item)]
elif isinstance(item, slice): elif isinstance(item, (slice, Iterable)):
return [self.timeseries[subitem] for subitem in self._slice_helper(item)] if isinstance(item, slice):
elif isinstance(item, Iterable): item = range(*item.indices(len(self.idx_start)))
return [self.timeseries[self._slice_helper(subitem)] for subitem in item] return [self.timeseries[subitem][self._slice_helper(subitem)] for subitem in item]
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {item}, must be an int, slice, or iterable" f"Dont know how to index with {item}, must be an int, slice, or iterable"
@ -203,13 +203,22 @@ class TimeSeriesReferenceVectorDataMixin(VectorDataMixin):
" never done in the core schema." " never done in the core schema."
) )
if isinstance(key, (int, np.integer)): if isinstance(key, (int, np.integer)):
self.timeseries[self._slice_helper(key)] = value self.timeseries[key][self._slice_helper(key)] = value
elif isinstance(key, slice): elif isinstance(key, (slice, Iterable)):
for subitem in self._slice_helper(key): if isinstance(key, slice):
self.timeseries[subitem] = value key = range(*key.indices(len(self.idx_start)))
elif isinstance(key, Iterable):
if isinstance(value, Iterable):
if len(key) != len(value):
raise ValueError(
"Can only assign equal-length iterable to a slice, manually index the"
" target Timeseries object if you need more control"
)
for subitem, subvalue in zip(key, value):
self.timeseries[subitem][self._slice_helper(subitem)] = subvalue
else:
for subitem in key: for subitem in key:
self.timeseries[self._slice_helper(subitem)] = value self.timeseries[subitem][self._slice_helper(subitem)] = value
else: else:
raise ValueError( raise ValueError(
f"Dont know how to index with {key}, must be an int, slice, or iterable" f"Dont know how to index with {key}, must be an int, slice, or iterable"