"""Visualisation module for plotting segmentations."""
from __future__ import annotations # c.f. PEP 563, PEP 649
from typing import TYPE_CHECKING
import numpy as np
from matplotlib import colormaps, colors
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mne import BaseEpochs
from mne.io import BaseRaw
from ..utils._checks import _check_type, _ensure_valid_show
from ..utils._docs import fill_doc
from ..utils._logs import logger, verbose
if TYPE_CHECKING:
from .._typing import ScalarFloatArray, ScalarIntArray
[docs]
@fill_doc
def plot_raw_segmentation(
labels: ScalarIntArray,
raw: BaseRaw,
n_clusters: int,
cluster_names: list[str] = None,
tmin: int | float | None = None,
tmax: int | float | None = None,
cmap: str | None = None,
axes: Axes | None = None,
cbar_axes: Axes | None = None,
*,
block: bool = False,
show: bool | None = None,
verbose: str | None = None,
**kwargs,
):
"""Plot raw segmentation.
Parameters
----------
%(labels_raw)s
raw : Raw
MNE `~mne.io.Raw` instance.
%(n_clusters)s
%(cluster_names)s
%(tmin_raw)s
%(tmax_raw)s
%(cmap)s
%(axes_seg)s
%(axes_cbar)s
%(block)s
%(show)s
%(verbose)s
**kwargs
Kwargs are passed to ``axes.plot``.
Returns
-------
fig : Figure
Matplotlib figure(s) on which topographic maps are plotted.
"""
_check_type(labels, (np.ndarray,), "labels") # 1D array (n_times, )
if labels.ndim != 1:
raise ValueError("Argument 'labels' should be a 1D array.")
_check_type(raw, (BaseRaw,), "raw")
_check_type(block, (bool,), "block")
show = _ensure_valid_show(show)
data = raw.get_data(tmin=tmin, tmax=tmax)
gfp = np.std(data, axis=0)
# build times array instead of using raw.times because the time-based
# selection in MNE can be a bit funky.
if tmin is None:
tmin = raw.times[0]
times = np.arange(
tmin,
tmin + gfp.size / raw.info["sfreq"],
1 / raw.info["sfreq"],
)
labels = labels[(times * raw.info["sfreq"]).astype(int)]
# make sure shapes are correct
if data.shape[1] != labels.size:
raise ValueError(
"Argument 'labels' and 'raw' do not have the same number of samples."
)
fig, axes = _plot_segmentation(
labels,
gfp,
times,
n_clusters,
cluster_names,
cmap,
axes,
cbar_axes,
verbose=verbose,
**kwargs,
)
# format
axes.set_xlabel("Time (s)")
if show:
plt.show(block=block)
return fig
[docs]
@fill_doc
def plot_epoch_segmentation(
labels: ScalarIntArray,
epochs: BaseEpochs,
n_clusters: int,
cluster_names: list[str] = None,
cmap: str | None = None,
axes: Axes | None = None,
cbar_axes: Axes | None = None,
*,
block: bool = False,
show: bool | None = None,
verbose: str | None = None,
**kwargs,
):
"""
Plot epochs segmentation.
Parameters
----------
%(labels_epo)s
epochs : Epochs
MNE `~mne.Epochs` instance.
%(n_clusters)s
%(cluster_names)s
%(cmap)s
%(axes_seg)s
%(axes_cbar)s
%(block)s
%(show)s
%(verbose)s
**kwargs
Kwargs are passed to ``axes.plot``.
Returns
-------
fig : Figure
Matplotlib figure on which topographic maps are plotted.
"""
_check_type(labels, (np.ndarray,), "labels") # 1D array (n_times, )
if labels.ndim != 2:
raise ValueError("Argument labels should be a 2D array.")
_check_type(epochs, (BaseEpochs,), "epochs")
_check_type(block, (bool,), "block")
show = _ensure_valid_show(show)
data = epochs.get_data(copy=False).swapaxes(0, 1)
data = data.reshape(data.shape[0], -1)
gfp = np.std(data, axis=0)
times = np.arange(0, data.shape[-1])
labels = labels.reshape(-1)
# make sure shapes are correct
if data.shape[1] != labels.size:
raise ValueError(
"Argument 'labels' and 'epochs' do not have the same number of samples."
)
fig, axes = _plot_segmentation(
labels,
gfp,
times,
n_clusters,
cluster_names,
cmap,
axes,
cbar_axes,
verbose=verbose,
**kwargs,
)
# format
x_ticks = np.linspace(
epochs.times.size // 2,
data.shape[-1] - epochs.times.size // 2,
len(epochs),
)
x_tick_labels = [str(i) for i in range(1, len(epochs) + 1)]
axes.set_xticks(x_ticks, x_tick_labels)
axes.set_xlabel("Epochs")
# add epoch lines
x = np.linspace(
epochs.times.size,
data.shape[-1] - epochs.times.size,
len(epochs) - 1,
)
axes.vlines(x, 0, gfp.max(), linestyles="dashed", colors="black")
if show:
plt.show(block=block)
return fig
@verbose
def _plot_segmentation(
labels: ScalarIntArray,
gfp: ScalarFloatArray,
times: ScalarFloatArray,
n_clusters: int,
cluster_names: list[str] = None,
cmap: str | colors.Colormap | None = None,
axes: Axes | None = None,
cbar_axes: Axes | None = None,
*,
verbose: str | None = None,
**kwargs,
) -> tuple[plt.Figure, Axes]:
"""Code snippet to plot segmentation for raw and epochs."""
_check_type(labels, (np.ndarray,), "labels") # 1D array (n_times, )
_check_type(gfp, (np.ndarray,), "gfp") # 1D array (n_times, )
_check_type(times, (np.ndarray,), "times") # 1D array (n_times, )
_check_type(n_clusters, ("int",), "n_clusters")
if n_clusters <= 0:
raise ValueError(
f"Provided number of clusters {n_clusters} is invalid. The number of "
"clusters must be strictly positive."
)
_check_type(cluster_names, (None, list, tuple), "cluster_names")
_check_type(cmap, (None, str, colors.Colormap), "cmap")
_check_type(axes, (None, Axes), "ax")
_check_type(cbar_axes, (None, Axes), "cbar_ax")
# check cluster_names
if cluster_names is None:
cluster_names = [str(k) for k in range(1, n_clusters + 1)]
if len(cluster_names) != n_clusters:
raise ValueError(
"Argument 'cluster_names' should have the 'n_clusters' elements. "
f"Provided: {len(cluster_names)} names for {n_clusters} clusters."
)
if axes is None:
fig, axes = plt.subplots(1, 1, layout="constrained")
else:
fig = axes.get_figure()
# add color and linewidth if absent from kwargs
if "color" not in kwargs:
kwargs["color"] = "black"
if "linewidth" not in kwargs:
kwargs["linewidth"] = 0.2
# define states and colors
state_labels = [-1] + list(range(n_clusters))
cluster_names = ["unlabeled"] + cluster_names
n_colors = n_clusters + 1
cmap = _compatibility_cmap(cmap, n_colors)
# plot
axes.plot(times, gfp, **kwargs)
for state, color in zip(state_labels, cmap.colors, strict=False):
pos = np.where(labels == state)[0]
if len(pos):
pos = np.unique([pos, pos + 1])
x = np.zeros(labels.shape).astype(bool)
if pos[-1] == labels.size:
pos = pos[:-1]
x[pos] = True
axes.fill_between(
times,
gfp,
color=color,
where=x,
step=None,
interpolate=False,
linewidth=0,
)
logger.info(
"For visualization purposes, the last segment appears truncated by 1 sample. "
"In the case where the last segment is 1 sample long, it does not appear."
)
# commonm formatting
axes.set_title("Segmentation")
axes.autoscale(tight=True)
# color bar
norm = colors.Normalize(vmin=0, vmax=n_colors)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
colorbar = plt.colorbar(
sm, cax=cbar_axes, ax=axes, ticks=[i + 0.5 for i in range(n_colors)]
)
colorbar.ax.set_yticklabels(cluster_names)
return fig, axes
def _compatibility_cmap(
cmap: str | colors.Colormap | None,
n_colors: int,
):
"""Convert the 'cmap' argument to a colormap.
Matplotlib 3.6 introduced a deprecation of plt.cm.get_cmap().
When support for the 3.6 version is dropped, this checker can be removed.
"""
if cmap is None:
cmap = colormaps["viridis"]
elif isinstance(cmap, str):
cmap = colormaps[cmap] # the cmap name is checked by matplotlib
cmap = cmap.resampled(n_colors)
return cmap