Source code for proteopy.pp.filtering

import warnings
from pathlib import Path
from typing import Callable
import numpy as np
import pandas as pd
import scipy.sparse as sp
from Bio import SeqIO

from proteopy.utils.functools import partial_with_docsig
from proteopy.utils.anndata import check_proteodata, is_proteodata


def filter_axis(
    adata,
    axis,
    min_fraction=None,
    min_count=None,
    group_by=None,
    zero_to_na=False,
    inplace=True,
):
    """
    Filter observations or variables based on non-missing value content.

    This function filters the AnnData object along a specified axis (observations
    or variables) based on the fraction or number of non-missing (np.nan) values.
    Filtering can be performed globally or within groups defined by the `group_by`
    parameter.

    Parameters
    ----------
    adata : anndata.AnnData
        The annotated data matrix to filter.
    axis : int
        The axis to filter on. `0` for observations, `1` for variables.
    min_fraction : float, optional
        The minimum fraction of non-missing values required to keep an observation
        or variable. If `group_by` is provided, this threshold is applied to the
        maximum completeness across all groups.
    min_count : int, optional
        The minimum number of non-missing values required to keep an observation
        or variable. If `group_by` is provided, this threshold is applied to the
        maximum count across all groups.
    group_by : str, optional
        A column key in `adata.obs` (if `axis=1`) or `adata.var` (if `axis=0`)
        used for grouping before applying the filter. The maximum completeness or
        count across the groups is used for filtering.
    zero_to_na : bool, optional
        If True, zeros in the data matrix are treated as missing values (NaN).
    inplace : bool, optional
        If True, modifies the `adata` object in place. Otherwise, returns a
        filtered copy.

    Returns
    -------
    anndata.AnnData or None
        If `inplace=False`, returns a new filtered AnnData object. Otherwise,
        returns `None`.

    Raises
    ------
    KeyError
        If the `group_by` key is not found in the corresponding annotation
        DataFrame.
    """
    check_proteodata(adata)

    if min_fraction is None and min_count is None:
        warnings.warn(
            "Neither `min_fraction` nor `min_count` were provided, so "
            "the function does nothing."
        )
        return None if inplace else adata.copy()

    X = adata.X.copy()
    if zero_to_na:
        if sp.issparse(X):
            X.data[X.data == 0] = np.nan
        else:
            X[X == 0] = np.nan

    if sp.issparse(X):
        X.eliminate_zeros()

    axis_i = 1 - axis
    axis_labels = adata.obs_names if axis == 0 else adata.var_names
    completeness = None # assigned below when min_fraction is set

    if group_by is not None:
        metadata = adata.obs if axis == 1 else adata.var
        if group_by not in metadata.columns:
            raise KeyError(
                f'`group_by`="{group_by}" not present in '
                f'adata.{"obs" if axis == 1 else "var"}'
            )
        grouping = metadata[group_by]
        unique_groups = grouping.dropna().unique()

        counts_by_group = []
        completeness_by_group = []
        for label in unique_groups:
            mask = (grouping == label).values
            subset = X[mask, :] if axis == 1 else X[:, mask]

            if subset.shape[axis_i] == 0:
                continue

            group_size = subset.shape[axis_i]

            if sp.issparse(subset):
                group_counts = subset.getnnz(axis=axis_i)
            else:
                group_counts = np.count_nonzero(~np.isnan(subset), axis=axis_i)

            df_counts = pd.DataFrame(group_counts, index=axis_labels)
            counts_by_group.append(df_counts)
            if min_fraction is not None:
                df_completeness = df_counts / group_size
                completeness_by_group.append(df_completeness)

        if not counts_by_group:
            counts = pd.Series(0, index=axis_labels, dtype=float)
        else:
            counts = pd.concat(counts_by_group, axis=1).max(axis=1)
        if min_fraction is not None:
            if not completeness_by_group:
                completeness = pd.Series(0, index=axis_labels, dtype=float)
            else:
                completeness = pd.concat(completeness_by_group, axis=1).max(axis=1)
    else:
        if sp.issparse(X):
            counts = pd.Series(X.getnnz(axis=axis_i), index=axis_labels)
        else:
            counts = pd.Series(
                np.count_nonzero(~np.isnan(X), axis=axis_i), index=axis_labels
            )
        if min_fraction is not None:
            num_total = adata.shape[axis_i]
            completeness = counts / num_total

    mask_filt = pd.Series(True, index=axis_labels)
    if min_fraction is not None:
        mask_filt &= completeness >= min_fraction

    if min_count is not None:
        mask_filt &= counts >= min_count

    n_removed = (~mask_filt).sum()
    axis_name = ["obs", "var"][axis]
    print(f"{n_removed} {axis_name} removed")

    if inplace:
        if axis == 0:
            adata._inplace_subset_obs(mask_filt.values)
        else:
            adata._inplace_subset_var(mask_filt.values)
        check_proteodata(adata)
        return None
    else:
        adata_filtered = adata[mask_filt, :] if axis == 0 else adata[:, mask_filt]
        check_proteodata(adata_filtered)
        return adata_filtered


