Trial-aligned decoding with International Brain Lab data#
This example works through a basic pipeline for decoding the mouse’s choice from spiking activity in the International Brain Lab’s decision task, including loading the data from DANDI, processing and trial-aligning the neural activity, and fitting a logistic regression with cross-validation using scikit-learn.
The International Brain Lab’s Brain Wide Map dataset is available at Dandiset 00409. The International Brain Lab’s BWM website includes links to their preprint and additional documentation. The IBL also has an excellent decoding demonstration in the COSYNE section of their events webpage under “Tutorial 2: Advanced analysis”, along with many other relevant demos.
For a more detailed tutorial on data loading with DANDI, see the “Streaming data from DANDI” example!
Caveats: This example is meant to provide a simple starting point for working with trial-aligned data and data from the IBL, and so it does not faithfully replicate the IBL’s quality control and filtering criteria; the decoding here is also simpler than the analyses carried out in those works.
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pynapple as nap
import scipy.stats
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
Loading data#
See also the “Streaming data from DANDI”, which includes more detail.
# These modules are used for data loading
from dandi.dandiapi import DandiAPIClient
import fsspec
from fsspec.implementations.cached import CachingFileSystem
import h5py
from pynwb import NWBHDF5IO
# The BWM dandiset:
dandiset_id = "000409"
# This is a randomly chosen recording session.
asset_path = "sub-CSH-ZAD-026/sub-CSH-ZAD-026_ses-15763234-d21e-491f-a01b-1238eb96d389_behavior+ecephys+image.nwb"
with DandiAPIClient() as client:
asset = client.get_dandiset(dandiset_id, "draft").get_asset_by_path(asset_path)
s3_url = asset.get_content_url(follow_redirects=1, strip_query=True)
fs = CachingFileSystem(
fs=fsspec.filesystem("http"),
cache_storage=str(Path("~/.caches/nwb-cache").expanduser()),
)
io = NWBHDF5IO(file=h5py.File(fs.open(s3_url, "rb")), load_namespaces=True)
nwb = nap.NWBFile(io.read())
/home/runner/.local/lib/python3.12/site-packages/hdmf/spec/namespace.py:535: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.8.0 is already loaded.
warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/runner/.local/lib/python3.12/site-packages/hdmf/spec/namespace.py:535: UserWarning: Ignoring cached namespace 'core' version 2.5.0 because version 2.8.0 is already loaded.
warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
/home/runner/.local/lib/python3.12/site-packages/hdmf/spec/namespace.py:535: UserWarning: Ignoring cached namespace 'hdmf-experimental' version 0.2.0 because version 0.5.0 is already loaded.
warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
This NWB file, loaded into a pynapple object, contains neural activity, behavioral data, and raw electrophysiological traces. They’re lazily loaded, so only what we use will be downloaded.
nwb
15763234-d21e-491f-a01b-1238eb96d389
┍━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━┑
│ Keys │ Type │
┝━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━┥
│ units │ TsGroup │
│ trials │ IntervalSet │
│ WheelMovementIntervals │ IntervalSet │
│ WheelVelocity │ Tsd │
│ WheelAcceleration │ Tsd │
│ RightSmoothedPupilDiameter │ Tsd │
│ RightRawPupilDiameter │ Tsd │
│ RightCameraMotionEnergy │ Tsd │
│ PoseEstimationRightCamera/tube_top │ TsdFrame │
│ PoseEstimationRightCamera/tube_bottom │ TsdFrame │
│ PoseEstimationRightCamera/tongue_end_r │ TsdFrame │
│ PoseEstimationRightCamera/tongue_end_l │ TsdFrame │
│ PoseEstimationRightCamera/pupil_top_r │ TsdFrame │
│ PoseEstimationRightCamera/pupil_right_r │ TsdFrame │
│ PoseEstimationRightCamera/pupil_left_r │ TsdFrame │
│ PoseEstimationRightCamera/pupil_bottom_r │ TsdFrame │
│ PoseEstimationRightCamera/paw_r │ TsdFrame │
│ PoseEstimationRightCamera/paw_l │ TsdFrame │
│ PoseEstimationRightCamera/nose_tip │ TsdFrame │
│ PoseEstimationLeftCamera/tube_top │ TsdFrame │
│ PoseEstimationLeftCamera/tube_bottom │ TsdFrame │
│ PoseEstimationLeftCamera/tongue_end_r │ TsdFrame │
│ PoseEstimationLeftCamera/tongue_end_l │ TsdFrame │
│ PoseEstimationLeftCamera/pupil_top_r │ TsdFrame │
│ PoseEstimationLeftCamera/pupil_right_r │ TsdFrame │
│ PoseEstimationLeftCamera/pupil_left_r │ TsdFrame │
│ PoseEstimationLeftCamera/pupil_bottom_r │ TsdFrame │
│ PoseEstimationLeftCamera/paw_r │ TsdFrame │
│ PoseEstimationLeftCamera/paw_l │ TsdFrame │
│ PoseEstimationLeftCamera/nose_tip │ TsdFrame │
│ tail_start │ TsdFrame │
│ LeftSmoothedPupilDiameter │ Tsd │
│ LeftRawPupilDiameter │ Tsd │
│ LeftCameraMotionEnergy │ Tsd │
│ WheelPositionSeries │ Tsd │
│ BodyCameraMotionEnergy │ Tsd │
│ OriginalVideoRightCamera │ TsdTensor │
│ OriginalVideoLeftCamera │ TsdTensor │
│ OriginalVideoBodyCamera │ TsdTensor │
│ ElectricalSeriesLf │ TsdFrame │
│ ElectricalSeriesAp │ TsdFrame │
┕━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━┙
The only fields we’ll use here are the neural activity (‘units’) and the trial information (‘trials’).
spikes = nwb['units']
trials = nwb['trials']
/home/runner/.local/lib/python3.12/site-packages/pynapple/core/base_class.py:51: UserWarning: Some epochs have no duration
self.time_support = IntervalSet(start=self.index[0], end=self.index[-1])
/home/runner/.local/lib/python3.12/site-packages/pynapple/core/base_class.py:53: RuntimeWarning: divide by zero encountered in scalar divide
self.rate = self.index.shape[0] / np.sum(
spikes
Index rate unit_name presence_ratio_standard_deviation contamination noise_cutoff mean_relative_depth sliding_refractory_period_violation ...
------- -------- ----------- ----------------------------------- ------------------- ------------------- --------------------- ------------------------------------- -----
0 15.13893 153 13.666350165351007 0.07900398430955985 -0.7963905293815864 20.0 1.0 ...
1 12.00565 48 20.769287422255918 0.11892361023182427 79.8204538815869 20.0 1.0 ...
2 10.96317 66 24.43411286786127 0.24423269776093587 309.40347763690806 20.0 1.0 ...
3 14.71296 15 15.965379830365684 0.17897759314496184 224.17425800396978 20.0 1.0 ...
4 1.60872 167 6.918140924507259 0.52036432396544 71.20151297069303 20.0 0.0 ...
5 26.26409 185 1.4261649779526435 0.0 16.74737163643705 20.0 0.0 ...
6 8.91826 479 0.8739017663273241 0.0 16.813435702191622 40.0 0.0 ...
... ... ... ... ... ... ... ... ...
704 0.15972 644 2.730299038382819 0.0 -1.304210347384608 2540.0 0.0 ...
705 0.00049 656 0.7696007901166254 0.0 -0.2590816235691618 3140.0 0.0 ...
706 0.00551 698 0.5889345249665567 0.0 8.498365855987974 3140.0 0.0 ...
707 0.56916 485 2.6945401464618075 2.2246960189648073 -0.7736837677230682 3160.0 0.0 ...
708 0.00016 213 5.195267757446576 0.0 3.539121883631855 880.0 0.0 ...
709 0.08367 572 6.059575867463959 0.0 1.5310167774901269 1520.0 0.0 ...
710 1.33274 328 1.1597099909096578 0.0 -1.093987436234209 3140.0 0.0 ...
trials
index start end choice feedback_type reward_volume contrast_left contrast_right probability_left feedback_time response_time stim_off_time stim_on_time go_cue_time first_movement_time
0 36.78385458 40.007964165 -1.0 1.0 1.5 nan 1.0 0.5 38.44465515003844 38.44455848697059 39.507848820039506 37.6934969400377 37.694296860037696 37.92137529000062
1 40.37566203 43.274770522 1.0 1.0 1.5 0.125 nan 0.5 41.70776214004171 41.70766501274109 42.77468877004277 41.39339358004139 41.3942601600414 41.58637529000062
2 43.63566933 46.292868485 1.0 1.0 1.5 1.0 nan 0.5 44.70686220004471 44.70676778002906 45.79278693004579 44.393426880044395 44.39422680004439 44.53837529000062
3 46.66926594 49.190979041 1.0 1.0 1.5 0.125 nan 0.5 47.61247161004761 47.61237009154725 48.69086376004869 47.35786374004736 47.35883031004736 47.507375290000624
4 49.55207763 55.159877451 -1.0 1.0 1.5 nan 0.125 0.5 53.58847395005359 53.58837606185574 54.65983347005466 50.10965520005011 50.11048845005011 53.27937529000063
5 55.57487529 60.826513102 -1.0 1.0 1.5 nan 0.25 0.5 59.24850789005925 59.248400340812395 60.326400090060325 56.25757368005626 56.25840693005626 58.90437529000064
6 61.19131359 66.742723383 -1.0 1.0 1.5 nan 0.25 0.5 65.17321536006517 65.17311846216711 66.24267507006624 61.89067698006189 61.89164355006189 65.03037529000063
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
887 4665.141205895 4674.910151903 -1.0 -1.0 0.0 0.0 nan 0.8 4672.342252384672 4672.341238330095 4674.410112244675 4665.629123764666 4665.629990344666 4672.208375290005
888 4675.984954745 4690.743070591 1.0 1.0 1.5 0.0625 nan 0.8 4689.187767664689 4689.187667131305 4690.24306212469 4676.878932004677 4676.879865244677 4688.641375290005
889 4691.834669615 4701.794999598 -1.0 -1.0 0.0 0.125 nan 0.8 4699.211265214699 4699.210189595457 4701.294923494701 4693.228663534694 4693.229663434693 4698.9773752900055
890 4702.867999505 4765.977571354 0.0 -1.0 0.0 0.0625 nan 0.2 4763.478537844763 4763.477562625614 4763.644121284764 4703.476272004704 4703.477238574704 nan
891 4767.122273435 4830.17841857 0.0 -1.0 0.0 nan 0.0625 0.2 4827.679150474828 4827.6784108947095 4827.957289324828 4767.191766484767 4767.678151174768 4820.196375290006
892 4831.334918195 4834.507429035 1.0 1.0 1.5 0.125 nan 0.2 4832.960822254833 4832.960722212809 4834.007350924834 4831.376214064831 4832.025082504832 4832.638375290006
893 4835.563328645 4845.873856126 -1.0 1.0 1.5 nan 1.0 0.2 4844.314253464844 4844.314150271431 4861.906727374862 4836.192465724836 4836.1933656348365 4843.874375290005
shape: (894, 2), time unit: sec.
Trial alignment and binning#
Here, we start by discarding some trials where a choice wasn’t made (a fuller analysis may include more criteria). Next, we create an IntervalSet
aligned to the stimulus onset time within each trial, where the window around each stimulus extends back 0.5s and forward 2s.
valid_choice = trials.choice != 0
trials = trials[valid_choice]
stim_on_intervals = nap.IntervalSet(
start=trials.stim_on_time - 0.5,
end=trials.stim_on_time + 2.0,
)
Now, use build_tensor
to align the neural data to these stimulus windows and bin it in 0.1s time bins.
trial_aligned_binned_spikes = nap.build_tensor(
spikes, stim_on_intervals, bin_size=0.1
)
trial_aligned_binned_spikes.shape
(711, 892, 25)
The result is n_neurons x n_trials x n_bins
. Let’s visualize the neural activity in a single trial.
trial = 645
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(
trial_aligned_binned_spikes[:, trial],
aspect='auto',
cmap=plt.cm.gist_yarg,
interpolation='none',
)
ax.axvline(5.5, lw=1, color='r', label='stimulus time')
ax.set_ylabel('units')
ax.set_xlabel('time bin (0.1s)')
ax.legend(frameon=False, loc='lower right')
ax.set_title(f'binned spiking activity in trial {trial}')
Text(0.5, 1.0, 'binned spiking activity in trial 645')

Decoding choice#
Let’s use scikit-learn to fit a logistic regression. We can use their cross-validation tools to pick the regularization parameter
# process the choice from -1, 1 to 0,1
y = (trials.choice + 1) / 2
# transpose the data to lead with trial dimension
X = trial_aligned_binned_spikes.transpose(1, 0, 2)
# standardize per neuron + trial
X = (X - X.mean(2, keepdims=True)) / X.std(2, keepdims=True).clip(min=1e-6)
# reshape to n_trials x n_features
X = X.reshape(len(X), -1)
# loguniform randomized search CV logistic regression
clf = RandomizedSearchCV(
LogisticRegression(),
param_distributions={'C': scipy.stats.loguniform(1e-3, 1e3)},
random_state=0,
)
clf.fit(X, y)
RandomizedSearchCV(estimator=LogisticRegression(), param_distributions={'C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x7f28e1537bc0>}, random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomizedSearchCV(estimator=LogisticRegression(), param_distributions={'C': <scipy.stats._distn_infrastructure.rv_continuous_frozen object at 0x7f28e1537bc0>}, random_state=0)
LogisticRegression(C=np.float64(0.348280208702833))
LogisticRegression(C=np.float64(0.348280208702833))
cv_df = pd.DataFrame(clf.cv_results_)
best_row = cv_df[cv_df.rank_test_score == 1]
best_row = best_row.iloc[best_row.param_C.argmin()]
plt.figure(figsize=(2, 4))
plt.boxplot(best_row[[c for c in best_row.keys() if c.startswith('split')]])
plt.ylabel('accuracy in fold')
plt.xticks([])
plt.title("choice decoding accuracy\nin 5-fold randomized CV")
Text(0.5, 1.0, 'choice decoding accuracy\nin 5-fold randomized CV')

Authors
Charlie Windolf