Trial-aligned decoding with International Brain Lab data#

This example works through a basic pipeline for decoding the mouse’s choice from spiking activity in the International Brain Lab’s decision task, including loading the data from DANDI, processing and trial-aligning the neural activity, and fitting a logistic regression with cross-validation using scikit-learn.

The International Brain Lab’s Brain Wide Map dataset is available at Dandiset 00409. The International Brain Lab’s BWM website includes links to their preprint and additional documentation. The IBL also has an excellent decoding demonstration in the COSYNE section of their events webpage under “Tutorial 2: Advanced analysis”, along with many other relevant demos.

For a more detailed tutorial on data loading with DANDI, see the “Streaming data from DANDI” example!

Caveats: This example is meant to provide a simple starting point for working with trial-aligned data and data from the IBL, and so it does not faithfully replicate the IBL’s quality control and filtering criteria; the decoding here is also simpler than the analyses carried out in those works.

Hide code cell content

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pynapple as nap
import scipy.stats
import seaborn as sns

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params)

Loading data#

See also the “Streaming data from DANDI”, which includes more detail.

# These modules are used for data loading
from dandi.dandiapi import DandiAPIClient
import fsspec
from fsspec.implementations.cached import CachingFileSystem
import h5py
from pynwb import NWBHDF5IO
# The BWM dandiset:
dandiset_id = "000409"
# This is a randomly chosen recording session.
asset_path = "sub-CSH-ZAD-026/sub-CSH-ZAD-026_ses-15763234-d21e-491f-a01b-1238eb96d389_desc-processed_behavior+ecephys.nwb"
with DandiAPIClient() as client:
    asset = client.get_dandiset(dandiset_id, "draft").get_asset_by_path(asset_path)
    s3_url = asset.get_content_url(follow_redirects=1, strip_query=True)
fs = CachingFileSystem(
    fs=fsspec.filesystem("http"),
    cache_storage=str(Path("~/.caches/nwb-cache").expanduser()),
)
io = NWBHDF5IO(file=h5py.File(fs.open(s3_url, "rb")), load_namespaces=True)
nwb = nap.NWBFile(io.read())
/home/runner/.local/lib/python3.12/site-packages/hdmf/spec/namespace.py:620: UserWarning: Ignoring the following cached namespace(s) because another version is already loaded:
hdmf-common - cached version: 1.9.0, loaded version: 1.8.0
hdmf-experimental - cached version: 0.6.0, loaded version: 0.5.0
Please update to the latest package versions.
  self.warn_for_ignored_namespaces(ignored_namespaces)

This NWB file, loaded into a pynapple object, contains neural activity, behavioral data, and raw electrophysiological traces. They’re lazily loaded, so only what we use will be downloaded.

nwb
15763234-d21e-491f-a01b-1238eb96d389
┍━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys                                 │ Type        │
┝━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units                                │ TsGroup     │
│ trials                               │ IntervalSet │
│ epochs                               │ IntervalSet │
│ WheelMovementIntervals               │ IntervalSet │
│ WheelVelocitySmoothed                │ Tsd         │
│ WheelPositionSmoothed                │ Tsd         │
│ WheelPosition                        │ Tsd         │
│ WheelAccelerationSmoothed            │ Tsd         │
│ RightPupilDiameterSmoothed           │ Tsd         │
│ RightPupilDiameter                   │ Tsd         │
│ PoseEstimationSeriesTubeTop          │ TsdFrame    │
│ PoseEstimationSeriesTubeBottom       │ TsdFrame    │
│ PoseEstimationSeriesRightTongueEnd   │ TsdFrame    │
│ PoseEstimationSeriesRightPupilTop    │ TsdFrame    │
│ PoseEstimationSeriesRightPupilRight  │ TsdFrame    │
│ PoseEstimationSeriesRightPupilLeft   │ TsdFrame    │
│ PoseEstimationSeriesRightPupilBottom │ TsdFrame    │
│ PoseEstimationSeriesRightPaw         │ TsdFrame    │
│ PoseEstimationSeriesNoseTip          │ TsdFrame    │
│ PoseEstimationSeriesLeftTongueEnd    │ TsdFrame    │
│ PoseEstimationSeriesLeftPaw          │ TsdFrame    │
│ PoseEstimationSeriesTailStart        │ TsdFrame    │
│ passive_task_replay                  │ IntervalSet │
│ passive_intervals                    │ IntervalSet │
│ gabor_table                          │ IntervalSet │
│ RightCameraMotionEnergy              │ Tsd         │
│ BodyCameraMotionEnergy               │ Tsd         │
┕━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙

