"""
Pynapple class to interface with NWB files.
Data are always lazy-loaded.
Object behaves like dictionary.
"""
import errno
import importlib
import os
import warnings
from collections import UserDict
from numbers import Number
from pathlib import Path
import numpy as np
from tabulate import tabulate
from .. import core as nap
def _get_unique_identifier(full_path_to_key):
out, count = np.unique(list(full_path_to_key.values()), return_counts=True)
if len(out) != len(full_path_to_key):
key_to_change = out[count > 1]
# Filter for ambiguous keys only
update_dict = {
key: val
for key, val in full_path_to_key.items()
if full_path_to_key[key] in key_to_change
}
for full_path, key in update_dict.items():
# Adding the most immediate parent path until disambiguation
base_parts = full_path.split("/")
relative_parts = key.split("/")
new_key = "/".join(base_parts[-len(relative_parts) - 1 :])
if new_key.startswith("/"):
new_key = new_key[1:]
update_dict[full_path] = new_key
update_dict = _get_unique_identifier(update_dict)
full_path_to_key.update(update_dict)
return full_path_to_key
def _get_full_path(path, obj):
if hasattr(obj, "parent"): # Better be safe here
if obj.parent is None:
return "/" + path
else:
if hasattr(obj.parent, "name"): # and extra safe
if obj.parent.name == "root":
return "/" + path
else:
return _get_full_path(obj.parent.name + "/" + path, obj.parent)
else:
return "/" + path
else:
return "/" + path
def iterate_over_nwb(nwbfile):
pynwb = importlib.import_module("pynwb")
for oid, obj in nwbfile.objects.items():
if isinstance(obj, pynwb.misc.DynamicTable) and any(
[i.name.endswith("_times_index") for i in obj.columns]
):
# data["units"] = {"id": oid, "type": "TsGroup"}
yield obj, {"id": oid, "type": "TsGroup"}
elif isinstance(obj, pynwb.epoch.TimeIntervals):
# Supposedly IntervalsSets
yield obj, {"id": oid, "type": "IntervalSet"}
elif isinstance(obj, pynwb.misc.DynamicTable) and any(
[i.name.endswith("_times") for i in obj.columns]
):
# Supposedly Timestamps
yield obj, {"id": oid, "type": "Ts"}
elif isinstance(obj, pynwb.misc.AnnotationSeries):
# Old timestamps version
yield obj, {"id": oid, "type": "Ts"}
elif isinstance(obj, pynwb.misc.TimeSeries):
if len(obj.data.shape) > 2:
yield obj, {"id": oid, "type": "TsdTensor"}
elif len(obj.data.shape) == 2:
yield obj, {"id": oid, "type": "TsdFrame"}
elif len(obj.data.shape) == 1:
yield obj, {"id": oid, "type": "Tsd"}
def _extract_compatible_data_from_nwbfile(nwbfile):
"""Extract all the NWB objects that can be converted to a pynapple object. If two objects have the same names, they
are distinguished by adding their module name to their path.
Parameters
----------
nwbfile : pynwb.file.NWBFile
Instance of NWB file
Returns
-------
dict
Dictionary containing all the object found and their type in pynapple.
"""
return {
_get_full_path(obj.name, obj): out for obj, out in iterate_over_nwb(nwbfile)
}
def _make_interval_set(obj, **kwargs):
"""Helper function to make IntervalSet
Parameters
----------
obj : pynwb.epoch.TimeIntervals
NWB object
Returns
-------
IntervalSet or dict of IntervalSet or pandas.DataFrame
If contains multiple epochs, a dictionary of IntervalSet is returned.
It too many metadata, the function returns the output of nwbfile.trials.to_dataframe()
"""
if hasattr(obj, "to_dataframe"):
df = obj.to_dataframe()
if hasattr(df, "start_time") and hasattr(df, "stop_time"):
df = df.rename(columns={"start_time": "start", "stop_time": "end"})
# create from full dataframe to ensure that metadata is associated correctly
data = nap.IntervalSet(df)
return data
else:
return obj
def _make_tsd(obj, lazy_loading=True):
"""Helper function to make Tsd
Parameters
----------
obj : pynwb.misc.TimeSeries
NWB object
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Returns
-------
Tsd
"""
d = obj.data
if not lazy_loading:
d = d[:]
if obj.timestamps is not None:
t = obj.timestamps[:]
else:
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate
data = nap.Tsd(t=t, d=d, load_array=not lazy_loading)
return data
def _make_tsd_tensor(obj, lazy_loading=True):
"""Helper function to make TsdTensor
Parameters
----------
obj : pynwb.misc.TimeSeries
NWB object
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Returns
-------
Tsd
"""
d = obj.data
if not lazy_loading:
d = d[:]
if obj.timestamps is not None:
t = obj.timestamps[:]
else:
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate
data = nap.TsdTensor(t=t, d=d, load_array=not lazy_loading)
return data
def _make_tsd_frame(obj, lazy_loading=True):
"""Helper function to make TsdFrame
Parameters
----------
obj : pynwb.misc.TimeSeries
NWB object
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Returns
-------
Tsd
"""
pynwb = importlib.import_module("pynwb")
d = obj.data
if not lazy_loading:
d = d[:]
if obj.timestamps is not None:
t = obj.timestamps[:]
else:
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate
if isinstance(obj, pynwb.behavior.SpatialSeries):
if obj.data.shape[1] == 2:
columns = ["x", "y"]
elif obj.data.shape[1] == 3:
columns = ["x", "y", "z"]
else:
columns = np.arange(obj.data.shape[1])
elif isinstance(obj, pynwb.ecephys.ElectricalSeries):
# (channel mapping)
try:
df = obj.electrodes.to_dataframe()
if hasattr(df, "label"):
columns = df["label"].values
else:
columns = df.index.values
except Exception:
columns = np.arange(obj.data.shape[1])
elif isinstance(obj, pynwb.ophys.RoiResponseSeries):
# (cell number)
try:
columns = obj.rois["id"][:]
except Exception:
columns = np.arange(obj.data.shape[1])
else:
columns = np.arange(obj.data.shape[1])
if len(columns) >= d.shape[1]: # Weird sometimes if background ID added
columns = columns[0 : obj.data.shape[1]]
else:
columns = np.arange(obj.data.shape[1])
data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading)
return data
def _make_tsgroup(obj, **kwargs):
"""Helper function to make TsGroup
Parameters
----------
obj : pynwb.misc.Units
NWB object
Returns
-------
TsGroup
"""
pynwb = importlib.import_module("pynwb")
index = obj.id[:]
tsgroup = {}
for i, gr in zip(index, obj.spike_times_index[:]):
# if np.min(np.diff(gr))<0.0:
# break
tsgroup[i] = nap.Ts(t=np.array(gr))
N = len(tsgroup)
metainfo = {}
for coln in obj.colnames:
if coln == "electrode_group":
for e in [
"location",
"x",
"y",
"z",
"imp",
"filtering",
"rel_x",
"rel_y",
"rel_z",
"reference",
]:
tmp = [eg.__getattribute__(e) for eg in obj[coln] if hasattr(eg, e)]
if len(tmp) == N:
metainfo[e] = np.array(tmp)
if coln not in ["spike_times_index", "spike_times", "electrode_group"]:
col = obj[coln]
if len(col) == N:
if hasattr(col, "to_dataframe"):
df = col.to_dataframe()
df = df.sort_index()
for k in df.columns:
if not isinstance(
df[k].values[0],
(list, tuple, dict, set, pynwb.ecephys.ElectrodeGroup),
):
metainfo[k] = df[k].values
# elif not isinstance(col[0], (np.ndarray, list, tuple, dict, set)):
elif isinstance(col[0], (Number, str)):
metainfo[coln] = np.array(col[:])
else:
pass
tsgroup = nap.TsGroup(tsgroup, metadata=metainfo)
return tsgroup
def _make_ts(obj, **kwargs):
"""Helper function to make Ts
Parameters
----------
obj : pynwb.misc.AnnotationSeries or pynwb.misc.DynamicTable
NWB object
Returns
-------
Ts or dict of Ts
"""
if hasattr(obj, "timestamps"):
data = nap.Ts(obj.timestamps[:])
else:
df = obj.to_dataframe()
data = {}
for k in df.columns:
if isinstance(k, str):
if k.endswith("_times"):
data[k] = nap.Ts(df[k].values)
if len(data) == 1:
data = data[list(data.keys())[0]]
return data
[docs]
class NWBFile(UserDict):
"""Class for reading NWB Files.
Examples
--------
>>> import pynapple as nap
>>> data = nap.load_file("my_file.nwb")
>>> data["units"]
Index rate location group
------- ------ ---------- -------
0 1.0 brain 0
1 1.0 brain 0
2 1.0 brain 0
"""
_f_eval = {
"IntervalSet": _make_interval_set,
"Tsd": _make_tsd,
"Ts": _make_ts,
"TsdFrame": _make_tsd_frame,
"TsdTensor": _make_tsd_tensor,
"TsGroup": _make_tsgroup,
}
[docs]
def __init__(self, file, lazy_loading=True):
"""
Parameters
----------
file : str or pynwb.file.NWBFile
Valid file to a NWB file
lazy_loading: bool
If True return a memory-view of the data, load otherwise.
Raises
------
FileNotFoundError
If path is invalid
RuntimeError
If file is not an instance of NWBFile
"""
# TODO: do we really need to have instantiation from file and object in the same place?
pynwb = importlib.import_module("pynwb")
NWBHDF5IO = pynwb.NWBHDF5IO
if isinstance(file, pynwb.file.NWBFile):
self.nwb = file
self.name = self.nwb.session_id
else:
path = Path(file)
if path.exists():
self.path = path
self.name = path.stem
self.io = NWBHDF5IO(path, "r")
self.nwb = self.io.read()
else:
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file)
# Get a dictionary with full_path -> {'id', 'type'}
self.data = _extract_compatible_data_from_nwbfile(self.nwb)
# Need to check if some object names are doublons
self.full_path_to_key = _get_unique_identifier(
{p: os.path.basename(p) for p in self.data.keys()}
)
# Creating the reverse mapping for the user : key -> full_path and key -> {'id', 'type'}
self.key_to_full_path = {v: k for k, v in self.full_path_to_key.items()}
self.data = {self.full_path_to_key[p]: self.data[p] for p in self.data.keys()}
# Mapping unique path identifier to id
self.key_to_id = {k: self.data[k]["id"] for k in self.data.keys()}
self._view = [[k, self.data[k]["type"]] for k in self.data.keys()]
self._lazy_loading = lazy_loading
UserDict.__init__(self, self.data)
def __str__(self):
title = self.name if isinstance(self.name, str) else "-"
headers = ["Keys", "Type"]
return (
title
+ "\n"
+ tabulate(self._view, headers=headers, tablefmt="mixed_outline")
)
# self._view = Table(title=self.name)
# self._view.add_column("Keys", justify="left", style="cyan", no_wrap=True)
# self._view.add_column("Type", style="green")
# for k in self.data.keys():
# self._view.add_row(
# k,
# self.data[k]["type"],
# )
# """View of the object"""
# with Console() as console:
# console.print(self._view)
# return ""
def __repr__(self):
"""View of the object"""
return self.__str__()
def __getitem__(self, key):
"""Get object from NWB
Parameters
----------
key : str
Returns
-------
(Ts, Tsd, TsdFrame, TsGroup, IntervalSet or dict of IntervalSet)
Raises
------
KeyError
If key is not in the dictionary
"""
if key.__hash__:
if key.startswith("/"): # allow user to specify the full path to the object
if key in self.full_path_to_key:
return self[self.full_path_to_key[key]]
else:
raise KeyError("Can't find key {} in group index.".format(key))
if self.__contains__(key):
if isinstance(self.data[key], dict) and "id" in self.data[key]:
obj = self.nwb.objects[self.data[key]["id"]]
try:
data = self._f_eval[self.data[key]["type"]](
obj, lazy_loading=self._lazy_loading
)
except Exception:
warnings.warn(
"Failed to build {}.\n Returning the NWB object for manual inspection".format(
self.data[key]["type"]
),
stacklevel=2,
)
data = obj
self.data[key] = data
return data
else:
return self.data[key]
else:
raise KeyError("Can't find key {} in group index.".format(key))
[docs]
def close(self):
"""Close the NWB file"""
self.io.close()
[docs]
def keys(self):
"""
Return keys of NWBFile
Returns
-------
list
List of keys
"""
return list(self.data.keys())
[docs]
def items(self):
"""
Return a list of key/object.
Returns
-------
list
List of tuples
"""
return list(self.data.items())
[docs]
def values(self):
"""
Return a list of all the objects
Returns
-------
list
List of objects
"""
return list(self.data.values())