Source code for pynapple.io.interface_nwb

"""
Pynapple class to interface with NWB files.
Data are always lazy-loaded.
Object behaves like dictionary.
"""

import errno
import importlib
import os
import warnings
from collections import UserDict
from numbers import Number
from pathlib import Path

import numpy as np
from tabulate import tabulate

from .. import core as nap


def _get_unique_identifier(full_path_to_key):
    out, count = np.unique(list(full_path_to_key.values()), return_counts=True)
    if len(out) != len(full_path_to_key):
        key_to_change = out[count > 1]
        # Filter for ambiguous keys only
        update_dict = {
            key: val
            for key, val in full_path_to_key.items()
            if full_path_to_key[key] in key_to_change
        }
        for full_path, key in update_dict.items():
            # Adding the most immediate parent path until disambiguation
            base_parts = full_path.split("/")
            relative_parts = key.split("/")
            new_key = "/".join(base_parts[-len(relative_parts) - 1 :])
            if new_key.startswith("/"):
                new_key = new_key[1:]
            update_dict[full_path] = new_key
        update_dict = _get_unique_identifier(update_dict)
        full_path_to_key.update(update_dict)
    return full_path_to_key


def _get_full_path(path, obj):
    if hasattr(obj, "parent"):  # Better be safe here
        if obj.parent is None:
            return "/" + path
        else:
            if hasattr(obj.parent, "name"):  # and extra safe
                if obj.parent.name == "root":
                    return "/" + path
                else:
                    return _get_full_path(obj.parent.name + "/" + path, obj.parent)
            else:
                return "/" + path
    else:
        return "/" + path


def iterate_over_nwb(nwbfile):
    pynwb = importlib.import_module("pynwb")
    for oid, obj in nwbfile.objects.items():
        if isinstance(obj, pynwb.misc.DynamicTable) and any(
            [i.name.endswith("_times_index") for i in obj.columns]
        ):
            # data["units"] = {"id": oid, "type": "TsGroup"}
            yield obj, {"id": oid, "type": "TsGroup"}

        elif isinstance(obj, pynwb.epoch.TimeIntervals):
            # Supposedly IntervalsSets
            yield obj, {"id": oid, "type": "IntervalSet"}

        elif isinstance(obj, pynwb.misc.DynamicTable) and any(
            [i.name.endswith("_times") for i in obj.columns]
        ):
            # Supposedly Timestamps
            yield obj, {"id": oid, "type": "Ts"}

        elif isinstance(obj, pynwb.misc.AnnotationSeries):
            # Old timestamps version
            yield obj, {"id": oid, "type": "Ts"}

        elif isinstance(obj, pynwb.misc.TimeSeries):
            if len(obj.data.shape) > 2:
                yield obj, {"id": oid, "type": "TsdTensor"}

            elif len(obj.data.shape) == 2:
                yield obj, {"id": oid, "type": "TsdFrame"}

            elif len(obj.data.shape) == 1:
                yield obj, {"id": oid, "type": "Tsd"}


def _extract_compatible_data_from_nwbfile(nwbfile):
    """Extract all the NWB objects that can be converted to a pynapple object. If two objects have the same names, they
    are distinguished by adding their module name to their path.

    Parameters
    ----------
    nwbfile : pynwb.file.NWBFile
        Instance of NWB file

    Returns
    -------
    dict
        Dictionary containing all the object found and their type in pynapple.
    """
    return {
        _get_full_path(obj.name, obj): out for obj, out in iterate_over_nwb(nwbfile)
    }


def _make_interval_set(obj, **kwargs):
    """Helper function to make IntervalSet

    Parameters
    ----------
    obj : pynwb.epoch.TimeIntervals
        NWB object

    Returns
    -------
    IntervalSet or dict of IntervalSet or pandas.DataFrame
        If contains multiple epochs, a dictionary of IntervalSet is returned.
        It too many metadata, the function returns the output of nwbfile.trials.to_dataframe()
    """
    if hasattr(obj, "to_dataframe"):
        df = obj.to_dataframe()

        if hasattr(df, "start_time") and hasattr(df, "stop_time"):
            df = df.rename(columns={"start_time": "start", "stop_time": "end"})
            # create from full dataframe to ensure that metadata is associated correctly
            data = nap.IntervalSet(df)
            return data

    else:
        return obj


