"""Functions for highpass, lowpass, bandpass or bandstop filtering."""
import inspect
from collections.abc import Iterable
from functools import wraps
from numbers import Number
import numpy as np
import pandas as pd
from scipy.signal import butter, sosfiltfilt, sosfreqz
from .. import core as nap
def _validate_filtering_inputs(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Validate each positional argument
sig = inspect.signature(func)
kwargs = sig.bind_partial(*args, **kwargs).arguments
cutoff = kwargs["cutoff"]
filter_type = kwargs["filter_type"]
if filter_type in ["lowpass", "highpass"] and not isinstance(cutoff, Number):
raise ValueError(
f"{filter_type} filter require a single number. {cutoff} provided instead."
)
if filter_type in ["bandpass", "bandstop"]:
if (
not isinstance(cutoff, Iterable)
or len(cutoff) != 2
or not all(isinstance(fq, Number) for fq in cutoff)
):
raise ValueError(
f"{filter_type} filter require a tuple of two numbers. {cutoff} provided instead."
)
if "fs" in kwargs:
if kwargs["fs"] is not None and not isinstance(kwargs["fs"], Number):
raise ValueError(
"Invalid value for 'fs'. Parameter 'fs' should be of type float or int"
)
if "order" in kwargs:
if not isinstance(kwargs["order"], int):
raise ValueError(
"Invalid value for 'order': Parameter 'order' should be of type int"
)
if "transition_bandwidth" in kwargs:
if not isinstance(kwargs["transition_bandwidth"], float):
raise ValueError(
"Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float"
)
# Call the original function with validated inputs
return func(**kwargs)
return wrapper
def _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order=4):
"""Calls scipy butter"""
return butter(order, cutoff, btype=filter_type, fs=sampling_frequency, output="sos")
def _compute_butterworth_filter(
data, cutoff, sampling_frequency=None, filter_type="bandpass", order=4
):
"""
Apply a Butterworth filter to the provided signal.
"""
sos = _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order)
if nap.utils.get_backend() == "jax":
from pynajax.jax_process_filtering import jax_sosfiltfilt
out = jax_sosfiltfilt(
sos,
data.index.values,
data.values,
data.time_support.start,
data.time_support.end,
)
else:
out = np.zeros_like(data.d)
for ep in data.time_support:
slc = data.get_slice(start=ep.start[0], end=ep.end[0])
out[slc] = sosfiltfilt(sos, data.d[slc], axis=0)
return data._define_instance(data.t, data.time_support, values=out)
def _compute_spectral_inversion(kernel):
"""
Compute the spectral inversion.
Parameters
----------
kernel: ndarray
Returns
-------
ndarray
"""
kernel *= -1.0
kernel[len(kernel) // 2] = 1.0 + kernel[len(kernel) // 2]
return kernel
def _get_windowed_sinc_kernel(
fc, filter_type, sampling_frequency, transition_bandwidth=0.02
):
"""
Get the windowed-sinc kernel.
Smith, S. (2003). Digital signal processing: a practical guide for engineers and scientists.
Chapter 16, equation 16-4
Parameters
----------
fc: float or tuple of float
Cutting frequency in Hz. Single float for 'lowpass' and 'highpass'. Tuple of float for
'bandpass' and 'bandstop'.
filter_type: str
Either 'lowpass', 'highpass', 'bandstop' or 'bandpass'.
sampling_frequency: float
Sampling frequency in Hz.
transition_bandwidth: float
Percentage between 0 and 0.5
Returns
-------
np.ndarray
"""
M = int(np.rint(4.0 / transition_bandwidth))
x = np.arange(-(M // 2), 1 + (M // 2))
fc = np.transpose(np.atleast_2d(fc / sampling_frequency))
kernel = np.sinc(2 * fc * x)
kernel = kernel * np.blackman(len(x))
kernel = np.transpose(kernel)
kernel = kernel / kernel.sum(0)
if filter_type == "lowpass":
return kernel.flatten()
elif filter_type == "highpass":
return _compute_spectral_inversion(kernel.flatten())
elif filter_type == "bandstop":
kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1])
kernel = np.sum(kernel, axis=1)
return kernel
elif filter_type == "bandpass":
kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1])
kernel = _compute_spectral_inversion(np.sum(kernel, axis=1))
return kernel
else:
raise ValueError
def _compute_windowed_sinc_filter(
data, freq, filter_type, sampling_frequency, transition_bandwidth=0.02
):
"""
Apply a windowed-sinc filter to the provided signal.
Parameters
----------
data: Tsd, TsdFrame or TsdTensor
freq: float or tuple of float
Cutting frequency in Hz. Single float for 'lowpass' and 'highpass'. Tuple of float for
'bandpass' and 'bandstop'.
sampling_frequency: float
Sampling frequency in Hz.
filter_type: str
Either 'lowpass', 'highpass', 'bandstop' or 'bandpass'.
transition_bandwidth: float
Percentage between 0 and 0.5
Returns
-------
Tsd, TsdFrame or TsdTensor
"""
kernel = _get_windowed_sinc_kernel(
freq, filter_type, sampling_frequency, transition_bandwidth
)
return data.convolve(kernel)
@_validate_filtering_inputs
def _compute_filter(
data,
cutoff,
fs=None,
mode="butter",
order=4,
transition_bandwidth=0.02,
filter_type="bandpass",
):
"""
Filter the signal.
"""
if not isinstance(data, nap.time_series._BaseTsd):
raise ValueError(
f"Invalid value: {data}. First argument should be of type Tsd, TsdFrame or TsdTensor"
)
if np.any(np.isnan(data)):
raise ValueError(
"The input signal contains NaN values, which are not supported for filtering. "
"Please remove or handle NaNs before applying the filter. "
"You can use the `dropna()` method to drop all NaN values."
)
if fs is None:
fs = data.rate
cutoff = np.array(cutoff, dtype=float)
if mode == "butter":
return _compute_butterworth_filter(
data, cutoff, fs, filter_type=filter_type, order=order
)
if mode == "sinc":
return _compute_windowed_sinc_filter(
data, cutoff, filter_type, fs, transition_bandwidth=transition_bandwidth
)
else:
raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'")
[docs]
def apply_bandpass_filter(
data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02
):
"""
Apply a band-pass filter to the provided signal.
Mode can be :
- `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter.
- `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth.
Parameters
----------
data : Tsd, TsdFrame, or TsdTensor
The signal to be filtered.
cutoff : (Numeric, Numeric)
Cutoff frequencies in Hz.
fs : float, optional
The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data.
mode : {'butter', 'sinc'}, optional
Filtering mode. Default is 'butter'.
order : int, optional
The order of the Butterworth filter. Higher values result in sharper frequency cutoffs.
Default is 4.
transition_bandwidth : float, optional
The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency.
The smaller the transition bandwidth, the larger the windowed-sinc kernel.
Default is 0.02.
Returns
-------
filtered_data : Tsd, TsdFrame, or TsdTensor
The filtered signal, with the same data type as the input.
Raises
------
ValueError
If `data` is not a Tsd, TsdFrame, or TsdTensor.
If `cutoff` is not a tuple of two floats for "bandpass" and "bandstop" filters.
If `fs` is not float or None.
If `mode` is not "butter" or "sinc".
If `order` is not an int.
If "transition_bandwidth" is not a float.
Notes
-----
For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal
is reduced by -3 dB (decibels).
"""
return _compute_filter(
data,
cutoff,
fs=fs,
mode=mode,
order=order,
transition_bandwidth=transition_bandwidth,
filter_type="bandpass",
)
[docs]
def apply_bandstop_filter(
data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02
):
"""
Apply a band-stop filter to the provided signal.
Mode can be :
- `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter.
- `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth.
Parameters
----------
data : Tsd, TsdFrame, or TsdTensor
The signal to be filtered.
cutoff : (Numeric, Numeric)
Cutoff frequencies in Hz.
fs : float, optional
The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data.
mode : {'butter', 'sinc'}, optional
Filtering mode. Default is 'butter'.
order : int, optional
The order of the Butterworth filter. Higher values result in sharper frequency cutoffs.
Default is 4.
transition_bandwidth : float, optional
The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency.
The smaller the transition bandwidth, the larger the windowed-sinc kernel.
Default is 0.02.
Returns
-------
filtered_data : Tsd, TsdFrame, or TsdTensor
The filtered signal, with the same data type as the input.
Raises
------
ValueError
If `data` is not a Tsd, TsdFrame, or TsdTensor.
If `cutoff` is not a tuple of two floats for "bandpass" and "bandstop" filters.
If `fs` is not float or None.
If `mode` is not "butter" or "sinc".
If `order` is not an int.
If "transition_bandwidth" is not a float.
Notes
-----
For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal
is reduced by -3 dB (decibels).
"""
return _compute_filter(
data,
cutoff,
fs=fs,
mode=mode,
order=order,
transition_bandwidth=transition_bandwidth,
filter_type="bandstop",
)
[docs]
def apply_highpass_filter(
data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02
):
"""
Apply a high-pass filter to the provided signal.
Mode can be :
- `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter.
- `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth.
Parameters
----------
data : Tsd, TsdFrame, or TsdTensor
The signal to be filtered.
cutoff : Numeric
Cutoff frequency in Hz.
fs : float, optional
The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data.
mode : {'butter', 'sinc'}, optional
Filtering mode. Default is 'butter'.
order : int, optional
The order of the Butterworth filter. Higher values result in sharper frequency cutoffs.
Default is 4.
transition_bandwidth : float, optional
The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency.
The smaller the transition bandwidth, the larger the windowed-sinc kernel.
Default is 0.02.
Returns
-------
filtered_data : Tsd, TsdFrame, or TsdTensor
The filtered signal, with the same data type as the input.
Raises
------
ValueError
If `data` is not a Tsd, TsdFrame, or TsdTensor.
If `cutoff` is not a number.
If `fs` is not float or None.
If `mode` is not "butter" or "sinc".
If `order` is not an int.
If "transition_bandwidth" is not a float.
Notes
-----
For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal
is reduced by -3 dB (decibels).
"""
return _compute_filter(
data,
cutoff,
fs=fs,
mode=mode,
order=order,
transition_bandwidth=transition_bandwidth,
filter_type="highpass",
)
[docs]
def apply_lowpass_filter(
data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02
):
"""
Apply a low-pass filter to the provided signal.
Mode can be :
- `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter.
- `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth.
Parameters
----------
data : Tsd, TsdFrame, or TsdTensor
The signal to be filtered.
cutoff : Numeric
Cutoff frequency in Hz.
fs : float, optional
The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data.
mode : {'butter', 'sinc'}, optional
Filtering mode. Default is 'butter'.
order : int, optional
The order of the Butterworth filter. Higher values result in sharper frequency cutoffs.
Default is 4.
transition_bandwidth : float, optional
The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency.
The smaller the transition bandwidth, the larger the windowed-sinc kernel.
Default is 0.02.
Returns
-------
filtered_data : Tsd, TsdFrame, or TsdTensor
The filtered signal, with the same data type as the input.
Raises
------
ValueError
If `data` is not a Tsd, TsdFrame, or TsdTensor.
If `cutoff` is not a number.
If `fs` is not float or None.
If `mode` is not "butter" or "sinc".
If `order` is not an int.
If "transition_bandwidth" is not a float.
Notes
-----
For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal
is reduced by -3 dB (decibels).
"""
return _compute_filter(
data,
cutoff,
fs=fs,
mode=mode,
order=order,
transition_bandwidth=transition_bandwidth,
filter_type="lowpass",
)
[docs]
@_validate_filtering_inputs
def get_filter_frequency_response(
cutoff, fs, filter_type, mode, order=4, transition_bandwidth=0.02
):
"""
Utility function to evaluate the frequency response of a particular type of filter. The arguments are the same
as the function `apply_lowpass_filter`, `apply_highpass_filter`, `apply_bandpass_filter` and
`apply_bandstop_filter`.
This function returns a pandas Series object with the index as frequencies.
Parameters
----------
cutoff : Numeric or tuple of Numeric
Cutoff frequency in Hz.
fs : float
The sampling frequency of the signal in Hz.
filter_type: str
Can be "lowpass", "highpass", "bandpass" or "bandstop"
mode: str
Can be "butter" or "sinc".
order : int, optional
The order of the Butterworth filter. Higher values result in sharper frequency cutoffs.
Default is 4.
transition_bandwidth : float, optional
The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency.
The smaller the transition bandwidth, the larger the windowed-sinc kernel.
Default is 0.02.
Returns
-------
pandas.Series
"""
cutoff = np.array(cutoff)
if mode == "butter":
sos = _get_butter_coefficients(cutoff, filter_type, fs, order)
w, h = sosfreqz(sos, worN=1024, fs=fs)
return pd.Series(index=w, data=np.abs(h))
if mode == "sinc":
kernel = _get_windowed_sinc_kernel(
cutoff, filter_type, fs, transition_bandwidth
)
fft_result = np.fft.fft(kernel)
fft_result = np.fft.fftshift(fft_result)
fft_freq = np.fft.fftfreq(n=len(kernel), d=1 / fs)
fft_freq = np.fft.fftshift(fft_freq)
return pd.Series(
index=fft_freq[fft_freq >= 0], data=np.abs(fft_result[fft_freq >= 0])
)
else:
raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'")