"""
The class `IntervalSet` deals with non-overlaping epochs. `IntervalSet` objects can interact with each other or with the time series objects.
"""
import importlib
import warnings
from numbers import Number
import numpy as np
import pandas as pd
from numpy.lib.mixins import NDArrayOperatorsMixin
from tabulate import tabulate
from ._jitted_functions import (
_jitfix_iset,
jitdiff,
jitin_interval,
jitintersect,
jitunion,
)
from .config import nap_config
from .time_index import TsIndex
from .utils import (
_get_terminal_size,
_IntervalSetSliceHelper,
check_filename,
convert_to_numpy_array,
is_array_like,
)
all_warnings = np.array(
[
"Some starts and ends are equal. Removing 1 microsecond!",
"Some ends precede the relative start. Dropping them!",
"Some starts precede the previous end. Joining them!",
"Some epochs have no duration",
]
)
[docs]
class IntervalSet(NDArrayOperatorsMixin):
"""
A class representing a (irregular) set of time intervals in elapsed time, with relative operations
The `IntervalSet` object behaves like a numpy ndarray with the limitation that the object is not mutable.
You can still apply any numpy array function to it :
>>> import pynapple as nap
>>> import numpy as np
>>> ep = nap.IntervalSet(start=[0, 10], end=[5,20])
start end
0 0 5
1 10 20
shape: (1, 2)
>>> np.diff(ep, 1)
UserWarning: Converting IntervalSet to numpy.array
array([[ 5.],
[10.]])
You can slice :
>>> ep[:,0]
array([ 0., 10.])
>>> ep[0]
start end
0 0 5
shape: (1, 2)
But modifying the `IntervalSet` with raise an error:
>>> ep[0,0] = 1
RuntimeError: IntervalSet is immutable. Starts and ends have been already sorted.
"""
[docs]
def __init__(self, start, end=None, time_units="s"):
"""
If start and end are not aligned, meaning that:
1. len(start) != len(end)
2. end[i] > start[i]
3. start[i+1] > end[i]
4. start and end are not sorted,
IntervalSet will try to "fix" the data by eliminating some of the start and end data points.
Parameters
----------
start : numpy.ndarray or number or pandas.DataFrame or pandas.Series or iterable of (start, end) pairs
Beginning of intervals. Alternatively, the `end` argument can be left out and `start` can be one of the
following:
- IntervalSet
- pandas.DataFrame with columns ["start", "end"]
- iterable of (start, end) pairs
- a single (start, end) pair
end : numpy.ndarray or number or pandas.Series, optional
Ends of intervals.
time_units : str, optional
Time unit of the intervals ('us', 'ms', 's' [default]).
Raises
------
RuntimeError
If `start` and `end` arguments are of unknown type.
"""
if isinstance(start, IntervalSet):
end = start.end.astype(np.float64)
start = start.start.astype(np.float64)
elif isinstance(start, pd.DataFrame):
assert (
"start" in start.columns
and "end" in start.columns
and start.shape[-1] == 2
), """
Wrong dataframe format. Expected format if passing a pandas dataframe is :
- 2 columns
- column names are ["start", "end"]
"""
end = start["end"].values.astype(np.float64)
start = start["start"].values.astype(np.float64)
else:
if end is None:
# Require iterable of (start, end) tuples
try:
start_end_array = np.array(list(start)).reshape(-1, 2)
start, end = zip(*start_end_array)
except (TypeError, ValueError):
raise ValueError(
"Unable to Interpret the input. Please provide a list of start-end pairs."
)
args = {"start": start, "end": end}
for arg, data in args.items():
if isinstance(data, Number):
args[arg] = np.array([data])
elif isinstance(data, (list, tuple)):
args[arg] = np.ravel(np.array(data))
elif isinstance(data, pd.Series):
args[arg] = data.values
elif isinstance(data, np.ndarray):
args[arg] = np.ravel(data)
elif is_array_like(data):
args[arg] = convert_to_numpy_array(data, arg)
else:
raise RuntimeError(
"Unknown format for {}. Accepted formats are numpy.ndarray, list, tuple or any array-like objects.".format(
arg
)
)
start = args["start"]
end = args["end"]
assert len(start) == len(end), "Starts end ends are not of the same length"
start = TsIndex.format_timestamps(start, time_units)
end = TsIndex.format_timestamps(end, time_units)
if not (np.diff(start) > 0).all():
warnings.warn("start is not sorted. Sorting it.", stacklevel=2)
start = np.sort(start)
if not (np.diff(end) > 0).all():
warnings.warn("end is not sorted. Sorting it.", stacklevel=2)
end = np.sort(end)
data, to_warn = _jitfix_iset(start, end)
if np.any(to_warn):
msg = "\n".join(all_warnings[to_warn])
warnings.warn(msg, stacklevel=2)
self.values = data
self.index = np.arange(data.shape[0], dtype="int")
self.columns = np.array(["start", "end"])
self.nap_class = self.__class__.__name__
def __repr__(self):
headers = [" " * 6, "start", "end"]
bottom = "shape: {}, time unit: sec.".format(self.shape)
rows = _get_terminal_size()[1]
max_rows = np.maximum(rows - 10, 6)
if len(self) > max_rows:
n_rows = max_rows // 2
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return (
tabulate(
np.hstack(
(self.index[0:n_rows][:, None], self.values[0:n_rows])
),
headers=headers,
tablefmt="plain",
colalign=("left", "center", "center"),
)
+ "\n"
+ " " * 10
+ "..."
+ tabulate(
np.hstack(
(self.index[-n_rows:][:, None], self.values[-n_rows:])
),
headers=[
" " * 6,
" " * 5,
" " * 3,
], # To align properly the columns
tablefmt="plain",
colalign=("left", "center", "center"),
)
+ "\n"
+ bottom
)
else:
return (
tabulate(
self.values, headers=headers, showindex="always", tablefmt="plain"
)
+ "\n"
+ bottom
)
def __str__(self):
return self.__repr__()
def __len__(self):
return len(self.values)
# def __iter__(self):
# pass
def __setitem__(self, key, value):
raise RuntimeError(
"IntervalSet is immutable. Starts and ends have been already sorted."
)
def __getitem__(self, key, *args, **kwargs):
if isinstance(key, str):
if key == "start":
return self.values[:, 0]
elif key == "end":
return self.values[:, 1]
else:
raise IndexError("Unknown string argument. Should be 'start' or 'end'")
elif isinstance(key, Number):
output = self.values.__getitem__(key)
return IntervalSet(start=output[0], end=output[1])
elif isinstance(key, (list, slice, np.ndarray)):
output = self.values.__getitem__(key)
return IntervalSet(start=output[:, 0], end=output[:, 1])
elif isinstance(key, pd.Series):
output = self.values.__getitem__(key)
return IntervalSet(start=output[:, 0], end=output[:, 1])
elif isinstance(key, tuple):
if len(key) == 2:
if isinstance(key[1], Number):
return self.values.__getitem__(key)
elif key[1] == slice(None, None, None) or key[1] == slice(0, 2, None):
output = self.values.__getitem__(key)
return IntervalSet(start=output[:, 0], end=output[:, 1])
else:
return self.values.__getitem__(key)
else:
raise IndexError(
"too many indices for IntervalSet: IntervalSet is 2-dimensional"
)
else:
return self.values.__getitem__(key)
def __array__(self, dtype=None):
return self.values.astype(dtype)
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
new_args = []
for a in args:
if isinstance(a, self.__class__):
new_args.append(a.values)
else:
new_args.append(a)
out = ufunc(*new_args, **kwargs)
if not nap_config.suppress_conversion_warnings:
warnings.warn(
"Converting IntervalSet to numpy.array",
UserWarning,
)
return out
def __array_function__(self, func, types, args, kwargs):
new_args = []
for a in args:
if isinstance(a, self.__class__):
new_args.append(a.values)
else:
new_args.append(a)
out = func._implementation(*new_args, **kwargs)
if not nap_config.suppress_conversion_warnings:
warnings.warn(
"Converting IntervalSet to numpy.array",
UserWarning,
)
return out
@property
def start(self):
return self.values[:, 0]
@property
def end(self):
return self.values[:, 1]
@property
def shape(self):
return self.values.shape
@property
def ndim(self):
return self.values.ndim
@property
def size(self):
return self.values.size
@property
def starts(self):
"""Return the starts of the IntervalSet as a Ts object
Returns
-------
Ts
The starts of the IntervalSet
"""
warnings.warn(
"starts is a deprecated function. It will be removed in future versions",
category=DeprecationWarning,
stacklevel=2,
)
time_series = importlib.import_module(".time_series", "pynapple.core")
return time_series.Ts(t=self.values[:, 0])
@property
def ends(self):
"""Return the ends of the IntervalSet as a Ts object
Returns
-------
Ts
The ends of the IntervalSet
"""
warnings.warn(
"ends is a deprecated function. It will be removed in future versions",
category=DeprecationWarning,
stacklevel=2,
)
time_series = importlib.import_module(".time_series", "pynapple.core")
return time_series.Ts(t=self.values[:, 1])
@property
def loc(self):
"""
Slicing function to add compatibility with pandas DataFrame after removing it as a super class of IntervalSet
"""
return _IntervalSetSliceHelper(self)
@classmethod
def _from_npz_reader(cls, file):
"""Load an IntervalSet object from a npz file.
The file should contain the keys 'start', 'end' and 'type'.
The 'type' key should be 'IntervalSet'.
Parameters
----------
file : NPZFile object
opened npz file interface.
Returns
-------
IntervalSet
The IntervalSet object
"""
return cls(start=file["start"], end=file["end"])
[docs]
def time_span(self):
"""
Time span of the interval set.
Returns
-------
out: IntervalSet
an IntervalSet with a single interval encompassing the whole IntervalSet
"""
s = self.values[0, 0]
e = self.values[-1, 1]
return IntervalSet(s, e)
[docs]
def tot_length(self, time_units="s"):
"""
Total elapsed time in the set.
Parameters
----------
time_units : None, optional
The time units to return the result in ('us', 'ms', 's' [default])
Returns
-------
out: float
_
"""
tot_l = np.sum(self.values[:, 1] - self.values[:, 0])
return TsIndex.return_timestamps(np.array([tot_l]), time_units)[0]
[docs]
def intersect(self, a):
"""
Set intersection of IntervalSet
Parameters
----------
a : IntervalSet
the IntervalSet to intersect self with
Returns
-------
out: IntervalSet
_
"""
start1 = self.values[:, 0]
end1 = self.values[:, 1]
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e = jitintersect(start1, end1, start2, end2)
return IntervalSet(s, e)
[docs]
def union(self, a):
"""
set union of IntervalSet
Parameters
----------
a : IntervalSet
the IntervalSet to union self with
Returns
-------
out: IntervalSet
_
"""
start1 = self.values[:, 0]
end1 = self.values[:, 1]
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e = jitunion(start1, end1, start2, end2)
return IntervalSet(s, e)
[docs]
def set_diff(self, a):
"""
set difference of IntervalSet
Parameters
----------
a : IntervalSet
the IntervalSet to set-substract from self
Returns
-------
out: IntervalSet
_
"""
start1 = self.values[:, 0]
end1 = self.values[:, 1]
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e = jitdiff(start1, end1, start2, end2)
return IntervalSet(s, e)
[docs]
def in_interval(self, tsd):
"""
finds out in which element of the interval set each point in a time series fits.
NaNs for those that don't fit an interval
Parameters
----------
tsd : Tsd
The tsd to be binned
Returns
-------
out: numpy.ndarray
an array with the interval index labels for each time stamp (NaN) for timestamps not in IntervalSet
"""
times = tsd.index.values
starts = self.values[:, 0]
ends = self.values[:, 1]
return jitin_interval(times, starts, ends)
[docs]
def drop_short_intervals(self, threshold, time_units="s"):
"""
Drops the short intervals in the interval set with duration shorter than `threshold`.
Parameters
----------
threshold : numeric
Time threshold for "short" intervals
time_units : None, optional
The time units for the treshold ('us', 'ms', 's' [default])
Returns
-------
out: IntervalSet
A copied IntervalSet with the dropped intervals
"""
threshold = TsIndex.format_timestamps(
np.array([threshold], dtype=np.float64), time_units
)[0]
return self[(self.values[:, 1] - self.values[:, 0]) > threshold]
[docs]
def drop_long_intervals(self, threshold, time_units="s"):
"""
Drops the long intervals in the interval set with duration longer than `threshold`.
Parameters
----------
threshold : numeric
Time threshold for "long" intervals
time_units : None, optional
The time units for the treshold ('us', 'ms', 's' [default])
Returns
-------
out: IntervalSet
A copied IntervalSet with the dropped intervals
"""
threshold = TsIndex.format_timestamps(
np.array([threshold], dtype=np.float64), time_units
)[0]
return self[(self.values[:, 1] - self.values[:, 0]) < threshold]
[docs]
def as_units(self, units="s"):
"""
returns a pandas DataFrame with time expressed in the desired unit
Parameters
----------
units : None, optional
'us', 'ms', or 's' [default]
Returns
-------
out: pandas.DataFrame
DataFrame with adjusted times
"""
data = self.values.copy()
data = TsIndex.return_timestamps(data, units)
if units == "us":
data = data.astype(np.int64)
df = pd.DataFrame(index=self.index, data=data, columns=self.columns)
return df
[docs]
def merge_close_intervals(self, threshold, time_units="s"):
"""
Merges intervals that are very close.
Parameters
----------
threshold : numeric
time threshold for the closeness of the intervals
time_units : None, optional
time units for the threshold ('us', 'ms', 's' [default])
Returns
-------
out: IntervalSet
a copied IntervalSet with merged intervals
"""
if len(self) == 0:
return IntervalSet(start=[], end=[])
threshold = TsIndex.format_timestamps(
np.array((threshold,), dtype=np.float64).ravel(), time_units
)[0]
start = self.values[:, 0]
end = self.values[:, 1]
tojoin = (start[1:] - end[0:-1]) > threshold
start = np.hstack((start[0], start[1:][tojoin]))
end = np.hstack((end[0:-1][tojoin], end[-1]))
return IntervalSet(start=start, end=end)
[docs]
def get_intervals_center(self, alpha=0.5):
"""
Returns by default the centers of each intervals.
It is possible to bias the midpoint by changing the alpha parameter between [0, 1]
For each epoch:
t = start + (end-start)*alpha
Parameters
----------
alpha : float, optional
The midpoint within each interval.
Returns
-------
Ts
Timestamps object
"""
time_series = importlib.import_module(".time_series", "pynapple.core")
starts = self.values[:, 0]
ends = self.values[:, 1]
if not isinstance(alpha, float):
raise RuntimeError("Parameter alpha should be float type")
alpha = np.clip(alpha, 0, 1)
t = starts + (ends - starts) * alpha
return time_series.Ts(t=t, time_support=self)
[docs]
def as_dataframe(self):
"""
Convert the `IntervalSet` object to a pandas.DataFrame object.
Returns
-------
out: pandas.DataFrame
_
"""
return pd.DataFrame(data=self.values, columns=["start", "end"])
[docs]
def save(self, filename):
"""
Save IntervalSet object in npz format. The file will contain the starts and ends.
The main purpose of this function is to save small/medium sized IntervalSet
objects. For example, you determined some epochs for one session that you want to save
to avoid recomputing them.
You can load the object with `nap.load_file`. Keys are 'start', 'end' and 'type'.
See the example below.
Parameters
----------
filename : str
The filename
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> ep = nap.IntervalSet(start=[0, 10, 20], end=[5, 12, 33])
>>> ep.save("my_ep.npz")
To load you file, you can use the `nap.load_file` function :
>>> ep = nap.load_file("my_path/my_ep.npz")
>>> ep
start end
0 0.0 5.0
1 10.0 12.0
2 20.0 33.0
Raises
------
RuntimeError
If filename is not str, path does not exist or filename is a directory.
"""
np.savez(
check_filename(filename),
start=self.values[:, 0],
end=self.values[:, 1],
type=np.array(["IntervalSet"], dtype=np.str_),
)
return
[docs]
def split(self, interval_size, time_units="s"):
"""Split `IntervalSet` to a new `IntervalSet` with each interval being of size `interval_size`.
Used mostly for chunking very large dataset or looping throught multiple epoch of same duration.
This function skips the epochs that are shorter than `interval_size`.
Note that intervals are strictly non-overlapping in pynapple. One microsecond is removed from contiguous intervals.
Parameters
----------
interval_size : Number
Description
time_units : str, optional
time units for the `interval_size` ('us', 'ms', 's' [default])
Returns
-------
IntervalSet
New `IntervalSet` with equal sized intervals
Raises
------
IOError
If `interval_size` is not a Number or is below 0
If `time_units` is not a string
"""
if not isinstance(interval_size, Number):
raise IOError("Argument interval_size should of type float or int")
if not interval_size > 0:
raise IOError("Argument interval_size should be strictly larger than 0")
if not isinstance(time_units, str):
raise IOError("Argument time_units should be of type str")
if len(self) == 0:
return IntervalSet(start=[], end=[])
interval_size = TsIndex.format_timestamps(
np.array((interval_size,), dtype=np.float64).ravel(), time_units
)[0]
interval_size = np.round(interval_size, nap_config.time_index_precision)
durations = np.round(self.end - self.start, nap_config.time_index_precision)
idxs = np.where(durations > interval_size)[0]
size_tmp = (
np.ceil((self.end[idxs] - self.start[idxs]) / interval_size)
).astype(int) + 1
new_starts = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan)
new_ends = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan)
i0 = 0
for cnt, idx in enumerate(idxs):
new_starts[i0 : i0 + size_tmp[cnt] - 1] = np.arange(
self.start[idx], self.end[idx], interval_size
)
new_ends[i0 : i0 + size_tmp[cnt] - 2] = new_starts[
i0 + 1 : i0 + size_tmp[cnt] - 1
]
new_ends[i0 + size_tmp[cnt] - 2] = self.end[idx]
i0 += size_tmp[cnt] - 1
new_starts = np.round(new_starts, nap_config.time_index_precision)
new_ends = np.round(new_ends, nap_config.time_index_precision)
durations = np.round(new_ends - new_starts, nap_config.time_index_precision)
tokeep = durations >= interval_size
new_starts = new_starts[tokeep]
new_ends = new_ends[tokeep]
# Removing 1 microsecond to have strictly non-overlapping intervals for intervals coming from the same epoch
new_ends -= 1e-6
return IntervalSet(new_starts, new_ends)