def _make_tsd(obj, lazy_loading=True):
    """Helper function to make Tsd

    Parameters
    ----------
    obj : pynwb.misc.TimeSeries
        NWB object
    lazy_loading: bool
        If True return a memory-view of the data, load otherwise.

    Returns
    -------
    Tsd

    """

    d = obj.data
    if not lazy_loading:
        d = d[:]

    if obj.timestamps is not None:
        t = obj.timestamps[:]
    else:
        t = obj.starting_time + np.arange(obj.num_samples) / obj.rate

    data = nap.Tsd(t=t, d=d, load_array=not lazy_loading)

    return data


def _make_tsd_tensor(obj, lazy_loading=True):
    """Helper function to make TsdTensor

    Parameters
    ----------
    obj : pynwb.misc.TimeSeries
        NWB object
    lazy_loading: bool
        If True return a memory-view of the data, load otherwise.

    Returns
    -------
    Tsd

    """

    d = obj.data
    if not lazy_loading:
        d = d[:]

    if obj.timestamps is not None:
        t = obj.timestamps[:]
    else:
        t = obj.starting_time + np.arange(obj.num_samples) / obj.rate

    data = nap.TsdTensor(t=t, d=d, load_array=not lazy_loading)

    return data


def _make_tsd_frame(obj, lazy_loading=True):
    """Helper function to make TsdFrame

    Parameters
    ----------
    obj : pynwb.misc.TimeSeries
        NWB object
    lazy_loading: bool
        If True return a memory-view of the data, load otherwise.

    Returns
    -------
    Tsd

    """
    pynwb = importlib.import_module("pynwb")

    d = obj.data
    if not lazy_loading:
        d = d[:]

    if obj.timestamps is not None:
        t = obj.timestamps[:]
    else:
        t = obj.starting_time + np.arange(obj.num_samples) / obj.rate

    if isinstance(obj, pynwb.behavior.SpatialSeries):
        if obj.data.shape[1] == 2:
            columns = ["x", "y"]
        elif obj.data.shape[1] == 3:
            columns = ["x", "y", "z"]
        else:
            columns = np.arange(obj.data.shape[1])

    elif isinstance(obj, pynwb.ecephys.ElectricalSeries):
        # (channel mapping)
        try:
            df = obj.electrodes.to_dataframe()
            if hasattr(df, "label"):
                columns = df["label"].values
            else:
                columns = df.index.values
        except Exception:
            columns = np.arange(obj.data.shape[1])

    elif isinstance(obj, pynwb.ophys.RoiResponseSeries):
        # (cell number)
        try:
            columns = obj.rois["id"][:]
        except Exception:
            columns = np.arange(obj.data.shape[1])

    else:
        columns = np.arange(obj.data.shape[1])

    if len(columns) >= d.shape[1]:  # Weird sometimes if background ID added
        columns = columns[0 : obj.data.shape[1]]
    else:
        columns = np.arange(obj.data.shape[1])

    data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading)

    return data


def _make_tsgroup(obj, **kwargs):
    """Helper function to make TsGroup

    Parameters
    ----------
    obj : pynwb.misc.Units
        NWB object

    Returns
    -------
    TsGroup

    """
    pynwb = importlib.import_module("pynwb")
    index = obj.id[:]
    tsgroup = {}
    for i, gr in zip(index, obj.spike_times_index[:]):
        # if np.min(np.diff(gr))<0.0:
        #     break
        tsgroup[i] = nap.Ts(t=np.array(gr))

    N = len(tsgroup)
    metainfo = {}
    for coln in obj.colnames:
        if coln == "electrode_group":
            for e in [
                "location",
                "x",
                "y",
                "z",
                "imp",
                "filtering",
                "rel_x",
                "rel_y",
                "rel_z",
                "reference",
            ]:
                tmp = [eg.__getattribute__(e) for eg in obj[coln] if hasattr(eg, e)]
                if len(tmp) == N:
                    metainfo[e] = np.array(tmp)

        if coln not in ["spike_times_index", "spike_times", "electrode_group"]:
            col = obj[coln]
            if len(col) == N:
                if hasattr(col, "to_dataframe"):
                    df = col.to_dataframe()
                    df = df.sort_index()
                    for k in df.columns:
                        if not isinstance(
                            df[k].values[0],
                            (list, tuple, dict, set, pynwb.ecephys.ElectrodeGroup),
                        ):
                            metainfo[k] = df[k].values
                # elif not isinstance(col[0], (np.ndarray, list, tuple, dict, set)):
                elif isinstance(col[0], (Number, str)):
                    metainfo[coln] = np.array(col[:])
                else:
                    pass

    tsgroup = nap.TsGroup(tsgroup, metadata=metainfo)

    return tsgroup