The only fields we’ll use here are the neural activity (‘units’) and the trial information (‘trials’).

spikes = nwb['units']
trials = nwb['trials']
spikes
Index    rate      unit_name    noise_cutoff    spike_amplitude_std_dB    firing_rate    max_spike_amplitude_uV    spike_count    ...
-------  --------  -----------  --------------  ------------------------  -------------  ------------------------  -------------  -----
0        0.82357   probe00_0    -0.42           1.36                      0.82           562.12                    5079           ...
1        10.99946  probe00_1    259.91          1.54                      11.0           321.68                    67834          ...
2        15.07013  probe00_2    204.26          1.46                      15.07          121.67                    92938          ...
3        5.16569   probe00_3    148.37          1.24                      5.17           135.71                    31857          ...
4        25.96824  probe00_4    44.18           1.47                      25.97          239.12                    160147         ...
5        5.13391   probe00_5    311.32          1.47                      5.13           229.67                    31661          ...
6        15.9644   probe00_6    242.57          1.38                      15.96          227.14                    98453          ...
...      ...       ...          ...             ...                       ...            ...                       ...            ...
903      2.0428    probe00_903  15.81           1.49                      2.04           123.44                    12598          ...
904      0.09989   probe00_904  -0.73           1.48                      0.1            139.46                    616            ...
905      0.16069   probe00_905  -0.85           1.09                      0.16           111.41                    991            ...
906      0.17091   probe00_906  -1.15           1.43                      0.17           82.67                     1054           ...
907      0.05351   probe00_907  4.93            1.37                      0.05           128.79                    330            ...
908      0.05351   probe00_908  -1.5            2.74                      0.05           95.75                     330            ...
909      0.1041    probe00_909  -0.94           1.11                      0.1            164.79                    642            ...
trials
index    start           end             quiescence_period    gabor_stimulus_onset_time    auditory_cue_time    wheel_movement_onset_time    choice_registration_time    feedback_time    ...
0        36.78385458     40.007964165    0.46                 37.69                        37.69                37.92                        38.44                       38.44            ...
1        40.37566203     43.274770522    0.46                 41.39                        41.39                41.59                        41.71                       41.71            ...
2        43.63566933     46.292868485    0.68                 44.39                        44.39                44.54                        44.71                       44.71            ...
3        46.66926594     49.190979041    0.61                 47.36                        47.36                47.51                        47.61                       47.61            ...
4        49.55207763     55.159877451    0.49                 50.11                        50.11                53.28                        53.59                       53.59            ...
5        55.57487529     60.826513102    0.6                  56.26                        56.26                58.9                         59.25                       59.25            ...
6        61.19131359     66.742723383    0.62                 61.89                        61.89                65.03                        65.17                       65.17            ...
...      ...             ...             ...                  ...                          ...                  ...                          ...                         ...              ...
887      4665.141205895  4674.910151903  0.43                 4665.63                      4665.63              4672.21                      4672.34                     4672.34          ...
888      4675.984954745  4690.743070591  0.68                 4676.88                      4676.88              4688.64                      4689.19                     4689.19          ...
889      4691.834669615  4701.794999598  0.61                 4693.23                      4693.23              4698.98                      4699.21                     4699.21          ...
890      4702.867999505  4765.977571354  0.53                 4703.48                      4703.48              nan                          4763.48                     4763.48          ...
891      4767.122273435  4830.17841857   0.48                 4767.68                      4767.68              4820.2                       4827.68                     4827.68          ...
892      4831.334918195  4834.507429035  0.64                 4832.02                      4832.03              4832.64                      4832.96                     4832.96          ...
893      4835.563328645  4845.873856126  0.59                 4836.19                      4836.19              4843.87                      4844.31                     4844.31          ...
shape: (894, 2), time unit: sec.

