# -*- coding: utf-8 -*-
# @Author: Guillaume Viejo
# @Date: 2023-08-01 11:54:45
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2024-05-21 15:28:27
"""
Pynapple class to interface with NWB files.
Data are always lazy-loaded.
Object behaves like dictionary.
"""
import errno
import os
import warnings
from collections import UserDict
from numbers import Number
from pathlib import Path
import numpy as np
import pynwb
from pynwb import NWBHDF5IO
from tabulate import tabulate
from .. import core as nap
def _extract_compatible_data_from_nwbfile(nwbfile):
"""Extract all the NWB objects that can be converted to a pynapple object.
Parameters
----------
nwbfile : pynwb.file.NWBFile
Instance of NWB file
Returns
-------
dict
Dictionary containing all the object found and their type in pynapple.
"""
data = {}
for oid, obj in nwbfile.objects.items():
if isinstance(obj, pynwb.misc.DynamicTable) and any(
[i.name.endswith("_times_index") for i in obj.columns]
):
data["units"] = {"id": oid, "type": "TsGroup"}
elif isinstance(obj, pynwb.epoch.TimeIntervals):
# Supposedly IntervalsSets
data[obj.name] = {"id": oid, "type": "IntervalSet"}
elif isinstance(obj, pynwb.misc.DynamicTable) and any(
[i.name.endswith("_times") for i in obj.columns]
):
# Supposedly Timestamps
data[obj.name] = {"id": oid, "type": "Ts"}
elif isinstance(obj, pynwb.misc.AnnotationSeries):
# Old timestamps version
data[obj.name] = {"id": oid, "type": "Ts"}
elif isinstance(obj, pynwb.misc.TimeSeries):
if len(obj.data.shape) > 2:
data[obj.name] = {"id": oid, "type": "TsdTensor"}
elif len(obj.data.shape) == 2:
data[obj.name] = {"id": oid, "type": "TsdFrame"}
elif len(obj.data.shape) == 1:
data[obj.name] = {"id": oid, "type": "Tsd"}
return data
def _make_interval_set(obj, **kwargs):
"""Helper function to make IntervalSet
Parameters
----------
obj : pynwb.epoch.TimeIntervals
NWB object
Returns
-------
IntervalSet or dict of IntervalSet or pandas.DataFrame
If contains multiple epochs, a dictionary of IntervalSet is returned.
It too many metadata, the function returns the output of nwbfile.trials.to_dataframe()
"""
if hasattr(obj, "to_dataframe"):
df = obj.to_dataframe()
if hasattr(df, "start_time") and hasattr(df, "stop_time"):
if df.shape[1] == 2:
data = nap.IntervalSet(start=df["start_time"], end=df["stop_time"])
return data
group_by_key = None
if "tags" in df.columns:
group_by_key = "tags"
elif df.shape[1] == 3: # assuming third column is the tag
group_by_key = df.columns[2]
if group_by_key:
for i in df.index:
if isinstance(df.loc[i, group_by_key], (list, tuple, np.ndarray)):
df.loc[i, group_by_key] = "-".join(
[str(j) for j in df.loc[i, group_by_key]]
)
data = {}
for k, subdf in df.groupby(group_by_key):
data[k] = nap.IntervalSet(
start=subdf["start_time"], end=subdf["stop_time"]
)
if len(data) == 1:
return data[list(data.keys())[0]]
else:
return data
else:
warnings.warn(
"Too many metadata. Returning pandas.DataFrame, not IntervalSet",
stacklevel=2,
)
return df # Too many metadata to split the epoch
else:
return obj
def _make_tsd(obj, lazy_loading=True):
"""Helper function to make Tsd
Parameters
----------
obj : pynwb.misc.TimeSeries
NWB object
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Returns
-------
Tsd
"""
d = obj.data
if not lazy_loading:
d = d[:]
if obj.timestamps is not None:
t = obj.timestamps[:]
else:
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate
data = nap.Tsd(t=t, d=d, load_array=not lazy_loading)
return data
def _make_tsd_tensor(obj, lazy_loading=True):
"""Helper function to make TsdTensor
Parameters
----------
obj : pynwb.misc.TimeSeries
NWB object
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Returns
-------
Tsd
"""
d = obj.data
if not lazy_loading:
d = d[:]
if obj.timestamps is not None:
t = obj.timestamps[:]
else:
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate
data = nap.TsdTensor(t=t, d=d, load_array=not lazy_loading)
return data
def _make_tsd_frame(obj, lazy_loading=True):
"""Helper function to make TsdFrame
Parameters
----------
obj : pynwb.misc.TimeSeries
NWB object
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Returns
-------
Tsd
"""
d = obj.data
if not lazy_loading:
d = d[:]
if obj.timestamps is not None:
t = obj.timestamps[:]
else:
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate
if isinstance(obj, pynwb.behavior.SpatialSeries):
if obj.data.shape[1] == 2:
columns = ["x", "y"]
elif obj.data.shape[1] == 3:
columns = ["x", "y", "z"]
else:
columns = np.arange(obj.data.shape[1])
elif isinstance(obj, pynwb.ecephys.ElectricalSeries):
# (channel mapping)
try:
df = obj.electrodes.to_dataframe()
if hasattr(df, "label"):
columns = df["label"].values
else:
columns = df.index.values
except Exception:
columns = np.arange(obj.data.shape[1])
elif isinstance(obj, pynwb.ophys.RoiResponseSeries):
# (cell number)
try:
columns = obj.rois["id"][:]
except Exception:
columns = np.arange(obj.data.shape[1])
else:
columns = np.arange(obj.data.shape[1])
if len(columns) >= d.shape[1]: # Weird sometimes if background ID added
columns = columns[0 : obj.data.shape[1]]
else:
columns = np.arange(obj.data.shape[1])
data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading)
return data
def _make_tsgroup(obj, **kwargs):
"""Helper function to make TsGroup
Parameters
----------
obj : pynwb.misc.Units
NWB object
Returns
-------
TsGroup
"""
index = obj.id[:]
tsgroup = {}
for i, gr in zip(index, obj.spike_times_index[:]):
# if np.min(np.diff(gr))<0.0:
# break
tsgroup[i] = nap.Ts(t=np.array(gr))
N = len(tsgroup)
metainfo = {}
for coln in obj.colnames:
if coln == "electrode_group":
for e in [
"location",
"x",
"y",
"z",
"imp",
"filtering",
"rel_x",
"rel_y",
"rel_z",
"reference",
]:
tmp = [eg.__getattribute__(e) for eg in obj[coln] if hasattr(eg, e)]
if len(tmp) == N:
metainfo[e] = np.array(tmp)
if coln not in ["spike_times_index", "spike_times", "electrode_group"]:
col = obj[coln]
if len(col) == N:
if hasattr(col, "to_dataframe"):
df = col.to_dataframe()
df = df.sort_index()
for k in df.columns:
if not isinstance(
df[k].values[0],
(list, tuple, dict, set, pynwb.ecephys.ElectrodeGroup),
):
metainfo[k] = df[k].values
# elif not isinstance(col[0], (np.ndarray, list, tuple, dict, set)):
elif isinstance(col[0], (Number, str)):
metainfo[coln] = np.array(col[:])
else:
pass
tsgroup = nap.TsGroup(tsgroup, **metainfo)
return tsgroup
def _make_ts(obj, **kwargs):
"""Helper function to make Ts
Parameters
----------
obj : pynwb.misc.AnnotationSeries or pynwb.misc.DynamicTable
NWB object
Returns
-------
Ts or dict of Ts
"""
if hasattr(obj, "timestamps"):
data = nap.Ts(obj.timestamps[:])
else:
df = obj.to_dataframe()
data = {}
for k in df.columns:
if isinstance(k, str):
if k.endswith("_times"):
data[k] = nap.Ts(df[k].values)
if len(data) == 1:
data = data[list(data.keys())[0]]
return data
[docs]
class NWBFile(UserDict):
"""Class for reading NWB Files.
Examples
--------
>>> import pynapple as nap
>>> data = nap.load_file("my_file.nwb")
>>> data["units"]
Index rate location group
------- ------ ---------- -------
0 1.0 brain 0
1 1.0 brain 0
2 1.0 brain 0
"""
_f_eval = {
"IntervalSet": _make_interval_set,
"Tsd": _make_tsd,
"Ts": _make_ts,
"TsdFrame": _make_tsd_frame,
"TsdTensor": _make_tsd_tensor,
"TsGroup": _make_tsgroup,
}
[docs]
def __init__(self, file, lazy_loading=True):
"""
Parameters
----------
file : str or pynwb.file.NWBFile
Valid file to a NWB file
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Raises
------
FileNotFoundError
If path is invalid
RuntimeError
If file is not an instance of NWBFile
"""
# TODO: do we really need to have instantiation from file and object in the same place?
if isinstance(file, pynwb.file.NWBFile):
self.nwb = file
self.name = self.nwb.session_id
else:
path = Path(file)
if path.exists():
self.path = path
self.name = path.stem
self.io = NWBHDF5IO(path, "r")
self.nwb = self.io.read()
else:
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)
self.data = _extract_compatible_data_from_nwbfile(self.nwb)
self.key_to_id = {k: self.data[k]["id"] for k in self.data.keys()}
self._view = [[k, self.data[k]["type"]] for k in self.data.keys()]
self._lazy_loading = lazy_loading
UserDict.__init__(self, self.data)
def __str__(self):
title = self.name if isinstance(self.name, str) else "-"
headers = ["Keys", "Type"]
return (
title
+ "\n"
+ tabulate(self._view, headers=headers, tablefmt="mixed_outline")
)
# self._view = Table(title=self.name)
# self._view.add_column("Keys", justify="left", style="cyan", no_wrap=True)
# self._view.add_column("Type", style="green")
# for k in self.data.keys():
# self._view.add_row(
# k,
# self.data[k]["type"],
# )
# """View of the object"""
# with Console() as console:
# console.print(self._view)
# return ""
def __repr__(self):
"""View of the object"""
return self.__str__()
def __getitem__(self, key):
"""Get object from NWB
Parameters
----------
key : str
Returns
-------
(Ts, Tsd, TsdFrame, TsGroup, IntervalSet or dict of IntervalSet)
Raises
------
KeyError
If key is not in the dictionary
"""
if key.__hash__:
if self.__contains__(key):
if isinstance(self.data[key], dict) and "id" in self.data[key]:
obj = self.nwb.objects[self.data[key]["id"]]
try:
data = self._f_eval[self.data[key]["type"]](
obj, lazy_loading=self._lazy_loading
)
except Exception:
warnings.warn(
"Failed to build {}.\n Returning the NWB object for manual inspection".format(
self.data[key]["type"]
),
stacklevel=2,
)
data = obj
self.data[key] = data
return data
else:
return self.data[key]
else:
raise KeyError("Can't find key {} in group index.".format(key))
[docs]
def close(self):
"""Close the NWB file"""
self.io.close()