docstr_header = """
Filter observations based on non-missing value content.

This function filters the AnnData object along the `obs` axis based on the
fraction or number of non-missing values (np.nan). Filtering can be performed
globally or within groups defined by the `group_by` parameter.
"""
filter_samples = partial_with_docsig(
    filter_axis,
    axis=0,
    docstr_header=docstr_header,
    )

docstr_header = """
Filter observations based on data completeness.

This function filters the AnnData object along a the obs axis based on the
fraction of non-missing values (np.nan). Filtering can be performed globally
or within groups defined by the `group_by` parameter.
"""
filter_samples_completeness = partial_with_docsig(
    filter_axis,
    axis=0,
    min_count=None,
    docstr_header=docstr_header,
    )

docstr_header = """
Filter variables based on non-missing value content.

This function filters the AnnData object along the `var` axis based on the
fraction or number of non-missing values (np.nan). Filtering can be performed
globally or within groups defined by the `group_by` parameter.
"""
filter_var = partial_with_docsig(
    filter_axis,
    axis=1,
    docstr_header=docstr_header,
    )

docstr_header = """
Filter variables based on data completeness.

This function filters the AnnData object along a the var axis based on the
fraction of non-missing values (np.nan). Filtering can be performed globally
or within groups defined by the `group_by` parameter.
"""
filter_var_completeness = partial_with_docsig(
    filter_axis,
    axis=1,
    min_count=None,
    docstr_header=docstr_header,
    )


