Source code for pycrostates.viz.cluster_centers

"""Visualization module for plotting cluster centers."""

from __future__ import annotations  # c.f. PEP 563, PEP 649

from typing import TYPE_CHECKING

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from mne import Info
from mne.channels.layout import _find_topomap_coords
from mne.viz import plot_topomap

from ..utils._checks import _check_axes, _check_type, _ensure_valid_show
from ..utils._docs import fill_doc
from ..utils._logs import logger, verbose

if TYPE_CHECKING:
    from typing import Any

    from .._typing import AxesArray, ScalarFloatArray
    from ..io import ChInfo


_GRADIENT_KWARGS_DEFAULTS: dict[str, str] = {
    "color": "black",
    "linestyle": "-",
    "marker": "P",
}


[docs] @fill_doc @verbose def plot_cluster_centers( cluster_centers: ScalarFloatArray, info: Info | ChInfo, cluster_names: list[str] = None, axes: Axes | AxesArray | None = None, show_gradient: bool | None = False, gradient_kwargs: dict[str, Any] = _GRADIENT_KWARGS_DEFAULTS, *, block: bool = False, show: bool | None = None, verbose: str | None = None, **kwargs, ): """Create topographic maps for cluster centers. Parameters ---------- %(cluster_centers)s info : Info | ChInfo Info instance with a montage used to plot the topographic maps. %(cluster_names)s %(axes_topo)s show_gradient : bool If True, plot a line between channel locations with highest and lowest values. gradient_kwargs : dict Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.plot` to plot gradient line. %(block)s %(show)s %(verbose)s **kwargs Additional keyword arguments are passed to :func:`mne.viz.plot_topomap`. Returns ------- fig : Figure Matplotlib figure(s) on which topographic maps are plotted. """ from ..io import ChInfo _check_type(cluster_centers, (np.ndarray,), "cluster_centers") _check_type(info, (Info, ChInfo), "info") _check_type(cluster_names, (None, list, tuple), "cluster_names") if axes is not None: _check_axes(axes) _check_type(show_gradient, (bool,), "show_gradient") _check_type( gradient_kwargs, (dict,), "gradient_kwargs", ) if gradient_kwargs != _GRADIENT_KWARGS_DEFAULTS and not show_gradient: logger.warning( "The argument 'gradient_kwargs' has not effect when the argument " "'show_gradient' is set to False." ) _check_type(block, (bool,), "block") show = _ensure_valid_show(show) # check cluster_names if cluster_names is None: cluster_names = [str(k) for k in range(1, cluster_centers.shape[0] + 1)] if len(cluster_names) != cluster_centers.shape[0]: raise ValueError( "Argument 'cluster_centers' and 'cluster_names' should have the same " "number of elements." ) # create axes if needed, and retrieve figure n_clusters = cluster_centers.shape[0] if axes is None: f, axes = plt.subplots( 1, n_clusters, figsize=(3 * n_clusters, 3), layout="constrained" ) if isinstance(axes, Axes): axes = np.array([axes]) # wrap in an array-like # sanity-check assert axes.ndim == 1 # axes formatting for ax in axes: ax.set_axis_off() else: # make sure we have enough ax to plot if isinstance(axes, Axes) and n_clusters != 1: raise ValueError( "Argument 'cluster_centers' and 'axes' must contain the " f"same number of clusters and Axes. Provided: {n_clusters} " "microstates maps and only 1 axes." ) elif axes.size != n_clusters: raise ValueError( "Argument 'cluster_centers' and 'axes' must contain the same " f"number of clusters and Axes. Provided: {n_clusters} " f"microstates maps and {axes.size} axes." ) figs = [ax.get_figure() for ax in axes.flatten()] if len(set(figs)) == 1: f = figs[0] else: f = figs del figs # plot cluster centers for k, (center, name) in enumerate( zip(cluster_centers, cluster_names, strict=False) ): # select axes from ax if axes.ndim == 1: ax = axes[k] else: i = k // axes.shape[1] j = k % axes.shape[1] ax = axes[i, j] # plot plot_topomap(center, info, axes=ax, show=False, **kwargs) # Add min max vector if show_gradient: i_min = np.argmin(center) i_max = np.argmax(center) pos = _find_topomap_coords(info, picks="all") ax.plot( [pos[i_min, 0], pos[i_max, 0]], [pos[i_min, 1], pos[i_max, 1]], **gradient_kwargs, ) ax.set_title(name) if show: plt.show(block=block) return f