Trial alignment and binning#

Here, we start by discarding some trials where a choice wasn’t made (a more complete analysis may include more criteria).

valid_choice = trials.mouse_wheel_choice != "none"
trials = trials[valid_choice]

Now, we can use compute_perievent to align the spikes to the stimulus onset times. We will choose a window around each stimulus that extends back 0.5s and forward 1s.

stimulus_onsets = nap.Ts(t=trials.gabor_stimulus_onset_time.values)
window=(-0.5, 1.0)
trial_aligned_spikes = nap.compute_perievent(data=spikes, events=stimulus_onsets, window=window)

The result is a dictionary of TsGroup, one per unit, containing that unit’s spikes relative to the onset time. We can easily visualize that as follows:

example_unit = 42
plt.plot(trial_aligned_spikes[example_unit].to_tsd(), "|", markersize=5)
plt.xlabel("time from stim (s)")
plt.ylabel("stimulus")
plt.xlim(*window)
plt.axvline(0.0, color="red");
../_images/9f49046ac9deecd63d9f86daf5c286240f46f79138122e018960948380eec9fb.png

Note

See the perievent user guide for how this works and other visualizations!

We can then bin and count these spikes as follows:

bin_size = 0.1
trial_aligned_binned_spikes = np.stack([trial_aligned_spikes[unit].count(bin_size) for unit in spikes], axis=1)
trial_aligned_binned_spikes.shape
(15, 910, 892)

Let’s visualize the neural activity in a single trial:

example_trial = 42
plt.imshow(
    trial_aligned_binned_spikes[:, :, example_trial].values.T,
    aspect="auto",
    cmap="Grays",
    interpolation="none",
    extent=(
        trial_aligned_binned_spikes.times()[0],
        trial_aligned_binned_spikes.times()[-1],
        0,
        trial_aligned_binned_spikes.shape[1],
    ),
)
plt.axvline(0, color='red')
plt.ylabel('unit')
plt.xlabel('time (s)');
../_images/94c92f9132586c6da6ca46feb907a9b3beabc78a08d6fd09798bf3e8b1fe270f.png

Decoding choice#

Let’s use scikit-learn to fit a logistic regression. We can use their cross-validation tools to pick the regularization parameter

# process the choice from -1, 1 to 0,1
y = (trials.mouse_wheel_choice == "clockwise").astype(int)
# transpose the data to lead with trial dimension
X = trial_aligned_binned_spikes.swapaxes(0, 2)
# standardize per neuron + trial
X = (X - X.mean(2, keepdims=True)) / X.std(2, keepdims=True).clip(min=1e-6)
# reshape to n_trials x n_features
X = X.reshape(len(X), -1)
# loguniform randomized search CV logistic regression
clf = RandomizedSearchCV(
    LogisticRegression(),
    param_distributions={'C': scipy.stats.loguniform(1e-3, 1e3)},
    random_state=0,
)
clf.fit(X, y);
cv_df = pd.DataFrame(clf.cv_results_)
best_row = cv_df[cv_df.rank_test_score == 1]
best_row = best_row.iloc[best_row.param_C.argmin()]
plt.figure(figsize=(2, 4))
plt.boxplot(best_row[[c for c in best_row.keys() if c.startswith('split')]])
plt.ylabel('accuracy in fold')
plt.xticks([])
plt.title("choice decoding accuracy\nin 5-fold randomized CV");
../_images/d2359e14ae4a5d402c94d1d6ebf0f4a8bc23f68d2fa917fd18ed0b135dca8e6f.png

Authors

Charlie Windolf