"""
Functions to compute wavelets decomposition of a time series.
The main function for doing wavelet decomposition is `nap.compute_wavelet_transform`
For now, pynapple only implements Morlet wavelets. To check the shape and quality of the wavelets, check out
the function `nap.generate_morlet_filterbank` to plot the wavelets.
"""
import numpy as np
from .. import core as nap
def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8):
"""
Defines the complex Morlet wavelet kernel.
Parameters
----------
M : int
Length of the wavelet
gaussian_width : float
Defines width of Gaussian to be used in wavelet creation.
window_length : float
The length of window to be used for wavelet creation.
precision: int.
Precision of wavelet to use. Default is 8
Returns
-------
np.ndarray
Morelet wavelet kernel
"""
x = np.linspace(-precision, precision, M)
return (
((np.pi * gaussian_width) ** (-0.25))
* np.exp(-(x**2) / gaussian_width)
* np.exp(1j * 2 * np.pi * window_length * x)
)
[docs]
def generate_morlet_filterbank(
freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16
):
"""
Generates a Morlet filterbank using the given frequencies and parameters.
This function can be used purely for visualization, or to convolve with a pynapple Tsd,
TsdFrame, or TsdTensor as part of a wavelet decomposition process.
Parameters
----------
freqs : 1d array
frequency values to estimate with Morlet wavelets.
fs : float or int
Sampling rate, in Hz.
gaussian_width : float
Defines width of Gaussian to be used in wavelet creation.
window_length : float
The length of window to be used for wavelet creation.
precision: int.
Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at.
Returns
-------
filter_bank : pynapple.TsdFrame
list of Morlet wavelet filters of the frequencies given
Notes
-----
This algorithm first computes a single, finely sampled wavelet using the provided hyperparameters.
Wavelets of different frequencies are generated by resampling this mother wavelet with an appropriate step size.
The step size is determined based on the desired frequency and the sampling rate.
"""
if not isinstance(freqs, np.ndarray):
raise TypeError("`freqs` must be a ndarray")
if len(freqs) == 0:
raise ValueError("Given list of freqs cannot be empty.")
if np.min(freqs) <= 0:
raise ValueError("All frequencies in freqs must be strictly positive")
if not isinstance(fs, (int, float, np.number)):
raise TypeError("`fs` must be of type float or int ndarray")
if isinstance(gaussian_width, (int, float, np.number)):
if gaussian_width <= 0:
raise ValueError("gaussian_width must be a positive number.")
else:
raise TypeError("gaussian_width must be a float or int instance.")
if isinstance(window_length, (int, float, np.number)):
if window_length <= 0:
raise ValueError("window_length must be a positive number.")
else:
raise TypeError("window_length must be a float or int instance.")
if isinstance(precision, int):
if precision <= 0:
raise ValueError("precision must be a positive number.")
else:
raise TypeError("precision must be a float or int instance.")
# Initialize filter bank and parameters
filter_bank = []
cutoff = 8 # Define cutoff for wavelet
# Compute a single, finely sampled Morlet wavelet
morlet_f = np.conj(
_morlet(
int(2**precision),
gaussian_width=gaussian_width,
window_length=window_length,
)
)
x = np.linspace(-cutoff, cutoff, int(2**precision))
max_len = -1 # Track maximum length of wavelet
for freq in freqs:
scale = window_length / (freq / fs)
# Calculate the indices for subsampling the wavelet and achieve the right frequency
# After the slicing the size will be reduced, therefore we will pad with 0s.
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0]))
j = np.ceil(j).astype(int) # Ceil the values to get integer indices
if j[-1] >= morlet_f.size:
j = np.extract(j < morlet_f.size, j)
scaled_morlet = morlet_f[j][::-1] # Scale and reverse wavelet
if len(scaled_morlet) > max_len:
max_len = len(scaled_morlet)
time = np.linspace(
-cutoff * window_length / freq, cutoff * window_length / freq, max_len
)
filter_bank.append(scaled_morlet)
# Pad wavelets to ensure all are of the same length
filter_bank = [
np.pad(
arr,
((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2),
constant_values=0.0,
)
for arr in filter_bank
]
# Return filter bank as a TsdFrame
return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time)