Source code for proteopy.pl.clustering

"""Clustering visualization tools for proteomics data."""

import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from scipy.cluster.hierarchy import fcluster
from sklearn.metrics import silhouette_score

from proteopy.utils.anndata import check_proteodata
from proteopy.utils.parsers import (
    _resolve_hclustv_keys,
    _resolve_hclustv_profile_key,
)


def _compute_wcss(X: np.ndarray, labels: np.ndarray) -> float:
    """
    Compute within-cluster sum of squares.

    Parameters
    ----------
    X : np.ndarray
        Data matrix with samples as rows.
    labels : np.ndarray
        Cluster labels for each sample.

    Returns
    -------
    float
        Total within-cluster sum of squares.
    """
    wcss = 0.0
    unique_labels = np.unique(labels)
    for label in unique_labels:
        cluster_points = X[labels == label]
        centroid = cluster_points.mean(axis=0)
        wcss += np.sum((cluster_points - centroid) ** 2)
    return wcss


[docs] def hclustv_silhouette( adata: ad.AnnData, linkage_key: str = 'auto', values_key: str = 'auto', k: int = 15, figsize: tuple[float, float] = (6.0, 4.0), show: bool = True, ax: bool = False, save: str | Path | None = None, verbose: bool = True, ) -> Axes | None: """ Plot silhouette scores for hierarchical clustering. Evaluates clustering quality by computing the average silhouette score for cluster counts ranging from 2 to ``k``. Higher silhouette scores indicate better-defined clusters. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` with clustering results from :func:`proteopy.tl.hclustv_tree` stored in ``.uns``. linkage_key : str Key in ``adata.uns`` for the linkage matrix. When ``'auto'``, auto-detects keys matching ``hclustv_linkage;*``. values_key : str Key in ``adata.uns`` for the profile values DataFrame. When ``'auto'``, auto-detects keys matching ``hclustv_values;*``. k : int Maximum number of clusters to evaluate. Silhouette scores are computed for cluster counts from 2 to ``k`` (inclusive). figsize : tuple[float, float] Matplotlib figure size in inches. show : bool Display the figure. ax : bool Return the Matplotlib Axes object instead of displaying. save : str | Path | None File path for saving the figure. verbose : bool Print status messages including auto-detected keys. Returns ------- Axes | None Axes object when ``ax`` is ``True``; otherwise ``None``. Raises ------ ValueError If no clustering results are found in ``adata.uns``, if multiple candidates exist and keys are not specified, or if ``k < 2``. KeyError If the specified ``linkage_key`` or ``values_key`` is not found. Examples -------- >>> import proteopy as pp >>> adata = pp.datasets.example_peptide_data() >>> pr.tl.hclustv_tree(adata, group_by="condition") >>> pr.pl.hclustv_silhouette(adata, k=5) """ check_proteodata(adata) if k < 2: raise ValueError("k must be at least 2 to compute silhouette scores.") linkage_key, values_key = _resolve_hclustv_keys( adata, linkage_key, values_key, verbose ) Z = adata.uns[linkage_key] profile_df = adata.uns[values_key] # profile_df has observations/groups as rows, variables as columns # For silhouette_score, we need samples (variables) as rows X = profile_df.T.values n_vars = X.shape[0] # Limit k to valid range max_k = n_vars - 1 if k > max_k: if verbose: print( f"k={k} exceeds maximum valid clusters ({max_k}). " f"Limiting to k={max_k}." ) k = max_k # Compute silhouette scores for k from 2 to k k_values = list(range(2, k + 1)) silhouette_scores_list = [] for n_clusters in k_values: labels = fcluster(Z, t=n_clusters, criterion="maxclust") score = silhouette_score(X, labels) silhouette_scores_list.append(score) # Create plot fig, _ax = plt.subplots(figsize=figsize) _ax.plot(k_values, silhouette_scores_list, marker="o", linewidth=1.5) _ax.set_xlabel("Number of clusters (k)") _ax.set_ylabel("Average silhouette score") _ax.set_title("Silhouette analysis for hierarchical clustering") # Set x-axis to show integer ticks only _ax.set_xticks(k_values) plt.tight_layout() if save is not None: fig.savefig(save, dpi=300, bbox_inches="tight") if verbose: print(f"Figure saved to: {save}") if show: plt.show() if ax: return _ax if not show and save is None and not ax: warnings.warn( "Function does not do anything. Enable `show`, provide a `save` " "path, or set `ax=True`." ) plt.close(fig) return None
[docs] def hclustv_elbow( adata: ad.AnnData, linkage_key: str = 'auto', values_key: str = 'auto', k: int = 15, figsize: tuple[float, float] = (6.0, 4.0), show: bool = True, ax: bool = False, save: str | Path | None = None, verbose: bool = True, ) -> Axes | None: """ Plot within-cluster sum of squares (elbow plot) for hierarchical clustering. Evaluates clustering by computing WCSS for cluster counts ranging from 1 to ``k``. The "elbow" point where WCSS reduction diminishes suggests an optimal cluster count. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` with clustering results from :func:`proteopy.tl.hclustv_tree` stored in ``.uns``. linkage_key : str Key in ``adata.uns`` for the linkage matrix. When ``'auto'``, auto-detects keys matching ``hclustv_linkage;*``. values_key : str Key in ``adata.uns`` for the profile values DataFrame. When ``'auto'``, auto-detects keys matching ``hclustv_values;*``. k : int Maximum number of clusters to evaluate. WCSS is computed for cluster counts from 1 to ``k`` (inclusive). figsize : tuple[float, float] Matplotlib figure size in inches. show : bool Display the figure. ax : bool Return the Matplotlib Axes object instead of displaying. save : str | Path | None File path for saving the figure. verbose : bool Print status messages including auto-detected keys. Returns ------- Axes | None Axes object when ``ax`` is ``True``; otherwise ``None``. Raises ------ ValueError If no clustering results are found in ``adata.uns``, if multiple candidates exist and keys are not specified, or if ``k < 1``. KeyError If the specified ``linkage_key`` or ``values_key`` is not found. Examples -------- >>> import proteopy as pr >>> adata = pr.datasets.example_peptide_data() >>> pr.tl.hclustv_tree(adata, group_by="condition") >>> pr.pl.hclustv_elbow(adata, k=10) """ check_proteodata(adata) if k < 1: raise ValueError("k must be at least 1 to compute WCSS.") linkage_key, values_key = _resolve_hclustv_keys( adata, linkage_key, values_key, verbose ) Z = adata.uns[linkage_key] profile_df = adata.uns[values_key] # profile_df has observations/groups as rows, variables as columns # For WCSS, we need samples (variables) as rows X = profile_df.T.values n_vars = X.shape[0] # Limit k to valid range max_k = n_vars if k > max_k: if verbose: print( f"k={k} exceeds maximum valid clusters ({max_k}). " f"Limiting to k={max_k}." ) k = max_k # Compute WCSS for k from 1 to k k_values = list(range(1, k + 1)) wcss_list = [] for n_clusters in k_values: labels = fcluster(Z, t=n_clusters, criterion="maxclust") wcss = _compute_wcss(X, labels) wcss_list.append(wcss) # Create plot fig, _ax = plt.subplots(figsize=figsize) _ax.plot(k_values, wcss_list, marker="o", linewidth=1.5) _ax.set_xlabel("Number of clusters (k)") _ax.set_ylabel("Within-cluster sum of squares (WCSS)") _ax.set_title("Elbow plot for hierarchical clustering") # Set x-axis to show integer ticks only _ax.set_xticks(k_values) plt.tight_layout() if save is not None: fig.savefig(save, dpi=300, bbox_inches="tight") if verbose: print(f"Figure saved to: {save}") if show: plt.show() if ax: return _ax if not show and save is None and not ax: warnings.warn( "Function does not do anything. Enable `show`, provide a `save` " "path, or set `ax=True`." ) plt.close(fig) return None
[docs] def hclustv_profile_intensities( adata: ad.AnnData, profiles: str | list[str] | None = None, profile_key: str = 'auto', group_by: str | pd.Series | dict | None = None, sort_by: str | pd.Series | dict | None = None, order: list[str] | None = None, n_cols: int = 2, n_rows: int = 3, title: str | None = None, titles: list[str] | dict[str, str] | None = None, xlabel_rotation: float = 45, sort_by_label_rotation: float = 0, ylabel: str = "Intensity", marker: str = 'o', markersize: float = 6, linewidth: float = 1.5, errorbar: str | tuple = 'se', color: str | None = None, figsize: tuple[float, float] | None = None, show: bool = True, ax: bool = False, save: str | Path | None = None, verbose: bool = True, ) -> list[Axes] | None: """ Plot cluster profile intensities across observations. Displays line plots for each cluster profile showing how intensity varies across observations. When ``group_by`` is specified, observations are grouped and error bars are displayed. Parameters ---------- adata : AnnData :class:`~anndata.AnnData` with cluster profiles stored in ``.uns`` from :func:`proteopy.tl.hclustv_profiles`. profiles : str | list[str] | None Profile column(s) to plot from the profiles DataFrame. When ``None``, plots the first 6 profiles or fewer if not available. Can be a single profile name (e.g., ``"01"``) or a list of names. profile_key : str Key in ``adata.uns`` for the profiles DataFrame. When ``'auto'``, auto-detects keys matching ``hclustv_profiles;*``. group_by : str | pd.Series | dict | None Grouping for x-axis observations. When ``str``, uses the column from ``adata.obs`` to group observations and display error bars. When ``pd.Series``, uses Series index as observation keys and values as group labels. When ``dict``, maps observation indices to group labels directly. When ``None``, plots individual observations without grouping. If the column or Series is categorical, the category order is respected for x-axis ordering. Mutually exclusive with ``sort_by``. sort_by : str | pd.Series | dict | None Sort individual observations by group membership without aggregating. When ``str``, uses the column from ``adata.obs``. When ``pd.Series``, uses Series index as observation keys and values as sort groups. When ``dict``, maps observation indices to sort groups directly. Observations are ordered by their group, with group order determined by ``order`` (if provided) or categorical order (if categorical). Mutually exclusive with ``group_by``. order : list[str] | None Order of groups on the x-axis. When ``group_by`` is specified, controls the order of grouped categories. When ``sort_by`` is specified, controls the order in which sort groups appear. When ``None``, uses categorical order if available, otherwise sorted alphabetically. n_cols : int Number of columns in the subplot grid. n_rows : int Number of rows in the subplot grid. title : str | None Overall figure title. When ``None``, no suptitle is added. titles : list[str] | dict[str, str] | None Custom titles for each subplot. When ``list``, must have the same length as the number of plotted profiles. When ``dict``, maps profile/cluster names to custom titles. When ``None``, uses default titles (``"Cluster {profile_name}"``). xlabel_rotation : float Rotation angle (degrees) for x-axis tick labels. sort_by_label_rotation : float Rotation angle (degrees) for sort group labels when ``sort_by`` is used. ylabel : str Label for the y-axis of each subplot. marker : str Marker style for data points. markersize : float Size of data point markers. linewidth : float Width of connecting lines. errorbar : str | tuple Error bar style for grouped data. Passed to ``sns.lineplot``. Common options: ``'se'`` (standard error), ``'sd'`` (standard deviation), ``'ci'`` (confidence interval), ``('ci', 95)``. color : str | None Color for the line and markers. When ``None``, uses default palette. figsize : tuple[float, float] | None Figure size. When ``None``, auto-computed based on grid dimensions. show : bool Display the figure. ax : bool Return the Matplotlib Axes objects instead of displaying. save : str | Path | None File path for saving the figure. verbose : bool Print status messages including auto-detected keys. Returns ------- list[Axes] | None List of Axes objects when ``ax`` is ``True``; otherwise ``None``. Raises ------ ValueError If no cluster profiles are found in ``adata.uns``, if multiple candidates exist and ``profile_key`` is not specified, or if specified profiles are not found in the DataFrame. KeyError If the specified ``profile_key`` is not found, or if ``group_by`` column is not found in ``adata.obs``. TypeError If the profiles data is not a pandas DataFrame. Examples -------- >>> import proteopy as pr >>> adata = pr.datasets.karayel_2020() >>> pr.tl.hclustv_tree(adata, group_by="condition") >>> pr.tl.hclustv_cluster_ann(adata, k=5) >>> pr.tl.hclustv_profiles(adata) >>> pr.pl.hclustv_profile_intensities(adata) Plot with grouping and error bars: >>> pr.pl.hclustv_profile_intensities(adata, group_by="condition") Plot specific profiles: >>> pr.pl.hclustv_profile_intensities(adata, profiles=["01", "03"]) """ import seaborn as sns check_proteodata(adata) # Resolve profiles key resolved_key = _resolve_hclustv_profile_key( adata, profile_key, verbose ) profiles_df = adata.uns[resolved_key] # Validate profiles DataFrame if not isinstance(profiles_df, pd.DataFrame): raise TypeError( f"Expected profiles data to be DataFrame, " f"got {type(profiles_df).__name__}." ) if profiles_df.empty: raise ValueError("Profiles DataFrame is empty.") available_profiles = profiles_df.columns.tolist() # Determine which profiles to plot if profiles is None: max_profiles = n_cols * n_rows selected_profiles = available_profiles[:min(6, max_profiles)] elif isinstance(profiles, str): selected_profiles = [profiles] else: selected_profiles = list(profiles) # Validate selected profiles exist missing_profiles = [ p for p in selected_profiles if p not in available_profiles ] if missing_profiles: raise ValueError( f"Profiles not found in DataFrame: {missing_profiles}. " f"Available profiles: {available_profiles}" ) if not selected_profiles: raise ValueError("No profiles to plot.") # Limit to grid capacity max_plots = n_cols * n_rows if len(selected_profiles) > max_plots: if verbose: print( f"Only plotting first {max_plots} profiles " f"(grid capacity: {n_rows}x{n_cols})." ) selected_profiles = selected_profiles[:max_plots] # Validate titles parameter if titles is not None: if isinstance(titles, list): if len(titles) != len(selected_profiles): raise ValueError( f"titles list length ({len(titles)}) must match " f"number of profiles ({len(selected_profiles)})." ) elif not isinstance(titles, dict): raise TypeError( f"titles must be list, dict, or None, " f"got {type(titles).__name__}." ) # Validate mutually exclusive parameters if group_by is not None and sort_by is not None: raise ValueError( "group_by and sort_by are mutually exclusive. " "Use group_by to aggregate observations, or sort_by to " "order individual observations by group membership." ) # Helper to extract mapping and category order def _extract_mapping(param, param_name): mapping = None cat_order = None if param is None: pass elif isinstance(param, str): if param not in adata.obs.columns: raise KeyError( f"{param_name} column '{param}' not found in adata.obs." ) obs_col_data = adata.obs[param] if hasattr(obs_col_data, 'cat'): cat_order = obs_col_data.cat.categories.tolist() obs_in_profiles = profiles_df.index.intersection(adata.obs_names) mapping = adata.obs.loc[obs_in_profiles, param].to_dict() elif isinstance(param, pd.Series): if hasattr(param, 'cat'): cat_order = param.cat.categories.tolist() mapping = param.to_dict() elif isinstance(param, dict): mapping = param else: raise TypeError( f"{param_name} must be str, pd.Series, dict, or None, " f"got {type(param).__name__}." ) return mapping, cat_order group_mapping, group_category_order = _extract_mapping(group_by, 'group_by') sort_mapping, sort_category_order = _extract_mapping(sort_by, 'sort_by') # Build long-form DataFrame for seaborn plot_data = profiles_df[selected_profiles].copy() plot_data = plot_data.reset_index() plot_data = plot_data.melt( id_vars=[plot_data.columns[0]], var_name='profile', value_name='intensity', ) obs_col = plot_data.columns[0] # Determine x variable and apply grouping/sorting if group_mapping is not None: plot_data['group'] = plot_data[obs_col].map(group_mapping) plot_data = plot_data.dropna(subset=['group']) x_var = 'group' category_order = group_category_order elif sort_mapping is not None: plot_data['_sort_group'] = plot_data[obs_col].map(sort_mapping) plot_data = plot_data.dropna(subset=['_sort_group']) x_var = obs_col category_order = sort_category_order else: x_var = obs_col category_order = None # Determine group/sort order if order is not None: group_order = order elif category_order is not None: if group_mapping is not None: present_values = set(plot_data['group'].unique()) elif sort_mapping is not None: present_values = set(plot_data['_sort_group'].unique()) else: present_values = set() group_order = [c for c in category_order if c in present_values] elif group_mapping is not None: group_order = sorted(plot_data['group'].unique()) elif sort_mapping is not None: group_order = sorted(plot_data['_sort_group'].unique()) else: group_order = None # Filter to only include specified groups if group_order is not None: if group_mapping is not None: plot_data = plot_data[plot_data['group'].isin(group_order)] elif sort_mapping is not None: plot_data = plot_data[plot_data['_sort_group'].isin(group_order)] # Determine x-axis order if group_mapping is not None: x_order = group_order elif sort_mapping is not None: # Sort observations by their group membership plot_data['_sort_group'] = pd.Categorical( plot_data['_sort_group'], categories=group_order, ordered=True ) sorted_obs = ( plot_data[[obs_col, '_sort_group']] .drop_duplicates() .sort_values('_sort_group')[obs_col] .tolist() ) x_order = sorted_obs else: x_order = profiles_df.index.tolist() # Convert x variable to categorical with specified order plot_data[x_var] = pd.Categorical( plot_data[x_var], categories=x_order, ordered=True ) plot_data = plot_data.sort_values(x_var) # Determine figure size if figsize is None: fig_width = 4 * n_cols fig_height = 3 * n_rows figsize = (fig_width, fig_height) # Create figure and axes n_profiles = len(selected_profiles) actual_rows = min(n_rows, (n_profiles + n_cols - 1) // n_cols) fig, axes_array = plt.subplots( actual_rows, n_cols, figsize=figsize, squeeze=False, ) axes_flat = axes_array.flatten() returned_axes = [] for idx, profile_name in enumerate(selected_profiles): _ax = axes_flat[idx] profile_data = plot_data[plot_data['profile'] == profile_name] if group_mapping is not None: sns.lineplot( data=profile_data, x=x_var, y='intensity', err_style='bars', errorbar=errorbar, err_kws={'capsize': 4}, marker=marker, markersize=markersize, linewidth=linewidth, color=color, ax=_ax, sort=False, ) else: sns.lineplot( data=profile_data, x=x_var, y='intensity', errorbar=None, marker=marker, markersize=markersize, linewidth=linewidth, color=color if color else '#4C78A8', ax=_ax, sort=False, ) # Add sort group labels below plot area if sort_mapping is not None and group_order is not None: # Extend y-axis to make room for labels ymin, ymax = _ax.get_ylim() y_range = ymax - ymin new_ymin = ymin - 0.15 * y_range _ax.set_ylim(new_ymin, ymax) # Position for labels (just below original ymin) label_y = ymin - 0.05 * y_range # Build mapping of obs position in x_order obs_to_pos = {obs: i for i, obs in enumerate(x_order)} for group_label in group_order: # Find observations belonging to this group group_obs = [ obs for obs in x_order if sort_mapping.get(obs) == group_label ] if not group_obs: continue # Get x positions for this group's observations positions = [obs_to_pos[obs] for obs in group_obs] center_x = (min(positions) + max(positions)) / 2 # Add label below plot area _ax.text( center_x, label_y, str(group_label), ha='center', va='top', fontsize=9, rotation=sort_by_label_rotation, ) # Set x-axis tick labels with rotation _ax.tick_params(axis='x', rotation=xlabel_rotation) for label in _ax.get_xticklabels(): label.set_ha('right') _ax.set_xlabel('') _ax.set_ylabel(ylabel) # Set subplot title if titles is None: subplot_title = f"Profile {profile_name}" elif isinstance(titles, list): subplot_title = titles[idx] else: subplot_title = titles.get(profile_name, f"Profile {profile_name}") _ax.set_title(subplot_title) returned_axes.append(_ax) # Hide unused axes for idx in range(n_profiles, len(axes_flat)): axes_flat[idx].set_visible(False) # Add overall title if title is not None: fig.suptitle(title, fontsize=12, y=1.02) plt.tight_layout() if save is not None: fig.savefig(save, dpi=300, bbox_inches="tight") if verbose: print(f"Figure saved to: {save}") if show: plt.show() if ax: return returned_axes if not show and save is None and not ax: warnings.warn( "Function does not do anything. Enable `show`, provide a `save` " "path, or set `ax=True`." ) plt.close(fig) return None