Source code for pycrostates.viz.segmentation

"""Visualisation module for plotting segmentations."""

from __future__ import annotations  # c.f. PEP 563, PEP 649

from typing import TYPE_CHECKING

import numpy as np
from matplotlib import colormaps, colors
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mne import BaseEpochs
from mne.io import BaseRaw
from mne.utils import check_version

from ..utils._checks import _check_type, _ensure_valid_show
from ..utils._docs import fill_doc
from ..utils._logs import logger, verbose

if TYPE_CHECKING:
    from typing import Optional, Union

    from .._typing import ScalarFloatArray, ScalarIntArray


[docs] @fill_doc def plot_raw_segmentation( labels: ScalarIntArray, raw: BaseRaw, n_clusters: int, cluster_names: list[str] = None, tmin: Optional[Union[int, float]] = None, tmax: Optional[Union[int, float]] = None, cmap: Optional[str] = None, axes: Optional[Axes] = None, cbar_axes: Optional[Axes] = None, *, block: bool = False, show: Optional[bool] = None, verbose: Optional[str] = None, **kwargs, ): """Plot raw segmentation. Parameters ---------- %(labels_raw)s raw : Raw MNE `~mne.io.Raw` instance. %(n_clusters)s %(cluster_names)s %(tmin_raw)s %(tmax_raw)s %(cmap)s %(axes_seg)s %(axes_cbar)s %(block)s %(show)s %(verbose)s **kwargs Kwargs are passed to ``axes.plot``. Returns ------- fig : Figure Matplotlib figure(s) on which topographic maps are plotted. """ _check_type(labels, (np.ndarray,), "labels") # 1D array (n_times, ) if labels.ndim != 1: raise ValueError("Argument 'labels' should be a 1D array.") _check_type(raw, (BaseRaw,), "raw") _check_type(block, (bool,), "block") show = _ensure_valid_show(show) data = raw.get_data(tmin=tmin, tmax=tmax) gfp = np.std(data, axis=0) # build times array instead of using raw.times because the time-based # selection in MNE can be a bit funky. if tmin is None: tmin = raw.times[0] times = np.arange( tmin, tmin + gfp.size / raw.info["sfreq"], 1 / raw.info["sfreq"], ) labels = labels[(times * raw.info["sfreq"]).astype(int)] # make sure shapes are correct if data.shape[1] != labels.size: raise ValueError( "Argument 'labels' and 'raw' do not have the same number of samples." ) fig, axes = _plot_segmentation( labels, gfp, times, n_clusters, cluster_names, cmap, axes, cbar_axes, verbose=verbose, **kwargs, ) # format axes.set_xlabel("Time (s)") if show: plt.show(block=block) return fig
[docs] @fill_doc def plot_epoch_segmentation( labels: ScalarIntArray, epochs: BaseEpochs, n_clusters: int, cluster_names: list[str] = None, cmap: Optional[str] = None, axes: Optional[Axes] = None, cbar_axes: Optional[Axes] = None, *, block: bool = False, show: Optional[bool] = None, verbose: Optional[str] = None, **kwargs, ): """ Plot epochs segmentation. Parameters ---------- %(labels_epo)s epochs : Epochs MNE `~mne.Epochs` instance. %(n_clusters)s %(cluster_names)s %(cmap)s %(axes_seg)s %(axes_cbar)s %(block)s %(show)s %(verbose)s **kwargs Kwargs are passed to ``axes.plot``. Returns ------- fig : Figure Matplotlib figure on which topographic maps are plotted. """ _check_type(labels, (np.ndarray,), "labels") # 1D array (n_times, ) if labels.ndim != 2: raise ValueError("Argument labels should be a 2D array.") _check_type(epochs, (BaseEpochs,), "epochs") _check_type(block, (bool,), "block") show = _ensure_valid_show(show) kwargs_epochs = dict(copy=False) if check_version("mne", "1.6") else dict() data = epochs.get_data(**kwargs_epochs).swapaxes(0, 1) data = data.reshape(data.shape[0], -1) gfp = np.std(data, axis=0) times = np.arange(0, data.shape[-1]) labels = labels.reshape(-1) # make sure shapes are correct if data.shape[1] != labels.size: raise ValueError( "Argument 'labels' and 'epochs' do not have the same number of samples." ) fig, axes = _plot_segmentation( labels, gfp, times, n_clusters, cluster_names, cmap, axes, cbar_axes, verbose=verbose, **kwargs, ) # format x_ticks = np.linspace( epochs.times.size // 2, data.shape[-1] - epochs.times.size // 2, len(epochs), ) x_tick_labels = [str(i) for i in range(1, len(epochs) + 1)] axes.set_xticks(x_ticks, x_tick_labels) axes.set_xlabel("Epochs") # add epoch lines x = np.linspace( epochs.times.size, data.shape[-1] - epochs.times.size, len(epochs) - 1, ) axes.vlines(x, 0, gfp.max(), linestyles="dashed", colors="black") if show: plt.show(block=block) return fig
@verbose def _plot_segmentation( labels: ScalarIntArray, gfp: ScalarFloatArray, times: ScalarFloatArray, n_clusters: int, cluster_names: list[str] = None, cmap: Optional[Union[str, colors.Colormap]] = None, axes: Optional[Axes] = None, cbar_axes: Optional[Axes] = None, *, verbose: Optional[str] = None, **kwargs, ) -> tuple[plt.Figure, Axes]: """Code snippet to plot segmentation for raw and epochs.""" _check_type(labels, (np.ndarray,), "labels") # 1D array (n_times, ) _check_type(gfp, (np.ndarray,), "gfp") # 1D array (n_times, ) _check_type(times, (np.ndarray,), "times") # 1D array (n_times, ) _check_type(n_clusters, ("int",), "n_clusters") if n_clusters <= 0: raise ValueError( f"Provided number of clusters {n_clusters} is invalid. The number of " "clusters must be strictly positive." ) _check_type(cluster_names, (None, list, tuple), "cluster_names") _check_type(cmap, (None, str, colors.Colormap), "cmap") _check_type(axes, (None, Axes), "ax") _check_type(cbar_axes, (None, Axes), "cbar_ax") # check cluster_names if cluster_names is None: cluster_names = [str(k) for k in range(1, n_clusters + 1)] if len(cluster_names) != n_clusters: raise ValueError( "Argument 'cluster_names' should have the 'n_clusters' elements. " f"Provided: {len(cluster_names)} names for {n_clusters} clusters." ) if axes is None: fig, axes = plt.subplots(1, 1, layout="constrained") else: fig = axes.get_figure() # add color and linewidth if absent from kwargs if "color" not in kwargs: kwargs["color"] = "black" if "linewidth" not in kwargs: kwargs["linewidth"] = 0.2 # define states and colors state_labels = [-1] + list(range(n_clusters)) cluster_names = ["unlabeled"] + cluster_names n_colors = n_clusters + 1 cmap = _compatibility_cmap(cmap, n_colors) # plot axes.plot(times, gfp, **kwargs) for state, color in zip(state_labels, cmap.colors): pos = np.where(labels == state)[0] if len(pos): pos = np.unique([pos, pos + 1]) x = np.zeros(labels.shape).astype(bool) if pos[-1] == labels.size: pos = pos[:-1] x[pos] = True axes.fill_between( times, gfp, color=color, where=x, step=None, interpolate=False ) logger.info( "For visualization purposes, the last segment appears truncated by 1 sample. " "In the case where the last segment is 1 sample long, it does not appear." ) # commonm formatting axes.set_title("Segmentation") axes.autoscale(tight=True) # color bar norm = colors.Normalize(vmin=0, vmax=n_colors) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) colorbar = plt.colorbar( sm, cax=cbar_axes, ax=axes, ticks=[i + 0.5 for i in range(n_colors)] ) colorbar.ax.set_yticklabels(cluster_names) return fig, axes def _compatibility_cmap( cmap: Optional[Union[str, colors.Colormap]], n_colors: int, ): """Convert the 'cmap' argument to a colormap. Matplotlib 3.6 introduced a deprecation of plt.cm.get_cmap(). When support for the 3.6 version is dropped, this checker can be removed. """ if check_version("matplotlib", "3.6"): if cmap is None: cmap = colormaps["viridis"] elif isinstance(cmap, str): cmap = colormaps[cmap] # the cmap name is checked by matplotlib cmap = cmap.resampled(n_colors) else: if isinstance(cmap, (str, type(None))): cmap = plt.cm.get_cmap(cmap, n_colors) else: raise RuntimeError( "User-defined colormaps are supported as of matplotlib 3.6 " "and above. Please update matplotlib or provide a colormap " "name as a string." ) return cmap