"""
The class `IntervalSet` deals with non-overlaping epochs. `IntervalSet` objects can interact with each other or with the time series objects.
"""
import importlib
import warnings
from numbers import Number
import numpy as np
import pandas as pd
from numpy.lib.mixins import NDArrayOperatorsMixin
from tabulate import tabulate
from ._jitted_functions import (
_jitfix_iset,
jitdiff,
jitin_interval,
jitintersect,
jitunion,
)
from .config import nap_config
from .metadata_class import _MetadataMixin, add_meta_docstring
from .time_index import TsIndex
from .utils import (
_convert_iter_to_str,
_get_terminal_size,
_IntervalSetSliceHelper,
check_filename,
convert_to_numpy_array,
is_array_like,
)
all_warnings = np.array(
[
"Some starts and ends are equal. Removing 1 microsecond!",
"Some ends precede the relative start. Dropping them!",
"Some starts precede the previous end. Joining them!",
"Some epochs have no duration",
]
)
[docs]
class IntervalSet(NDArrayOperatorsMixin, _MetadataMixin):
"""
A class representing a (irregular) set of time intervals in elapsed time, with relative operations
The `IntervalSet` object behaves like a numpy ndarray with the limitation that the object is not mutable.
If start and end are not aligned, meaning that:
1. len(start) != len(end)
2. end[i] > start[i]
3. start[i+1] < end[i]
4. start and end are not sorted,
IntervalSet will try to "fix" the data by eliminating some of the start and end data points.
Parameters
----------
start : numpy.ndarray or number or pandas.DataFrame or pandas.Series or iterable of (start, end) pairs
Beginning of intervals.
Alternatively, the `end` argument can be left out and `start` can be one of the following:
- IntervalSet
- pandas.DataFrame with columns ["start", "end"]
- iterable of (start, end) pairs
- a single (start, end) pair
end : numpy.ndarray or number or pandas.Series, optional
Ends of intervals.
time_units : str, optional
Time unit of the intervals ('us', 'ms', 's' [default])
metadata: pandas.DataFrame or dict, optional
Metadata associated with each interval. Metadata names are pulled from DataFrame columns or dictionary keys.
The length of the metadata should match the length of the intervals.
Raises
------
RuntimeError
If `start` and `end` arguments are of unknown type.
Examples
--------
Initialize an IntervalSet with a list of start and end times:
>>> import pynapple as nap
>>> import numpy as np
>>> start = [0, 10, 20]
>>> end = [5, 12, 33]
>>> ep = nap.IntervalSet(start=start, end=end)
>>> ep
index start end
0 0 5
1 10 12
2 20 33
shape: (3, 2), time unit: sec.
Initialize an IntervalSet with an array of start and end pairs:
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> ep = nap.IntervalSet(times)
>>> ep
index start end
0 0 5
1 10 12
2 20 33
shape: (3, 2), time unit: sec.
Initialize an IntervalSet with metadata:
>>> start = [0, 10, 20]
>>> end = [5, 12, 33]
>>> metadata = {"label": ["a", "b", "c"]}
>>> ep = nap.IntervalSet(start=start, end=end, metadata=metadata)
index start end label
0 0 5 a
1 10 12 b
2 20 33 c
shape: (3, 2), time unit: sec.
Initialize an IntervalSet with a pandas DataFrame:
>>> import pandas as pd
>>> df = pd.DataFrame(data={"start": [0, 10, 20], "end": [5, 12, 33], "label": ["a", "b", "c"]})
>>> ep = nap.IntervalSet(df)
>>> ep
index start end label
0 0 5 a
1 10 12 b
2 20 33 c
shape: (3, 2), time unit: sec.
Apply numpy functions to an IntervalSet:
>>> ep = nap.IntervalSet(start=[0, 10], end=[5,20])
>>> ep
index start end
0 0 5
1 10 20
shape: (2, 2), time unit: sec.
>>> np.diff(ep, 1)
UserWarning: Converting IntervalSet to numpy.array
array([[ 5.],
[10.]])
Slicing an IntervalSet:
>>> ep[:,0]
array([ 0., 10.])
>>> ep[0]
start end
0 0 5
shape: (1, 2)
Modifying the `IntervalSet` will raise an error:
>>> ep[0,0] = 1
RuntimeError: IntervalSet is immutable. Starts and ends have been already sorted.
"""
start: np.ndarray
"""The start times of each interval"""
end: np.ndarray
"""The end times of each interval"""
values: np.ndarray
"""Array of start and end times"""
index: np.ndarray
"""Index of each interval, automatically set from 0 to n_intervals"""
columns: np.ndarray
"""Column names of the IntervalSet, which are always ["start", "end"]"""
nap_class: str
"""The pynapple class name"""
[docs]
def __init__(
self,
start,
end=None,
time_units="s",
metadata=None,
):
# set directly in __dict__ to avoid infinite recursion in __setattr__
self.__dict__["_initialized"] = False
if isinstance(start, IntervalSet):
end = start.end.astype(np.float64)
start = start.start.astype(np.float64)
elif isinstance(start, pd.DataFrame):
assert (
"start" in start.columns and "end" in start.columns
), """
DataFrame must contain columns name "start" and "end" for start and end times.
"""
# try sorting the DataFrame by start times, preserving its end pair, as an effort to preserve metadata
# since metadata would be dropped if starts and ends are sorted separately
# note that if end times are still not sorted, metadata will be dropped
if np.any(start["start"].diff() < 0):
warnings.warn(
"DataFrame is not sorted by start times. Sorting it.", stacklevel=2
)
start = start.sort_values("start").reset_index(drop=True)
metadata = start.drop(columns=["start", "end"])
end = start["end"].values.astype(np.float64)
start = start["start"].values.astype(np.float64)
else:
if end is None:
# Require iterable of (start, end) tuples
try:
start_end_array = np.array(list(start)).reshape(-1, 2)
start, end = zip(*start_end_array)
except (TypeError, ValueError):
raise ValueError(
"Unable to Interpret the input. Please provide a list of start-end pairs."
)
args = {"start": start, "end": end}
for arg, data in args.items():
if isinstance(data, Number):
args[arg] = np.array([data])
elif isinstance(data, (list, tuple)):
args[arg] = np.ravel(np.array(data))
elif isinstance(data, pd.Series):
args[arg] = data.values
elif isinstance(data, np.ndarray):
args[arg] = np.ravel(data)
elif is_array_like(data):
args[arg] = convert_to_numpy_array(data, arg)
else:
raise RuntimeError(
"Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format(
arg
)
)
start = args["start"]
end = args["end"]
assert len(start) == len(end), "Starts end ends are not of the same length"
start = TsIndex.format_timestamps(start, time_units)
end = TsIndex.format_timestamps(end, time_units)
drop_meta = False
if not (np.diff(start) > 0).all():
if metadata is not None:
msg1 = "Cannot add metadata to unsorted start times. "
msg2 = " and dropping metadata"
else:
msg1 = ""
msg2 = ""
warnings.warn(
"start is not sorted. " + msg1 + "Sorting it" + msg2 + ".", stacklevel=2
)
start = np.sort(start)
drop_meta = True
if not (np.diff(end) > 0).all():
if metadata is not None:
msg1 = "Cannot add metadata to unsorted end times. "
msg2 = " and dropping metadata"
else:
msg1 = ""
msg2 = ""
warnings.warn(
"end is not sorted. " + msg1 + "Sorting it" + msg2 + ".", stacklevel=2
)
end = np.sort(end)
drop_meta = True
data, to_warn = _jitfix_iset(start, end)
if np.any(to_warn):
msg = "\n".join(all_warnings[to_warn])
warnings.warn(msg, stacklevel=2)
if np.any(to_warn[1:]) and (metadata is not None):
drop_meta = True
warnings.warn("epochs have changed, dropping metadata.", stacklevel=2)
self.values = data
self.index = np.arange(data.shape[0], dtype="int")
self.columns = np.array(["start", "end"])
self.nap_class = self.__class__.__name__
# initialize metadata to get all attributes before setting metadata
_MetadataMixin.__init__(self)
self._class_attributes = self.__dir__() # get list of all attributes
self._class_attributes.append("_class_attributes") # add this property
self._initialized = True
if drop_meta is False:
self.set_info(metadata)
def __repr__(self):
# Start by determining how many columns and rows.
# This can be unique for each object
cols, rows = _get_terminal_size()
# max_cols = np.maximum(cols // 12, 5)
max_rows = np.maximum(rows - 10, 2)
# By default, the first three columns should always show.
# Adding an extra column between actual values and metadata
try:
metadata = self._metadata
col_names = metadata.columns
except Exception:
# Necessary for backward compatibility when saving IntervalSet as pickle
metadata = pd.DataFrame(index=self.index)
col_names = []
headers = ["index", "start", "end"]
if len(col_names):
headers += [c for c in col_names]
bottom = f"shape: {self.shape}, time unit: sec."
# We rarely want to print everything as it can be very big.
if len(self) > max_rows:
n_rows = max_rows // 2
data = np.vstack(
(
np.hstack(
(
self.index[0:n_rows, None],
self.values[0:n_rows],
_convert_iter_to_str(metadata.values[0:n_rows]),
),
dtype=object,
),
np.array([["..." for _ in range(len(headers))]], dtype=object),
np.hstack(
(
self.index[-n_rows:, None],
self.values[0:n_rows],
_convert_iter_to_str(metadata.values[-n_rows:]),
),
dtype=object,
),
)
)
else:
data = np.hstack(
(
self.index[:, None],
self.values,
_convert_iter_to_str(metadata.values),
),
dtype=object,
)
return tabulate(data, headers=headers, tablefmt="plain") + "\n" + bottom
def __str__(self):
return self.__repr__()
def __len__(self):
return len(self.values)
def __setattr__(self, name, value):
# necessary setter to allow metadata to be set as an attribute
if self._initialized:
if name in self._class_attributes:
raise AttributeError(
f"Cannot set attribute '{name}'; IntervalSet is immutable. Use 'set_info()' to set '{name}' as metadata."
)
else:
_MetadataMixin.__setattr__(self, name, value)
else:
object.__setattr__(self, name, value)
def __getattr__(self, name):
# Necessary for backward compatibility with pickle
# avoid infinite recursion when pickling due to
# self._metadata.column having attributes '__reduce__', '__reduce_ex__'
if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"):
raise AttributeError(name)
try:
metadata = self._metadata
except Exception:
metadata = pd.DataFrame(index=self.index)
if name == "_metadata":
return metadata
elif name in metadata.columns:
return _MetadataMixin.__getattr__(self, name)
else:
return super().__getattr__(name)
def __setitem__(self, key, value):
if key in self.columns:
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)
elif isinstance(key, str):
_MetadataMixin.__setitem__(self, key, value)
else:
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)
def __getitem__(self, key):
try:
metadata = _MetadataMixin.__getitem__(self, key)
except Exception:
metadata = pd.DataFrame(index=self.index)
if isinstance(key, str):
# self[str]
if key == "start":
return self.values[:, 0]
elif key == "end":
return self.values[:, 1]
elif key in self._metadata.columns:
return _MetadataMixin.__getitem__(self, key)
else:
raise IndexError(
f"Unknown string argument. Should be in {['start', 'end'] + list(self._metadata.keys())}"
)
elif isinstance(key, list) and all(isinstance(x, str) for x in key):
# self[[*str]]
# easiest to convert to dataframe and then slice
# in case of mixing ["start", "end"] with metadata columns
df = self.as_dataframe()
if all(x in key for x in ["start", "end"]):
return IntervalSet(df[key])
else:
return df[key]
elif isinstance(key, Number):
# self[Number]
output = self.values.__getitem__(key)
return IntervalSet(start=output[0], end=output[1], metadata=metadata)
elif isinstance(key, (slice, list, np.ndarray, pd.Series)):
# self[array_like]
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True)
return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata)
elif isinstance(key, tuple):
if len(key) == 2:
if isinstance(key[1], Number):
# self[Any, Number]
# allow number indexing for start and end times for backward compatibility
return self.values.__getitem__(key)
elif isinstance(key[1], str):
# self[Any, str]
if key[1] == "start":
return self.values[key[0], 0]
elif key[1] == "end":
return self.values[key[0], 1]
elif key[1] in self._metadata.columns:
return _MetadataMixin.__getitem__(self, key)
elif isinstance(key[1], (list, np.ndarray)):
if all(isinstance(x, str) for x in key[1]):
# self[Any, [*str]]
# easiest to convert to dataframe and then slice
# in case of mixing ["start", "end"] with metadata columns
df = self.as_dataframe()
if all(x in key[1] for x in ["start", "end"]):
return IntervalSet(df.loc[key])
else:
return df.loc[key]
elif all(isinstance(x, Number) for x in key[1]):
if all(x in [0, 1] for x in key[1]):
# self[Any, [0,1]]
# allow number indexing for start and end times for backward compatibility
output = self.values.__getitem__(key[0])
if isinstance(key[0], Number):
return IntervalSet(start=output[0], end=output[1])
else:
return IntervalSet(start=output[:, 0], end=output[:, 1])
else:
raise IndexError(
f"index {key[1]} out of bounds for IntervalSet axis 1 with size 2"
)
else:
raise IndexError(f"unknown index {key[1]} for index 2")
elif isinstance(key[1], slice):
if key[1] == slice(None, None, None):
# self[Any, :]
output = self.values.__getitem__(key[0])
metadata = _MetadataMixin.__getitem__(self, key[0])
if isinstance(key[0], Number):
return IntervalSet(
start=output[0], end=output[1], metadata=metadata
)
else:
return IntervalSet(
start=output[:, 0],
end=output[:, 1],
metadata=metadata.reset_index(drop=True),
)
elif key[1] == slice(0, 2, None):
# self[Any, :2]
# allow number indexing for start and end times for backward compatibility
output = self.values.__getitem__(key[0])
if isinstance(key[0], Number):
return IntervalSet(start=output[0], end=output[1])
else:
return IntervalSet(start=output[:, 0], end=output[:, 1])
else:
raise IndexError(
f"index {key[1]} out of bounds for IntervalSet axis 1 with size 2"
)
else:
raise IndexError(f"unknown type {type(key[1])} for index 2")
else:
raise IndexError(
"too many indices for IntervalSet: IntervalSet is 2-dimensional"
)
else:
raise IndexError(f"unknown type {type(key)} for index")
def __array__(self, dtype=None):
return self.values.astype(dtype)
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
new_args = []
for a in args:
if isinstance(a, self.__class__):
new_args.append(a.values)
else:
new_args.append(a)
out = ufunc(*new_args, **kwargs)
if not nap_config.suppress_conversion_warnings:
warnings.warn(
"Converting IntervalSet to numpy.array",
UserWarning,
)
return out
def __array_function__(self, func, types, args, kwargs):
new_args = []
for a in args:
if isinstance(a, self.__class__):
new_args.append(a.values)
else:
new_args.append(a)
out = func._implementation(*new_args, **kwargs)
if not nap_config.suppress_conversion_warnings:
warnings.warn(
"Converting IntervalSet to numpy.array",
UserWarning,
)
return out
@property
def start(self):
return self.values[:, 0]
@property
def end(self):
return self.values[:, 1]
@property
def shape(self):
return self.values.shape
@property
def ndim(self):
return self.values.ndim
@property
def size(self):
return self.values.size
@property
def starts(self):
"""Return the starts of the IntervalSet as a Ts object
Returns
-------
Ts
The starts of the IntervalSet
"""
warnings.warn(
"starts is a deprecated function. It will be removed in future versions",
category=DeprecationWarning,
stacklevel=2,
)
time_series = importlib.import_module(".time_series", "pynapple.core")
return time_series.Ts(t=self.values[:, 0])
@property
def ends(self):
"""Return the ends of the IntervalSet as a Ts object
Returns
-------
Ts
The ends of the IntervalSet
"""
warnings.warn(
"ends is a deprecated function. It will be removed in future versions",
category=DeprecationWarning,
stacklevel=2,
)
time_series = importlib.import_module(".time_series", "pynapple.core")
return time_series.Ts(t=self.values[:, 1])
@property
def loc(self):
"""
Slicing function to add compatibility with pandas DataFrame after removing it as a super class of IntervalSet
"""
return _IntervalSetSliceHelper(self)
@classmethod
def _from_npz_reader(cls, file):
"""Load an IntervalSet object from a npz file.
The file should contain the keys 'start', 'end' and 'type'.
The 'type' key should be 'IntervalSet'.
Parameters
----------
file : NPZFile object
opened npz file interface.
Returns
-------
IntervalSet
The IntervalSet object
"""
ep = cls(start=file["start"], end=file["end"])
if "_metadata" in file: # load metadata if it exists
if file["_metadata"]: # check that metadata is not empty
metadata = pd.DataFrame.from_dict(file["_metadata"].item())
ep.set_info(metadata)
return ep
[docs]
def time_span(self):
"""
Time span of the interval set.
Returns
-------
out: IntervalSet
an IntervalSet with a single interval encompassing the whole IntervalSet
"""
if len(self.metadata_columns):
warnings.warn(
"metadata incompatible with time_span method. dropping metadata from result",
UserWarning,
)
s = self.values[0, 0]
e = self.values[-1, 1]
return IntervalSet(s, e)
[docs]
def tot_length(self, time_units="s"):
"""
Total elapsed time in the set.
Parameters
----------
time_units : None, optional
The time units to return the result in ('us', 'ms', 's' [default])
Returns
-------
out: float
_
"""
tot_l = np.sum(self.values[:, 1] - self.values[:, 0])
return TsIndex.return_timestamps(np.array([tot_l]), time_units)[0]
[docs]
def intersect(self, a):
"""
Set intersection of IntervalSet
Parameters
----------
a : IntervalSet
the IntervalSet to intersect self with
Returns
-------
out: IntervalSet
_
"""
start1 = self.values[:, 0]
end1 = self.values[:, 1]
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e, m = jitintersect(start1, end1, start2, end2)
m1 = self._metadata.loc[m[:, 0]].reset_index(drop=True)
m2 = a._metadata.loc[m[:, 1]].reset_index(drop=True)
return IntervalSet(s, e, metadata=m1.join(m2))
[docs]
def union(self, a):
"""
set union of IntervalSet
Parameters
----------
a : IntervalSet
the IntervalSet to union self with
Returns
-------
out: IntervalSet
_
"""
if len(self.metadata_columns):
warnings.warn(
"metadata incompatible with union method. dropping metadata from result",
UserWarning,
)
start1 = self.values[:, 0]
end1 = self.values[:, 1]
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e = jitunion(start1, end1, start2, end2)
return IntervalSet(s, e)
[docs]
def set_diff(self, a):
"""
set difference of IntervalSet
Parameters
----------
a : IntervalSet
the IntervalSet to set-substract from self
Returns
-------
out: IntervalSet
_
"""
start1 = self.values[:, 0]
end1 = self.values[:, 1]
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e, m = jitdiff(start1, end1, start2, end2)
m1 = self._metadata.loc[m].reset_index(drop=True)
return IntervalSet(s, e, metadata=m1)
[docs]
def in_interval(self, tsd):
"""
finds out in which element of the interval set each point in a time series fits.
NaNs for those that don't fit an interval
Parameters
----------
tsd : Tsd
The tsd to be binned
Returns
-------
out: numpy.ndarray
an array with the interval index labels for each time stamp (NaN) for timestamps not in IntervalSet
"""
times = tsd.index.values
starts = self.values[:, 0]
ends = self.values[:, 1]
return jitin_interval(times, starts, ends)
[docs]
def drop_short_intervals(self, threshold, time_units="s"):
"""
Drops the short intervals in the interval set with duration shorter than `threshold`.
Parameters
----------
threshold : numeric
Time threshold for "short" intervals
time_units : None, optional
The time units for the treshold ('us', 'ms', 's' [default])
Returns
-------
out: IntervalSet
A copied IntervalSet with the dropped intervals
"""
threshold = TsIndex.format_timestamps(
np.array([threshold], dtype=np.float64), time_units
)[0]
return self[(self.values[:, 1] - self.values[:, 0]) > threshold]
[docs]
def drop_long_intervals(self, threshold, time_units="s"):
"""
Drops the long intervals in the interval set with duration longer than `threshold`.
Parameters
----------
threshold : numeric
Time threshold for "long" intervals
time_units : None, optional
The time units for the treshold ('us', 'ms', 's' [default])
Returns
-------
out: IntervalSet
A copied IntervalSet with the dropped intervals
"""
threshold = TsIndex.format_timestamps(
np.array([threshold], dtype=np.float64), time_units
)[0]
return self[(self.values[:, 1] - self.values[:, 0]) < threshold]
[docs]
def as_units(self, units="s"):
"""
returns a pandas DataFrame with time expressed in the desired unit
Parameters
----------
units : None, optional
'us', 'ms', or 's' [default]
Returns
-------
out: pandas.DataFrame
DataFrame with adjusted times
"""
data = self.values.copy()
data = TsIndex.return_timestamps(data, units)
if units == "us":
data = data.astype(np.int64)
df = pd.DataFrame(index=self.index, data=data, columns=self.columns)
return df
[docs]
def merge_close_intervals(self, threshold, time_units="s"):
"""
Merges intervals that are very close.
Parameters
----------
threshold : numeric
time threshold for the closeness of the intervals
time_units : None, optional
time units for the threshold ('us', 'ms', 's' [default])
Returns
-------
out: IntervalSet
a copied IntervalSet with merged intervals
"""
if len(self.metadata_columns):
warnings.warn(
"metadata incompatible with merge_close_intervals method. dropping metadata from result",
UserWarning,
)
if len(self) == 0:
return IntervalSet(start=[], end=[])
threshold = TsIndex.format_timestamps(
np.array((threshold,), dtype=np.float64).ravel(), time_units
)[0]
start = self.values[:, 0]
end = self.values[:, 1]
tojoin = (start[1:] - end[0:-1]) > threshold
start = np.hstack((start[0], start[1:][tojoin]))
end = np.hstack((end[0:-1][tojoin], end[-1]))
return IntervalSet(start=start, end=end)
[docs]
def get_intervals_center(self, alpha=0.5):
"""
Returns by default the centers of each intervals.
It is possible to bias the midpoint by changing the alpha parameter between [0, 1]
For each epoch:
t = start + (end-start)*alpha
Parameters
----------
alpha : float, optional
The midpoint within each interval.
Returns
-------
Ts
Timestamps object
"""
time_series = importlib.import_module(".time_series", "pynapple.core")
starts = self.values[:, 0]
ends = self.values[:, 1]
if not isinstance(alpha, float):
raise RuntimeError("Parameter alpha should be float type")
alpha = np.clip(alpha, 0, 1)
t = starts + (ends - starts) * alpha
return time_series.Ts(t=t, time_support=self)
[docs]
def as_dataframe(self):
"""
Convert the `IntervalSet` object to a pandas.DataFrame object.
Returns
-------
out: pandas.DataFrame
_
"""
df = pd.DataFrame(data=self.values, columns=["start", "end"])
return pd.concat([df, self._metadata], axis=1)
[docs]
def save(self, filename):
"""
Save IntervalSet object in npz format. The file will contain the starts and ends.
The main purpose of this function is to save small/medium sized IntervalSet
objects. For example, you determined some epochs for one session that you want to save
to avoid recomputing them.
You can load the object with `nap.load_file`. Keys are 'start', 'end' and 'type'.
See the example below.
Parameters
----------
filename : str
The filename
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> ep = nap.IntervalSet(start=[0, 10, 20], end=[5, 12, 33])
>>> ep.save("my_ep.npz")
To load you file, you can use the `nap.load_file` function :
>>> ep = nap.load_file("my_path/my_ep.npz")
>>> ep
start end
0 0.0 5.0
1 10.0 12.0
2 20.0 33.0
Raises
------
RuntimeError
If filename is not str, path does not exist or filename is a directory.
"""
np.savez(
check_filename(filename),
start=self.values[:, 0],
end=self.values[:, 1],
type=np.array(["IntervalSet"], dtype=np.str_),
_metadata=self._metadata.to_dict(), # save metadata as dictionary
)
return
[docs]
def split(self, interval_size, time_units="s"):
"""Split `IntervalSet` to a new `IntervalSet` with each interval being of size `interval_size`.
Used mostly for chunking very large dataset or looping throught multiple epoch of same duration.
This function skips the epochs that are shorter than `interval_size`.
Note that intervals are strictly non-overlapping in pynapple. One microsecond is removed from contiguous intervals.
Parameters
----------
interval_size : Number
Description
time_units : str, optional
time units for the `interval_size` ('us', 'ms', 's' [default])
Returns
-------
IntervalSet
New `IntervalSet` with equal sized intervals
Raises
------
IOError
If `interval_size` is not a Number or is below 0
If `time_units` is not a string
"""
if not isinstance(interval_size, Number):
raise IOError("Argument interval_size should of type float or int")
if not interval_size > 0:
raise IOError("Argument interval_size should be strictly larger than 0")
if not isinstance(time_units, str):
raise IOError("Argument time_units should be of type str")
if len(self) == 0:
return IntervalSet(start=[], end=[])
interval_size = TsIndex.format_timestamps(
np.array((interval_size,), dtype=np.float64).ravel(), time_units
)[0]
interval_size = np.round(interval_size, nap_config.time_index_precision)
durations = np.round(self.end - self.start, nap_config.time_index_precision)
idxs = np.where(durations > interval_size)[0]
size_tmp = (
np.ceil((self.end[idxs] - self.start[idxs]) / interval_size)
).astype(int) + 1
new_starts = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan)
new_ends = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan)
new_meta = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan)
i0 = 0
for cnt, idx in enumerate(idxs):
# repeat metainfo for each new interval
new_meta[i0 : i0 + size_tmp[cnt] - 1] = idx
new_starts[i0 : i0 + size_tmp[cnt] - 1] = np.arange(
self.start[idx], self.end[idx], interval_size
)
new_ends[i0 : i0 + size_tmp[cnt] - 2] = new_starts[
i0 + 1 : i0 + size_tmp[cnt] - 1
]
new_ends[i0 + size_tmp[cnt] - 2] = self.end[idx]
i0 += size_tmp[cnt] - 1
new_starts = np.round(new_starts, nap_config.time_index_precision)
new_ends = np.round(new_ends, nap_config.time_index_precision)
durations = np.round(new_ends - new_starts, nap_config.time_index_precision)
tokeep = durations >= interval_size
new_starts = new_starts[tokeep]
new_ends = new_ends[tokeep]
new_meta = new_meta[tokeep]
metadata = self._metadata.loc[new_meta].reset_index(drop=True)
# Removing 1 microsecond to have strictly non-overlapping intervals for intervals coming from the same epoch
new_ends -= 1e-6
return IntervalSet(new_starts, new_ends, metadata=metadata)
[docs]
@add_meta_docstring("set_info")
def set_info(self, metadata=None, **kwargs):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> ep = nap.IntervalSet(times)
To add metadata with a pandas.DataFrame:
>>> import pandas as pd
>>> metadata = pd.DataFrame(data=['left','right','left'], columns=['choice'])
>>> ep.set_info(metadata)
>>> ep
index start end choice
0 0 5 left
1 10 12 right
2 20 33 left
shape: (3, 2), time unit: sec.
To add metadata with a dictionary:
>>> metadata = {"reward": [1, 0, 1]}
>>> ep.set_info(metadata)
>>> ep
index start end choice reward
0 0 5 left 1
1 10 12 right 0
2 20 33 left 1
shape: (3, 2), time unit: sec.
To add metadata with a keyword argument (pd.Series, numpy.ndarray, list or tuple):
>>> stim = pd.Series(data = [10, -23, 12])
>>> ep.set_info(stim=stim)
>>> ep
index start end choice reward stim
0 0 5 left 1 10
1 10 12 right 0 -23
2 20 33 left 1 12
shape: (3, 2), time unit: sec.
To add metadata as an attribute:
>>> ep.label = ["a", "b", "c"]
>>> ep
index start end choice reward label
0 0 5 left 1 a
1 10 12 right 0 b
2 20 33 left 1 c
shape: (3, 2), time unit: sec.
To add metadata as a key:
>>> ep["error"] = [0, 0, 0]
>>> ep
index start end choice reward label error
0 0 5 left 1 a 0
1 10 12 right 0 b 0
2 20 33 left 1 c 0
shape: (3, 2), time unit: sec.
Metadata can be overwritten:
>>> ep.set_info(label=["x", "y", "z"])
>>> ep
index start end choice reward label error
0 0 5 left 1 x 0
1 10 12 right 0 y 0
2 20 33 left 1 z 0
shape: (3, 2), time unit: sec.
"""
_MetadataMixin.set_info(self, metadata, **kwargs)
[docs]
@add_meta_docstring("get_info")
def get_info(self, key):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(tmp,metadata=metadata)
To access a single metadata column:
>>> ep.get_info("l1")
0 1
1 2
2 3
Name: l1, dtype: int64
To access multiple metadata columns:
>>> ep.get_info(["l1", "l2"])
l1 l2
0 1 x
1 2 x
2 3 y
To access metadata of a single index:
>>> ep.get_info(0)
rate 0.667223
l1 1
l2 x
Name: 0, dtype: object
To access metadata of multiple indices:
>>> ep.get_info([0, 1])
rate l1 l2
0 0.667223 1 x
1 1.334445 2 x
To access metadata of a single index and column:
>>> ep.get_info((0, "l1"))
np.int64(1)
To access metadata as an attribute:
>>> ep.l1
0 1
1 2
2 3
Name: l1, dtype: int64
To access metadata as a key:
>>> ep["l1"]
0 1
1 2
2 3
Name: l1, dtype: int64
Multiple metadata columns can be accessed as keys:
>>> ep[["l1", "l2"]]
l1 l2
0 1 x
1 2 x
2 3 y
"""
return _MetadataMixin.get_info(self, key)