Trial-aligned decoding with International Brain Lab data#

This example works through a basic pipeline for decoding the mouse’s choice from spiking activity in the International Brain Lab’s decision task, including loading the data from DANDI, processing and trial-aligning the neural activity, and fitting a logistic regression with cross-validation using scikit-learn.

The International Brain Lab’s Brain Wide Map dataset is available at Dandiset 00409. The International Brain Lab’s BWM website includes links to their preprint and additional documentation. The IBL also has an excellent decoding demonstration in the COSYNE section of their events webpage under “Tutorial 2: Advanced analysis”, along with many other relevant demos.

For a more detailed tutorial on data loading with DANDI, see the “Streaming data from DANDI” example!

Caveats: This example is meant to provide a simple starting point for working with trial-aligned data and data from the IBL, and so it does not faithfully replicate the IBL’s quality control and filtering criteria; the decoding here is also simpler than the analyses carried out in those works.

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pynapple as nap
import scipy.stats

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV

Loading data#

See also the “Streaming data from DANDI”, which includes more detail.

# These modules are used for data loading
from dandi.dandiapi import DandiAPIClient
import fsspec
from fsspec.implementations.cached import CachingFileSystem
import h5py
from pynwb import NWBHDF5IO
# The BWM dandiset:
dandiset_id = "000409"
# This is a randomly chosen recording session.
asset_path = "sub-CSH-ZAD-026/sub-CSH-ZAD-026_ses-15763234-d21e-491f-a01b-1238eb96d389_behavior+ecephys+image.nwb"
with DandiAPIClient() as client:
    asset = client.get_dandiset(dandiset_id, "draft").get_asset_by_path(asset_path)
    s3_url = asset.get_content_url(follow_redirects=1, strip_query=True)
fs = CachingFileSystem(
    fs=fsspec.filesystem("http"),
    cache_storage=str(Path("~/.caches/nwb-cache").expanduser()),
)
io = NWBHDF5IO(file=h5py.File(fs.open(s3_url, "rb")), load_namespaces=True)
nwb = nap.NWBFile(io.read())
/home/runner/.local/lib/python3.12/site-packages/hdmf/spec/namespace.py:535: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.8.0 is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/runner/.local/lib/python3.12/site-packages/hdmf/spec/namespace.py:535: UserWarning: Ignoring cached namespace 'core' version 2.5.0 because version 2.8.0 is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/runner/.local/lib/python3.12/site-packages/hdmf/spec/namespace.py:535: UserWarning: Ignoring cached namespace 'hdmf-experimental' version 0.2.0 because version 0.5.0 is already loaded.
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."

This NWB file, loaded into a pynapple object, contains neural activity, behavioral data, and raw electrophysiological traces. They’re lazily loaded, so only what we use will be downloaded.

nwb
15763234-d21e-491f-a01b-1238eb96d389
┍━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys                                     │ Type        │
┝━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units                                    │ TsGroup     │
│ trials                                   │ IntervalSet │
│ WheelMovementIntervals                   │ IntervalSet │
│ WheelVelocity                            │ Tsd         │
│ WheelAcceleration                        │ Tsd         │
│ RightSmoothedPupilDiameter               │ Tsd         │
│ RightRawPupilDiameter                    │ Tsd         │
│ RightCameraMotionEnergy                  │ Tsd         │
│ PoseEstimationRightCamera/tube_top       │ TsdFrame    │
│ PoseEstimationRightCamera/tube_bottom    │ TsdFrame    │
│ PoseEstimationRightCamera/tongue_end_r   │ TsdFrame    │
│ PoseEstimationRightCamera/tongue_end_l   │ TsdFrame    │
│ PoseEstimationRightCamera/pupil_top_r    │ TsdFrame    │
│ PoseEstimationRightCamera/pupil_right_r  │ TsdFrame    │
│ PoseEstimationRightCamera/pupil_left_r   │ TsdFrame    │
│ PoseEstimationRightCamera/pupil_bottom_r │ TsdFrame    │
│ PoseEstimationRightCamera/paw_r          │ TsdFrame    │
│ PoseEstimationRightCamera/paw_l          │ TsdFrame    │
│ PoseEstimationRightCamera/nose_tip       │ TsdFrame    │
│ PoseEstimationLeftCamera/tube_top        │ TsdFrame    │
│ PoseEstimationLeftCamera/tube_bottom     │ TsdFrame    │
│ PoseEstimationLeftCamera/tongue_end_r    │ TsdFrame    │
│ PoseEstimationLeftCamera/tongue_end_l    │ TsdFrame    │
│ PoseEstimationLeftCamera/pupil_top_r     │ TsdFrame    │
│ PoseEstimationLeftCamera/pupil_right_r   │ TsdFrame    │
│ PoseEstimationLeftCamera/pupil_left_r    │ TsdFrame    │
│ PoseEstimationLeftCamera/pupil_bottom_r  │ TsdFrame    │
│ PoseEstimationLeftCamera/paw_r           │ TsdFrame    │
│ PoseEstimationLeftCamera/paw_l           │ TsdFrame    │
│ PoseEstimationLeftCamera/nose_tip        │ TsdFrame    │
│ tail_start                               │ TsdFrame    │
│ LeftSmoothedPupilDiameter                │ Tsd         │
│ LeftRawPupilDiameter                     │ Tsd         │
│ LeftCameraMotionEnergy                   │ Tsd         │
│ WheelPositionSeries                      │ Tsd         │
│ BodyCameraMotionEnergy                   │ Tsd         │
│ OriginalVideoRightCamera                 │ TsdTensor   │
│ OriginalVideoLeftCamera                  │ TsdTensor   │
│ OriginalVideoBodyCamera                  │ TsdTensor   │
│ ElectricalSeriesLf                       │ TsdFrame    │
│ ElectricalSeriesAp                       │ TsdFrame    │
┕━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙

