Source code for pycrostates.cluster.kmeans

"""Class and functions to use modified Kmeans."""

from __future__ import annotations  # c.f. PEP 563, PEP 649

from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
from mne import BaseEpochs
from mne.io import BaseRaw
from mne.parallel import parallel_func
from numpy.random import Generator, RandomState

from ..utils import _corr_vectors
from ..utils._checks import _check_n_jobs, _check_random_state, _check_type
from ..utils._docs import copy_doc, fill_doc
from ..utils._logs import logger
from ._base import _BaseCluster

if TYPE_CHECKING:
    from typing import Any, Optional, Union

    from .._typing import Picks, RandomState, ScalarFloatArray, ScalarIntArray
    from ..io import ChData


[docs] @fill_doc class ModKMeans(_BaseCluster): r"""Modified K-Means clustering algorithm. See :footcite:t:`Marqui1995` for additional information. Parameters ---------- %(n_clusters)s n_init : int Number of time the k-means algorithm is run with different centroid seeds. The final result will be the run with the highest Global Explained Variance (GEV). max_iter : int Maximum number of iterations of the K-means algorithm for a single run. tol : float Relative tolerance with regards estimate residual noise in the cluster centers of two consecutive iterations to declare convergence. %(random_state)s References ---------- .. footbibliography:: """ # TODO: docstring for tol doesn't look english def __init__( self, n_clusters: int, n_init: int = 100, max_iter: int = 300, tol: Union[int, float] = 1e-6, random_state: RandomState = None, ): super().__init__() # k-means has a fix number of clusters defined at init self._n_clusters = _BaseCluster._check_n_clusters(n_clusters) self._cluster_names = [str(k) for k in range(self.n_clusters)] # k-means settings self._n_init = ModKMeans._check_n_init(n_init) self._max_iter = ModKMeans._check_max_iter(max_iter) self._tol = ModKMeans._check_tol(tol) self._random_state = _check_random_state(random_state) # fit variables self._GEV_ = None def _repr_html_(self, caption=None): from ..html_templates import repr_templates_env template = repr_templates_env.get_template("ModKMeans.html.jinja") if self.fitted: n_samples = self._fitted_data.shape[-1] ch_types, ch_counts = np.unique( self.get_channel_types(), return_counts=True ) ch_repr = [ f"{ch_count} {ch_type.upper()}" for ch_type, ch_count in zip(ch_types, ch_counts) ] GEV = f"{self._GEV_ * 100:.2f}" else: n_samples = None ch_repr = None GEV = None return template.render( name=self.__class__.__name__, n_clusters=self._n_clusters, n_init=self._n_init, GEV=GEV, cluster_names=self._cluster_names, fitted=self._fitted, n_samples=n_samples, ch_repr=ch_repr, ) @copy_doc(_BaseCluster.__eq__) def __eq__(self, other: Any) -> bool: if isinstance(other, ModKMeans): if not super().__eq__(other): return False attributes = ( "_n_init", "_max_iter", "_tol", # '_random_state', # TODO: think about comparison and I/O for random states "_GEV_", ) for attribute in attributes: try: attr1 = self.__getattribute__(attribute) attr2 = other.__getattribute__(attribute) except AttributeError: return False if attr1 != attr2: return False return True return False @copy_doc(_BaseCluster.__ne__) def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @copy_doc(_BaseCluster._check_fit) def _check_fit(self): super()._check_fit() # sanity-check assert self.GEV_ is not None
[docs] @fill_doc def fit( self, inst: Union[BaseRaw, BaseEpochs, ChData], picks: Picks = "eeg", tmin: Optional[Union[int, float]] = None, tmax: Optional[Union[int, float]] = None, reject_by_annotation: bool = True, n_jobs: int = 1, *, verbose: Optional[str] = None, ) -> None: """Compute cluster centers. Parameters ---------- inst : Raw | Epochs | ChData MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object from which to extract :term:`cluster centers`. picks : str | list | slice | None Channels to include. Note that all channels selected must have the same type. Slices and lists of integers will be interpreted as channel indices. In lists, channel name strings (e.g. ``['Fp1', 'Fp2']``) will pick the given channels. Can also be the string values ``“all”`` to pick all channels, or ``“data”`` to pick data channels. ``"eeg"`` (default) will pick all eeg channels. Note that channels in ``info['bads']`` will be included if their names or indices are explicitly provided. %(tmin_raw)s %(tmax_raw)s %(reject_by_annotation_raw)s %(n_jobs)s %(verbose)s """ n_jobs = _check_n_jobs(n_jobs) data = super().fit( inst, picks=picks, tmin=tmin, tmax=tmax, reject_by_annotation=reject_by_annotation, verbose=verbose, ) inits = self._random_state.randint( low=0, high=100 * self._n_init, size=(self._n_init) ) if n_jobs == 1: best_gev, best_maps, best_segmentation = None, None, None count_converged = 0 for init in inits: gev, maps, segmentation, converged = ModKMeans._kmeans( data, self._n_clusters, self._max_iter, init, self._tol ) if not converged: continue if best_gev is None or gev > best_gev: best_gev, best_maps, best_segmentation = ( gev, maps, segmentation, ) count_converged += 1 else: parallel, p_fun, _ = parallel_func( ModKMeans._kmeans, n_jobs, total=self._n_init ) runs = parallel( p_fun(data, self._n_clusters, self._max_iter, init, self._tol) for init in inits ) try: best_run = np.nanargmax([run[0] if run[3] else np.nan for run in runs]) best_gev, best_maps, best_segmentation, _ = runs[best_run] count_converged = sum(run[3] for run in runs) except ValueError: best_gev, best_maps, best_segmentation = None, None, None count_converged = 0 if best_gev is not None: logger.info( "Selecting run with highest GEV = %.2f%% after %i/%i " "iterations converged.", best_gev * 100, count_converged, self._n_init, ) else: logger.error( "All the K-means run failed to converge. Please adapt the tolerance " "and the maximum number of iteration." ) self.fitted = False # reset variables related to fit return # break early self._GEV_ = best_gev self._cluster_centers_ = best_maps self._labels_ = best_segmentation self._fitted = True self._ignore_polarity = True
[docs] @copy_doc(_BaseCluster.save) def save(self, fname: Union[str, Path]): super().save(fname) # TODO: to be replaced by a general writer than infers the writer from # the file extension. from ..io.fiff import _write_cluster # pylint: disable=C0415 _write_cluster( fname, self._cluster_centers_, self._info, "ModKMeans", self._cluster_names, self._fitted_data, self._labels_, n_init=self._n_init, max_iter=self._max_iter, tol=self._tol, GEV_=self._GEV_, )
# -------------------------------------------------------------------- @staticmethod def _kmeans( data: ScalarFloatArray, n_clusters: int, max_iter: int, random_state: Union[RandomState, Generator], tol: Union[int, float], ) -> tuple[float, ScalarFloatArray, ScalarIntArray, bool]: """Run the k-means algorithm.""" gfp_sum_sq = np.sum(data**2) maps, converged = ModKMeans._compute_maps( data, n_clusters, max_iter, random_state, tol ) activation = maps.dot(data) segmentation = np.argmax(np.abs(activation), axis=0) map_corr = _corr_vectors(data, maps[segmentation].T) gev = np.sum((data * map_corr) ** 2) / gfp_sum_sq return gev, maps, segmentation, converged @staticmethod def _compute_maps( data: ScalarFloatArray, n_clusters: int, max_iter: int, random_state: Union[RandomState, Generator], tol: Union[int, float], ) -> tuple[ScalarFloatArray, bool]: """Compute microstates maps. Based on mne_microstates by Marijn van Vliet <w.m.vanvliet@gmail.com> https://github.com/wmvanvliet/mne_microstates/blob/master/microstates.py """ # TODO: Does this work if the RandomState is a generator? if not isinstance(random_state, np.random.RandomState): random_state = np.random.RandomState(random_state) # ------------------------- handle zeros maps ------------------------- # zero map can be due to non data in the recording, it's unlikely that # all channels recorded the same value at the same time (=0 due to # average reference) # --------------------------------------------------------------------- data = data[:, np.linalg.norm(data.T, axis=1) != 0] n_channels, n_samples = data.shape data_sum_sq = np.sum(data**2) # Select random time points for our initial topographic maps init_times = random_state.choice(n_samples, size=n_clusters, replace=False) maps = data[:, init_times].T # Normalize the maps maps /= np.linalg.norm(maps, axis=1, keepdims=True) prev_residual = np.inf for _ in range(max_iter): # Assign each sample to the best matching microstate activation = maps.dot(data) segmentation = np.argmax(np.abs(activation), axis=0) # Recompute the topographic maps of the microstates, based on the # samples that were assigned to each state. for state in range(n_clusters): idx = segmentation == state if np.sum(idx) == 0: maps[state] = 0 continue # Find largest eigenvector maps[state] = data[:, idx].dot(activation[state, idx]) maps[state] /= np.linalg.norm(maps[state]) # Estimate residual noise act_sum_sq = np.sum(np.sum(maps[segmentation].T * data, axis=0) ** 2) residual = abs(data_sum_sq - act_sum_sq) residual /= float(n_samples * (n_channels - 1)) # check convergence if (prev_residual - residual) < (tol * residual): converged = True break prev_residual = residual else: converged = False return maps, converged # -------------------------------------------------------------------- @property def n_init(self) -> int: # noqa: D401 """Number of k-means algorithms run with different centroid seeds. :type: `int` """ return self._n_init @property def max_iter(self) -> int: """Maximum number of iterations of the k-means algorithm for a run. :type: `int` """ return self._max_iter @property def tol(self) -> Union[int, float]: """Relative tolerance to reach convergence. :type: `float` """ return self._tol @property def random_state(self) -> Union[RandomState, Generator]: """Random state to fix seed generation. :type: `~numpy.random.RandomState` | `~numpy.random.Generator` """ return self._random_state @property def GEV_(self) -> float: """Global Explained Variance. :type: `float` """ if self._GEV_ is None: assert not self._fitted # sanity-check logger.warning("Clustering algorithm has not been fitted.") return self._GEV_ @_BaseCluster.fitted.setter @copy_doc(_BaseCluster.fitted.setter) def fitted(self, fitted): super(self.__class__, self.__class__).fitted.__set__(self, fitted) if not fitted: self._GEV_ = None # -------------------------------------------------------------------- # For now, check function are defined as static within KMeans. If they are # used outside KMeans, they should be moved to regular function in _base.py # -------------------------------------------------------------------- @staticmethod def _check_n_init(n_init: int) -> int: """Check that n_init is a positive integer.""" _check_type(n_init, ("int",), item_name="n_init") if n_init <= 0: raise ValueError( "The number of initialization must be a positive integer. " f"Provided: '{n_init}'." ) return n_init @staticmethod def _check_max_iter(max_iter: int) -> int: """Check that max_iter is a positive integer.""" _check_type(max_iter, ("int",), item_name="max_iter") if max_iter <= 0: raise ValueError( "The number of max iteration must be a positive integer. " f"Provided: '{max_iter}'." ) return max_iter @staticmethod def _check_tol(tol: Union[int, float]) -> Union[int, float]: """Check that tol is a positive number.""" _check_type(tol, ("numeric",), item_name="tol") if tol <= 0: raise ValueError( f"The tolerance must be a positive number. Provided: '{tol}'." ) return tol