[docs] def filter_proteins_by_peptide_count( adata, min_count=None, max_count=None, protein_col="protein_id", inplace=True, ): """ Filter proteins by their peptide count. Parameters ---------- adata : anndata.AnnData Annotated data matrix with a protein identifier column in ``adata.var``. min_count : int or None, optional Keep peptides whose proteins have at least this many peptides. max_count : int or None, optional Keep peptides whose proteins have at most this many peptides. protein_col : str, optional (default: "protein_id") Column in ``adata.var`` containing protein identifiers. inplace : bool, optional (default: True) If True, modify ``adata`` in place. Otherwise, return a filtered view. Returns ------- None or anndata.AnnData ``None`` if ``inplace=True``; otherwise the filtered AnnData view. """ check_proteodata(adata) if is_proteodata(adata)[1] != "peptide": raise ValueError(( "`AnnData` object must be in ProteoData peptide format." )) if min_count is None and max_count is None: warnings.warn("Pass at least one argument: min_count | max_count") adata_copy = None if inplace else adata.copy() if adata_copy is not None: check_proteodata(adata_copy) return adata_copy if min_count is not None: if min_count < 0: raise ValueError("`min_count` must be non-negative.") if max_count is not None: if max_count < 0: raise ValueError("`max_count` must be non-negative.") if (min_count is not None and max_count is not None) and (min_count > max_count): raise ValueError("`min_count` cannot be greater than `max_count`.") if protein_col not in adata.var.columns: raise KeyError(f"`protein_col`='{protein_col}' not found in adata.var") proteins = adata.var[protein_col] counts = proteins.value_counts() keep_mask = pd.Series(True, index=counts.index) if min_count is not None: keep_mask &= counts >= min_count if max_count is not None: keep_mask &= counts <= max_count protein_ids_keep = counts.index[keep_mask] var_keep_mask = proteins.isin(protein_ids_keep) if inplace: adata._inplace_subset_var(var_keep_mask.values) check_proteodata(adata) n_proteins_removed = len(counts.index) - len(protein_ids_keep) n_peptides_removed = int((~var_keep_mask).sum()) print( f"Removed {n_proteins_removed} proteins and " f"{n_peptides_removed} peptides." ) return None else: new_adata = adata[:, var_keep_mask] check_proteodata(new_adata) n_proteins_removed = len(counts.index) - len(protein_ids_keep) n_peptides_removed = int((~var_keep_mask).sum()) print( f"Removed {n_proteins_removed} proteins and " f"{n_peptides_removed} peptides." ) return new_adata
[docs] def filter_samples_by_category_count( adata, category_col, min_count=None, max_count=None, inplace=True, ): """ Filter observations by the frequency of their category value in a ``.vars`` metadata column. Parameters ---------- adata : anndata.AnnData Annotated data matrix. category_col : str Column in ``adata.obs`` containing the categories to count. min_count : int or None, optional Keep categories with at least this many observations. max_count : int or None, optional Keep categories with at most this many observations. inplace : bool, optional (default: True) If True, modify ``adata`` in place. Otherwise, return a filtered copy. Returns ------- None or anndata.AnnData ``None`` if ``inplace=True``; otherwise the filtered AnnData. """ check_proteodata(adata) if min_count is None and max_count is None: raise ValueError( "At least one argument must be passed: min_count | max_count" ) if min_count is not None and min_count < 0: raise ValueError("`min_count` must be non-negative.") if max_count is not None and max_count < 0: raise ValueError("`max_count` must be non-negative.") if ( min_count is not None and max_count is not None and min_count > max_count ): raise ValueError("`min_count` cannot be greater than `max_count`.") if category_col not in adata.obs.columns: raise KeyError(f"`category_col`='{category_col}' not found in adata.obs") obs_series = adata.obs[category_col] counts = obs_series.value_counts(dropna=False) counts_filt = counts if min_count is not None: counts_filt = counts_filt[counts_filt >= min_count] if max_count is not None: counts_filt = counts_filt[counts_filt <= max_count] obs_keep_mask = obs_series.isin(counts_filt.index) removed = int((~obs_keep_mask).sum()) print(f"Removed {removed} observations.") if inplace: adata._inplace_subset_obs(obs_keep_mask.values) check_proteodata(adata) return None new_adata = adata[obs_keep_mask, :].copy() check_proteodata(new_adata) return new_adata
[docs] def remove_zero_variance_vars( adata, group_by=None, atol=1e-8, inplace=True, ): """ Remove variables (columns) with near-zero variance, skipping NaN values. This function removes variables (e.g., peptides, proteins or features) whose variance across observations is less than or equal to a given tolerance. If a grouping variable is provided via `group_by`, a variable is removed if it has near-zero variance (≤ `atol`) in **any** group. Parameters ---------- adata : anndata.AnnData Annotated data matrix. group_by : str or None, optional (default: None) Column name in ``adata.obs`` to compute variance per group. If provided, variables are removed if their variance is ≤ `atol` within *any* group. If None, variance is computed across all observations. atol : float, optional (default: 1e-8) Absolute tolerance threshold. Variables with variance ≤ `atol` are considered to have zero variance and are removed. inplace : bool, optional (default: True) If True, modifies ``adata`` in place. Otherwise, returns a copy with low-variance variables removed. Returns ------- None or anndata.AnnData If ``inplace=True``, returns None and modifies ``adata`` in place. Otherwise, returns a new AnnData object containing only variables with variance > `atol`. Notes ----- - NaN values are ignored using ``np.nanvar`` (population variance, ddof=0). - For sparse matrices, the data is densified for variance computation. Without grouping this happens once on the full matrix; with grouping it happens per-group slice to limit peak memory. - If `group_by` is provided, any variable that has variance ≤ `atol` in *any* group is removed globally. """ check_proteodata(adata) X = adata.X n_vars = adata.n_vars is_sparse = sp.issparse(X) keep_mask = np.ones(n_vars, dtype=bool) if group_by is None: X_full = X.toarray() if is_sparse else np.asarray(X) var_all = np.nanvar(X_full, axis=0, ddof=0) keep_mask &= (var_all > atol) else: if group_by not in adata.obs.columns: raise KeyError(f"`group_by`='{group_by}' not found in adata.obs") groups = adata.obs[group_by].astype("category") zero_any = np.zeros(n_vars, dtype=bool) for g in groups.cat.categories: idx = np.where(groups.values == g)[0] if idx.size == 0: continue Xg = X[idx, :] Xg_arr = Xg.toarray() if sp.issparse(Xg) else np.asarray(Xg) vg = np.nanvar(Xg_arr, axis=0, ddof=0) zero_any |= (vg <= atol) keep_mask &= ~zero_any removed = int((~keep_mask).sum()) print(f"Removed {removed} variables.") if inplace: adata._inplace_subset_var(keep_mask) check_proteodata(adata) return None else: new_adata = adata[:, keep_mask].copy() check_proteodata(new_adata) return new_adata
[docs] def remove_contaminants( adata, contaminant_path, protein_key="protein_id", header_parser: Callable[[str], str] | None = None, inplace=False, ): """ Remove variables whose protein identifier matches a contaminant FASTA entry. Parameters ---------- adata : anndata.AnnData Annotated data. contaminant_path : str | Path Path to the contaminant list. The file can be in FASTA format, in which case the headers are parsed to extract the contaminant ids (see param: header_parser); or tabular format TSV/CSV files, in which case the first column is extracted as contaminant ids.. protein_key : str, optional (default: "protein_id") Column in ``adata.var`` containing protein identifiers to match. header_parser : callable, optional Function to extract protein IDs from FASTA headers. Defaults to splitting the header on ``"|"`` and returning the second element, falling back to the full header if not present. inplace : bool, optional (default: False) If True, modify ``adata`` in place. Otherwise, return a filtered view. Returns ------- None or anndata.AnnData ``None`` if ``inplace=True``; otherwise the filtered AnnData view. """ check_proteodata(adata) if header_parser is None: def header_parser(header: str) -> str: parts = header.split("|") return parts[1] if len(parts) > 1 else header def _load_contaminant_ids_from_fasta(fasta_path: Path) -> set[str]: contaminant_ids = set() for record in SeqIO.parse(fasta_path, "fasta"): parsed = header_parser(record.id) if parsed == "": warnings.warn( f"Header parser returned empty ID for record '{record.id}'.", ) continue contaminant_ids.add(parsed) return contaminant_ids def _load_contaminant_ids_from_table(table_path: Path, sep: str) -> set[str]: series = pd.read_csv(table_path, sep=sep, usecols=[0]).iloc[:, 0] series = series.dropna().astype(str) return set(series.tolist()) cont_path = Path(contaminant_path) if not cont_path.exists(): raise FileNotFoundError(f"Contaminant file not found at {cont_path}") if protein_key not in adata.var.columns: raise KeyError(f"`protein_key`='{protein_key}' not found in adata.var") suffix = cont_path.suffix.lower() match suffix: case ".fasta" | ".fa" | ".faa": contaminant_ids = _load_contaminant_ids_from_fasta(cont_path) case ".csv": contaminant_ids = _load_contaminant_ids_from_table(cont_path, ",") case ".tsv": contaminant_ids = _load_contaminant_ids_from_table(cont_path, "\t") case _: raise ValueError( "Unsupported contaminant file type. Use FASTA (.fasta/.fa/.faa), " "CSV (.csv), or TSV (.tsv).", ) proteins = adata.var[protein_key] keep_mask = ~proteins.isin(contaminant_ids) _, level = is_proteodata(adata) if level == "peptide": removed_peptides = int((~keep_mask).sum()) removed_proteins = int(proteins[~keep_mask].nunique()) print( f"Removed {removed_peptides} contaminating peptides and " f"{removed_proteins} contaminating proteins.", ) elif level == "protein": removed_proteins = int((~keep_mask).sum()) print(f"Removed {removed_proteins} contaminating proteins.") else: removed = int((~keep_mask).sum()) print(f"Removed {removed} contaminating variables.") if inplace: adata._inplace_subset_var(keep_mask.values) check_proteodata(adata) return None new_adata = adata[:, keep_mask] check_proteodata(new_adata) return new_adata