The only fields we’ll use here are the neural activity (‘units’) and the trial information (‘trials’).

spikes = nwb['units']
trials = nwb['trials']
/home/runner/.local/lib/python3.12/site-packages/pynapple/core/base_class.py:51: UserWarning: Some epochs have no duration
  self.time_support = IntervalSet(start=self.index[0], end=self.index[-1])
/home/runner/.local/lib/python3.12/site-packages/pynapple/core/base_class.py:53: RuntimeWarning: divide by zero encountered in scalar divide
  self.rate = self.index.shape[0] / np.sum(
spikes
Index    rate      unit_name    presence_ratio_standard_deviation    contamination        noise_cutoff         mean_relative_depth    sliding_refractory_period_violation    ...
-------  --------  -----------  -----------------------------------  -------------------  -------------------  ---------------------  -------------------------------------  -----
0        15.13893  153          13.666350165351007                   0.07900398430955985  -0.7963905293815864  20.0                   1.0                                    ...
1        12.00565  48           20.769287422255918                   0.11892361023182427  79.8204538815869     20.0                   1.0                                    ...
2        10.96317  66           24.43411286786127                    0.24423269776093587  309.40347763690806   20.0                   1.0                                    ...
3        14.71296  15           15.965379830365684                   0.17897759314496184  224.17425800396978   20.0                   1.0                                    ...
4        1.60872   167          6.918140924507259                    0.52036432396544     71.20151297069303    20.0                   0.0                                    ...
5        26.26409  185          1.4261649779526435                   0.0                  16.74737163643705    20.0                   0.0                                    ...
6        8.91826   479          0.8739017663273241                   0.0                  16.813435702191622   40.0                   0.0                                    ...
...      ...       ...          ...                                  ...                  ...                  ...                    ...                                    ...
704      0.15972   644          2.730299038382819                    0.0                  -1.304210347384608   2540.0                 0.0                                    ...
705      0.00049   656          0.7696007901166254                   0.0                  -0.2590816235691618  3140.0                 0.0                                    ...
706      0.00551   698          0.5889345249665567                   0.0                  8.498365855987974    3140.0                 0.0                                    ...
707      0.56916   485          2.6945401464618075                   2.2246960189648073   -0.7736837677230682  3160.0                 0.0                                    ...
708      0.00016   213          5.195267757446576                    0.0                  3.539121883631855    880.0                  0.0                                    ...
709      0.08367   572          6.059575867463959                    0.0                  1.5310167774901269   1520.0                 0.0                                    ...
710      1.33274   328          1.1597099909096578                   0.0                  -1.093987436234209   3140.0                 0.0                                    ...
trials
index    start           end             choice    feedback_type    reward_volume    contrast_left    contrast_right    probability_left    feedback_time      response_time       stim_off_time       stim_on_time        go_cue_time         first_movement_time
0        36.78385458     40.007964165    -1.0      1.0              1.5              nan              1.0               0.5                 38.44465515003844  38.44455848697059   39.507848820039506  37.6934969400377    37.694296860037696  37.92137529000062
1        40.37566203     43.274770522    1.0       1.0              1.5              0.125            nan               0.5                 41.70776214004171  41.70766501274109   42.77468877004277   41.39339358004139   41.3942601600414    41.58637529000062
2        43.63566933     46.292868485    1.0       1.0              1.5              1.0              nan               0.5                 44.70686220004471  44.70676778002906   45.79278693004579   44.393426880044395  44.39422680004439   44.53837529000062
3        46.66926594     49.190979041    1.0       1.0              1.5              0.125            nan               0.5                 47.61247161004761  47.61237009154725   48.69086376004869   47.35786374004736   47.35883031004736   47.507375290000624
4        49.55207763     55.159877451    -1.0      1.0              1.5              nan              0.125             0.5                 53.58847395005359  53.58837606185574   54.65983347005466   50.10965520005011   50.11048845005011   53.27937529000063
5        55.57487529     60.826513102    -1.0      1.0              1.5              nan              0.25              0.5                 59.24850789005925  59.248400340812395  60.326400090060325  56.25757368005626   56.25840693005626   58.90437529000064
6        61.19131359     66.742723383    -1.0      1.0              1.5              nan              0.25              0.5                 65.17321536006517  65.17311846216711   66.24267507006624   61.89067698006189   61.89164355006189   65.03037529000063
...      ...             ...             ...       ...              ...              ...              ...               ...                 ...                ...                 ...                 ...                 ...                 ...
887      4665.141205895  4674.910151903  -1.0      -1.0             0.0              0.0              nan               0.8                 4672.342252384672  4672.341238330095   4674.410112244675   4665.629123764666   4665.629990344666   4672.208375290005
888      4675.984954745  4690.743070591  1.0       1.0              1.5              0.0625           nan               0.8                 4689.187767664689  4689.187667131305   4690.24306212469    4676.878932004677   4676.879865244677   4688.641375290005
889      4691.834669615  4701.794999598  -1.0      -1.0             0.0              0.125            nan               0.8                 4699.211265214699  4699.210189595457   4701.294923494701   4693.228663534694   4693.229663434693   4698.9773752900055
890      4702.867999505  4765.977571354  0.0       -1.0             0.0              0.0625           nan               0.2                 4763.478537844763  4763.477562625614   4763.644121284764   4703.476272004704   4703.477238574704   nan
891      4767.122273435  4830.17841857   0.0       -1.0             0.0              nan              0.0625            0.2                 4827.679150474828  4827.6784108947095  4827.957289324828   4767.191766484767   4767.678151174768   4820.196375290006
892      4831.334918195  4834.507429035  1.0       1.0              1.5              0.125            nan               0.2                 4832.960822254833  4832.960722212809   4834.007350924834   4831.376214064831   4832.025082504832   4832.638375290006
893      4835.563328645  4845.873856126  -1.0      1.0              1.5              nan              1.0               0.2                 4844.314253464844  4844.314150271431   4861.906727374862   4836.192465724836   4836.1933656348365  4843.874375290005
shape: (894, 2), time unit: sec.

Trial alignment and binning#

Here, we start by discarding some trials where a choice wasn’t made (a fuller analysis may include more criteria). Next, we create an IntervalSet aligned to the stimulus onset time within each trial, where the window around each stimulus extends back 0.5s and forward 2s.

valid_choice = trials.choice != 0
trials = trials[valid_choice]
stim_on_intervals = nap.IntervalSet(
    start=trials.stim_on_time - 0.5,
    end=trials.stim_on_time + 2.0,
)

Now, use build_tensor to align the neural data to these stimulus windows and bin it in 0.1s time bins.

trial_aligned_binned_spikes = nap.build_tensor(
    spikes, stim_on_intervals, bin_size=0.1
)
trial_aligned_binned_spikes.shape
(711, 892, 25)

The result is n_neurons x n_trials x n_bins. Let’s visualize the neural activity in a single trial.

trial = 645
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(
    trial_aligned_binned_spikes[:, trial],
    aspect='auto',
    cmap=plt.cm.gist_yarg,
    interpolation='none',
)
ax.axvline(5.5, lw=1, color='r', label='stimulus time')
ax.set_ylabel('units')
ax.set_xlabel('time bin (0.1s)')
ax.legend(frameon=False, loc='lower right')
ax.set_title(f'binned spiking activity in trial {trial}')
Text(0.5, 1.0, 'binned spiking activity in trial 645')
../_images/7be9796ee21c2bfd74e017099234f9c0fb57ece7f4b9e49e5fac44226368f410.png

Decoding choice#

Let’s use scikit-learn to fit a logistic regression. We can use their cross-validation tools to pick the regularization parameter

# process the choice from -1, 1 to 0,1
y = (trials.choice + 1) / 2
# transpose the data to lead with trial dimension
X = trial_aligned_binned_spikes.transpose(1, 0, 2)
# standardize per neuron + trial
X = (X - X.mean(2, keepdims=True)) / X.std(2, keepdims=True).clip(min=1e-6)
# reshape to n_trials x n_features
X = X.reshape(len(X), -1)
# loguniform randomized search CV logistic regression
clf = RandomizedSearchCV(
    LogisticRegression(),
    param_distributions={'C': scipy.stats.loguniform(1e-3, 1e3)},
    random_state=0,
)
clf.fit(X, y)
RandomizedSearchCV(estimator=LogisticRegression(),
                   param_distributions={'C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x7f28e1537bc0>},
                   random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
cv_df = pd.DataFrame(clf.cv_results_)
best_row = cv_df[cv_df.rank_test_score == 1]
best_row = best_row.iloc[best_row.param_C.argmin()]
plt.figure(figsize=(2, 4))
plt.boxplot(best_row[[c for c in best_row.keys() if c.startswith('split')]])
plt.ylabel('accuracy in fold')
plt.xticks([])
plt.title("choice decoding accuracy\nin 5-fold randomized CV")
Text(0.5, 1.0, 'choice decoding accuracy\nin 5-fold randomized CV')
../_images/450366a9d52025b8fcae31ea7dded0aa51da306a28a97bd19a6255992ae42f30.png

Authors

Charlie Windolf