def _make_ts(obj, **kwargs):
    """Helper function to make Ts

    Parameters
    ----------
    obj : pynwb.misc.AnnotationSeries or pynwb.misc.DynamicTable
        NWB object

    Returns
    -------
    Ts or dict of Ts

    """
    if hasattr(obj, "timestamps"):
        data = nap.Ts(obj.timestamps[:])
    else:
        df = obj.to_dataframe()
        data = {}
        for k in df.columns:
            if isinstance(k, str):
                if k.endswith("_times"):
                    data[k] = nap.Ts(df[k].values)
        if len(data) == 1:
            data = data[list(data.keys())[0]]

    return data


[docs] class NWBFile(UserDict): """Class for reading NWB Files. Examples -------- >>> import pynapple as nap >>> data = nap.load_file("my_file.nwb") >>> data["units"] Index rate location group ------- ------ ---------- ------- 0 1.0 brain 0 1 1.0 brain 0 2 1.0 brain 0 """ _f_eval = { "IntervalSet": _make_interval_set, "Tsd": _make_tsd, "Ts": _make_ts, "TsdFrame": _make_tsd_frame, "TsdTensor": _make_tsd_tensor, "TsGroup": _make_tsgroup, }
[docs] def __init__(self, file, lazy_loading=True): """ Parameters ---------- file : str or pynwb.file.NWBFile Valid file to a NWB file lazy_loading: bool If True return a memory-view of the data, load otherwise. Raises ------ FileNotFoundError If path is invalid RuntimeError If file is not an instance of NWBFile """ # TODO: do we really need to have instantiation from file and object in the same place? pynwb = importlib.import_module("pynwb") NWBHDF5IO = pynwb.NWBHDF5IO if isinstance(file, pynwb.file.NWBFile): self.nwb = file self.name = self.nwb.session_id else: path = Path(file) if path.exists(): self.path = path self.name = path.stem self.io = NWBHDF5IO(path, "r") self.nwb = self.io.read() else: raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) # Get a dictionary with full_path -> {'id', 'type'} self.data = _extract_compatible_data_from_nwbfile(self.nwb) # Need to check if some object names are doublons self.full_path_to_key = _get_unique_identifier( {p: os.path.basename(p) for p in self.data.keys()} ) # Creating the reverse mapping for the user : key -> full_path and key -> {'id', 'type'} self.key_to_full_path = {v: k for k, v in self.full_path_to_key.items()} self.data = {self.full_path_to_key[p]: self.data[p] for p in self.data.keys()} # Mapping unique path identifier to id self.key_to_id = {k: self.data[k]["id"] for k in self.data.keys()} self._view = [[k, self.data[k]["type"]] for k in self.data.keys()] self._lazy_loading = lazy_loading UserDict.__init__(self, self.data)
def __str__(self): title = self.name if isinstance(self.name, str) else "-" headers = ["Keys", "Type"] return ( title + "\n" + tabulate(self._view, headers=headers, tablefmt="mixed_outline") ) # self._view = Table(title=self.name) # self._view.add_column("Keys", justify="left", style="cyan", no_wrap=True) # self._view.add_column("Type", style="green") # for k in self.data.keys(): # self._view.add_row( # k, # self.data[k]["type"], # ) # """View of the object""" # with Console() as console: # console.print(self._view) # return "" def __repr__(self): """View of the object""" return self.__str__() def __getitem__(self, key): """Get object from NWB Parameters ---------- key : str Returns ------- (Ts, Tsd, TsdFrame, TsGroup, IntervalSet or dict of IntervalSet) Raises ------ KeyError If key is not in the dictionary """ if key.__hash__: if key.startswith("/"): # allow user to specify the full path to the object if key in self.full_path_to_key: return self[self.full_path_to_key[key]] else: raise KeyError("Can't find key {} in group index.".format(key)) if self.__contains__(key): if isinstance(self.data[key], dict) and "id" in self.data[key]: obj = self.nwb.objects[self.data[key]["id"]] try: data = self._f_eval[self.data[key]["type"]]( obj, lazy_loading=self._lazy_loading ) except Exception: warnings.warn( "Failed to build {}.\n Returning the NWB object for manual inspection".format( self.data[key]["type"] ), stacklevel=2, ) data = obj self.data[key] = data return data else: return self.data[key] else: raise KeyError("Can't find key {} in group index.".format(key))
[docs] def close(self): """Close the NWB file""" self.io.close()
[docs] def keys(self): """ Return keys of NWBFile Returns ------- list List of keys """ return list(self.data.keys())
[docs] def items(self): """ Return a list of key/object. Returns ------- list List of tuples """ return list(self.data.items())
[docs] def values(self): """ Return a list of all the objects Returns ------- list List of objects """ return list(self.data.values())