"""
Functions to compute power spectral density and mean power spectral density.
"""
import inspect
from functools import wraps
from numbers import Number
import numpy as np
import pandas as pd
from numba import njit
from scipy import signal
from .. import core as nap
def _validate_spectrum_inputs(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Validate each positional argument
sig = inspect.signature(func)
kwargs = sig.bind_partial(*args, **kwargs).arguments
parameters_type = {
"sig": (nap.Tsd, nap.TsdFrame),
"fs": Number,
"ep": nap.IntervalSet,
"full_range": bool,
"norm": bool,
"n": int,
"time_unit": str,
"interval_size": Number,
"overlap": float,
}
for param, param_type in parameters_type.items():
if param in kwargs:
if not isinstance(kwargs[param], param_type):
raise TypeError(
f"Invalid type. Parameter {param} must be of type {param_type}."
)
# Call the original function with validated inputs
return func(**kwargs)
return wrapper
[docs]
@_validate_spectrum_inputs
def compute_power_spectral_density(
sig, fs=None, ep=None, full_range=False, norm=False, n=None
):
"""
Compute Power Spectral Density over a single epoch.
Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal.
Parameters
----------
sig : pynapple.Tsd or pynapple.TsdFrame
Time series.
fs : float, optional
Sampling rate, in Hz. If None, will be calculated from the given signal
ep : None or pynapple.IntervalSet, optional
The epoch to calculate the fft on. Must be length 1.
full_range : bool, optional
If true, will return full fft frequency range, otherwise will return only positive values
norm: bool, optional
Whether the FFT result is divided by the length of the signal to normalize the amplitude
n: int, optional
Length of the transformed axis of the output. If n is smaller than the length of the input,
the input is cropped. If it is larger, the input is padded with zeros. If n is not given,
the length of the input along the axis specified by axis is used.
Returns
-------
pandas.DataFrame
Time frequency representation of the input signal, indexes are frequencies, values
are powers.
Notes
-----
This function computes fft on only a single epoch of data. This epoch be given with the ep
parameter otherwise will be sig.time_support, but it must only be a single epoch.
"""
if ep is None:
ep = sig.time_support
if len(ep) != 1:
raise ValueError("Given epoch (or signal time_support) must have length 1")
if fs is None:
fs = sig.rate
fft_result = np.fft.fft(sig.restrict(ep).values, n=n, axis=0)
if n is None:
n = len(sig.restrict(ep))
fft_freq = np.fft.fftfreq(n, 1 / fs)
if norm:
fft_result = fft_result / fft_result.shape[0]
ret = pd.DataFrame(fft_result, fft_freq)
ret.sort_index(inplace=True)
if not full_range:
return ret.loc[ret.index >= 0]
return ret
[docs]
@_validate_spectrum_inputs
def compute_mean_power_spectral_density(
sig,
interval_size,
fs=None,
overlap=0.25,
ep=None,
full_range=False,
norm=False,
time_unit="s",
):
"""
Compute Mean Power Spectral Density over multiple epochs of same size.
The parameter `interval_size` controls the duration of the epochs.
To improve frequency resolution, the signal is multiplied by a Hamming window.
Note that this function assumes a constant sampling rate for `sig`.
Parameters
----------
sig : pynapple.Tsd or pynapple.TsdFrame
Signal with equispaced samples
interval_size : Number
Epochs size to compute to average the FFT across
fs : Number, optional
Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate`
overlap : float, optional
Percentage of overlap between successive intervals.
`0.0 <= overlap < 1.0`. Default is 0.25
ep : None or pynapple.IntervalSet, optional
The `IntervalSet` to calculate the fft on. Can be any length.
full_range : bool, optional
If true, will return full fft frequency range, otherwise will return only positive values
norm: bool, optional
Whether the FFT result is divided by the length of the signal to normalize the amplitude
time_unit : str, optional
Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us')
Returns
-------
pandas.DataFrame
Power spectral density.
Examples
--------
>>> import numpy as np
>>> import pynapple as nap
>>> t = np.arange(0, 1, 1/1000)
>>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t)
>>> mpsd = nap.compute_mean_power_spectral_density(signal, 0.1)
Raises
------
RuntimeError
If splitting the epoch with `interval_size` results in an empty set.
ValueError
If overlap is not within [0, 1).
"""
if not (0.0 <= overlap < 1.0):
raise ValueError("Overlap should be in intervals [0.0, 1.0).")
if ep is None:
ep = sig.time_support
if fs is None:
fs = sig.rate
interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[
0
]
# Check if at least one epoch is larger than the interval size
if np.max(ep.end - ep.start) < interval_size:
raise RuntimeError(
f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size"
)
split_ep = _overlap_split(ep.start, ep.end, interval_size, overlap)
# Get the slices of each ep
slices = np.zeros((len(split_ep), 2), dtype=int)
for i in range(len(split_ep)):
sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1])
slices[i, 0] = sl.start
slices[i, 1] = sl.stop
# Check what is the signal length
N = np.min(np.diff(slices, 1))
if N == 0:
raise RuntimeError(
"One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed."
)
# Get the freqs
fft_freq = np.fft.fftfreq(N, 1 / fs)
# Get the Hamming window
window = signal.windows.hamming(N)
if sig.ndim == 2:
window = window[:, np.newaxis]
# Compute the fft
fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex)
for i in range(len(slices)):
tmp = sig[slices[i, 0] : slices[i, 1]].values[0:N] * window
fft_result += np.fft.fft(tmp, axis=0)
if norm:
fft_result = fft_result / (float(N) * float(len(slices)))
ret = pd.DataFrame(fft_result, fft_freq)
ret.sort_index(inplace=True)
if not full_range:
return ret.loc[ret.index >= 0]
return ret
@njit
def _overlap_split(start, end, interval_size, overlap):
N = int(
np.ceil(np.sum(end - start) / (interval_size * (1 - overlap)))
) # upper bound
slices = np.zeros((N + 1, 2))
k = 0 # epochs
n = 0
while k < len(start):
t = start[k]
while t + interval_size < end[k]:
slices[n, 0] = t
slices[n, 1] = t + interval_size
t += (1 - overlap) * interval_size
n += 1
k += 1
return slices[0:n]