Source code for mkado.io.plotting

"""Plotting functions for MK test results."""

from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
    from mkado.analysis.asymptotic import AsymptoticMKResult
    from mkado.analysis.mk_test import MKResult
    from mkado.analysis.polarized import PolarizedMKResult

logger = logging.getLogger(__name__)


[docs] def create_volcano_plot( results: list[tuple[str, MKResult | PolarizedMKResult]], output_path: Path, alpha: float = 0.05, ) -> None: """Create a volcano plot from batch MK test results. The volcano plot shows -log10(NI) on the X-axis and -log10(p-value) on the Y-axis. A horizontal line indicates the Bonferroni-corrected significance threshold. Args: results: List of (gene_name, result) tuples from batch MK tests output_path: Path to save the plot (PNG, PDF, or SVG) alpha: Significance level for Bonferroni correction (default 0.05) """ import matplotlib.pyplot as plt import seaborn as sns from mkado.analysis.mk_test import MKResult from mkado.analysis.polarized import PolarizedMKResult # Set seaborn dark grid style for modern look sns.set_theme(style="darkgrid") # Extract NI and p-values, tracking why genes are excluded ni_values = [] p_values = [] gene_names = [] skipped_ni_none = 0 # NI undefined (Dn=0 or Ds=0 or Ps=0) skipped_ni_zero = 0 # NI=0 (Pn=0) skipped_pval = 0 # p-value invalid for name, result in results: if isinstance(result, MKResult): ni = result.ni pval = result.p_value elif isinstance(result, PolarizedMKResult): ni = result.ni_ingroup pval = result.p_value_ingroup else: continue # Skip if NI is None or invalid if ni is None: skipped_ni_none += 1 logger.debug("volcano: skipping %s (NI undefined — Dn, Ds, or Ps is zero)", name) continue if ni <= 0: skipped_ni_zero += 1 logger.debug("volcano: skipping %s (NI=0 — Pn is zero)", name) continue if pval <= 0: skipped_pval += 1 continue ni_values.append(ni) p_values.append(pval) gene_names.append(name) n_skipped = skipped_ni_none + skipped_ni_zero + skipped_pval if n_skipped > 0: parts = [] if skipped_ni_none > 0: parts.append(f"{skipped_ni_none} with undefined NI (Dn, Ds, or Ps = 0)") if skipped_ni_zero > 0: parts.append(f"{skipped_ni_zero} with NI = 0 (Pn = 0)") if skipped_pval > 0: parts.append(f"{skipped_pval} with invalid p-value") logger.warning( "Volcano plot: %d/%d genes excluded — %s", n_skipped, len(results), "; ".join(parts), ) if not ni_values: raise ValueError("No valid NI/p-value pairs found in results") # Convert to numpy arrays ni_arr = np.array(ni_values) p_arr = np.array(p_values) # Calculate -log10 values neg_log10_ni = -np.log10(ni_arr) neg_log10_p = -np.log10(p_arr) # Bonferroni correction threshold n_tests = len(results) bonferroni_threshold = alpha / n_tests neg_log10_bonf = -np.log10(bonferroni_threshold) # FDR threshold (nominal significance level) fdr_threshold = alpha neg_log10_fdr = -np.log10(fdr_threshold) # Create figure fig, ax = plt.subplots(figsize=(10, 8)) # Determine point colors based on significance # Red: significant after Bonferroni, Orange: significant at FDR but not Bonferroni, Blue: not significant sig_bonf = p_arr < bonferroni_threshold sig_fdr = (p_arr < fdr_threshold) & ~sig_bonf colors = np.where(sig_bonf, "#e74c3c", np.where(sig_fdr, "#e67e22", "#3498db")) ax.scatter( neg_log10_ni, neg_log10_p, c=colors, alpha=0.7, edgecolors="white", linewidth=0.5, s=60, ) # Add Bonferroni threshold line ax.axhline( y=neg_log10_bonf, color="#e74c3c", linestyle="--", linewidth=1.5, label=f"Bonferroni (p = {bonferroni_threshold:.2e})", ) # Add FDR threshold line ax.axhline( y=neg_log10_fdr, color="#e67e22", linestyle="--", linewidth=1.5, label=f"Nominal (p = {fdr_threshold})", ) # Add vertical line at -log10(NI) = 0, i.e. NI = 1 (neutral expectation) ax.axvline( x=0, color="#2c3e50", linestyle=":", linewidth=1.5, label="Neutral (NI = 1)", ) # Labels and title ax.set_xlabel("-log$_{10}$(NI)", fontsize=12) ax.set_ylabel("-log$_{10}$(p-value)", fontsize=12) # Add legend ax.legend(loc="upper right", framealpha=0.9) # Add annotation for interpretation xlim = ax.get_xlim() ax.text( xlim[0] + 0.02 * (xlim[1] - xlim[0]), neg_log10_bonf + 0.3, f"n = {len(ni_values)} genes, {sum(sig_bonf)} Bonferroni, {sum(sig_fdr)} nominal", fontsize=9, color="#7f8c8d", ) # Tight layout and save plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig)
[docs] def create_asymptotic_plot( result: AsymptoticMKResult, output_path: Path, ) -> None: """Create an asymptotic MK test plot showing alpha(x) vs derived allele frequency. This plot follows the style of Messer & Petrov (2013), showing: - Scatter points of alpha at each frequency bin - The fitted curve (exponential or linear) - A horizontal line at alpha_asymptotic with confidence interval band Args: result: AsymptoticMKResult from asymptotic MK test output_path: Path to save the plot (PNG, PDF, or SVG) """ import matplotlib.pyplot as plt import seaborn as sns # Set seaborn dark grid style for modern look sns.set_theme(style="darkgrid") # Extract data - use alpha_x_values if available, otherwise fall back to frequency_bins if result.alpha_x_values and len(result.alpha_x_values) == len(result.alpha_by_freq): x_data = np.array(result.alpha_x_values) else: x_data = np.array(result.frequency_bins) y_data = np.array(result.alpha_by_freq) if len(x_data) == 0 or len(y_data) == 0: raise ValueError("No frequency bin data available for plotting") if len(x_data) != len(y_data): raise ValueError(f"Mismatched data: {len(x_data)} x values vs {len(y_data)} alpha values") # Create figure fig, ax = plt.subplots(figsize=(10, 8)) # Plot scatter points ax.scatter( x_data, y_data, c="#2c3e50", alpha=0.8, edgecolors="white", linewidth=0.5, s=80, zorder=3, label="Observed α(x)", ) # Plot confidence interval band for alpha_asymptotic ax.fill_between( [0, 1], [result.ci_low, result.ci_low], [result.ci_high, result.ci_high], color="#95a5a6", alpha=0.3, zorder=1, label=f"95% CI ({result.ci_low:.2f} - {result.ci_high:.2f})", ) # Plot alpha_asymptotic horizontal line ax.axhline( y=result.alpha_asymptotic, color="#e74c3c", linestyle="--", linewidth=2, zorder=2, label=f"α$_{{asym}}$ = {result.alpha_asymptotic:.2f}", ) # Plot fitted curve x_curve = np.linspace(0.05, 1.0, 100) if result.model_type == "exponential": y_curve = result.fit_a + result.fit_b * np.exp(-result.fit_c * x_curve) fit_label = "Fit: a + b·e$^{-cx}$" else: y_curve = result.fit_a + result.fit_b * x_curve fit_label = "Fit: a + b·x" ax.plot( x_curve, y_curve, color="#e74c3c", linewidth=2, zorder=2, label=fit_label, ) # Labels and title — Uricchio et al. 2019 cumulative SFS uses α(>x) instead. ax.set_xlabel("Derived allele frequency x", fontsize=12) if result.sfs_mode == "above": ax.set_ylabel("MK α(>x)", fontsize=12) ax.set_title("Asymptotic MK Test: α(>x) vs Frequency", fontsize=14, fontweight="bold") else: ax.set_ylabel("MK α(x)", fontsize=12) ax.set_title("Asymptotic MK Test: α(x) vs Frequency", fontsize=14, fontweight="bold") # Set axis limits ax.set_xlim(0, 1) # Add legend ax.legend(loc="lower right", framealpha=0.9) # Add annotation with fit parameters if result.model_type == "exponential": param_text = f"a = {result.fit_a:.3f}, b = {result.fit_b:.3f}, c = {result.fit_c:.3f}" else: param_text = f"a = {result.fit_a:.3f}, b = {result.fit_b:.3f}" # Add gene count if aggregated if result.num_genes > 0: param_text += f"\nn = {result.num_genes} genes" ax.text( 0.02, 0.98, param_text, transform=ax.transAxes, fontsize=9, verticalalignment="top", color="#7f8c8d", family="monospace", ) # Tight layout and save plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close(fig)