#!/usr/bin/env python3
"""
W.I.N.G.S. — Publication figure generator.
Produces parallel figure sets for the Agent-Based Model (ABM) and
the Wright-Fisher fixed-size Model (WFM), organised by biologically
meaningful combination subsets.
Colour scheme: Paul Tol's qualitative palette (colourblind-safe).
Line plots differentiate combos via colour + dash pattern + line width.
All figures exported as both PNG (300 dpi) and SVG (text as text objects).
Figure catalogue
----------------
For EACH model (ABM / WFM):
1–2. infection_over_time_{subset}.{png,svg}
3–4. final_infection_{subset}.{png,svg}
5–6. time_to_fixation_{subset}.{png,svg}
7. heatmap_infection.{png,svg}
8. heatmap_fixation_pct.{png,svg}
9. heatmap_fixation_time.{png,svg}
ABM-only:
10–11. population_over_time_{subset}.{png,svg}
12. heatmap_population.{png,svg}
Usage
-----
python plot_wings.py --model abm --input data/combined_abm.csv
python plot_wings.py --model wfm --input data/combined_wfm.csv
"""
import argparse
import os
import sys
import warnings
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
import pandas as pd
warnings.filterwarnings("ignore", category=FutureWarning)
# ======================================================================
# Style & constants
# ======================================================================
plt.rcParams.update({
"figure.dpi": 200,
"savefig.dpi": 300,
"savefig.bbox": "tight",
"font.family": "sans-serif",
"font.sans-serif": ["Arial", "DejaVu Sans"],
"font.size": 10,
"axes.titlesize": 12,
"axes.labelsize": 11,
"xtick.labelsize": 9,
"ytick.labelsize": 9,
"legend.fontsize": 8,
"legend.framealpha": 0.85,
"legend.edgecolor": "0.8",
"axes.spines.top": False,
"axes.spines.right": False,
"figure.facecolor": "white",
# SVG: keep text as text objects (editable in Illustrator)
"svg.fonttype": "none",
})
FIXATION_THRESHOLD = 0.99
EGG_HATCH_DAY = 23 # Tribolium egg hatching ≈ 552 hours ≈ 23 days
# ======================================================================
# Paul Tol qualitative palette (colourblind-safe)
# ======================================================================
# Source: https://personal.sron.nl/~pault/data/colourschemes.pdf
#
# We use the "bright" scheme (7 colours) extended with the "vibrant"
# scheme for additional combos. Each combo also gets a unique dash
# pattern and line width for redundant coding.
#
# The user's biological data convention:
# skyblue = uninfected, bright orange = infected, dark green = treated
# We keep these compatible: "None" (no Wolbachia effect) gets grey,
# infection-driving combos use warm/cool tones.
# -- Tol bright --
_TOL_BLUE = "#4477AA"
_TOL_CYAN = "#66CCEE"
_TOL_GREEN = "#228833"
_TOL_YELLOW = "#CCBB44"
_TOL_RED = "#EE6677"
_TOL_PURPLE = "#AA3377"
_TOL_GREY = "#BBBBBB"
# -- Tol vibrant (supplements) --
_TOL_ORANGE = "#EE7733"
_TOL_TEAL = "#009988"
_TOL_MAGENTA = "#EE3377"
_TOL_DBLUE = "#0077BB"
# Combo → (colour, dash, linewidth)
# Dash patterns: solid, dashed, dotted, dash-dot, long-dash, etc.
_SOLID = "solid"
_DASH = (0, (5, 2))
_DOT = (0, (1.5, 1.5))
_DASHDOT = (0, (5, 2, 1.5, 2))
_LDASH = (0, (8, 3))
_LDASHDOT = (0, (8, 3, 1.5, 3))
_DDASH = (0, (3, 1.5))
_DDASHDOT = (0, (3, 1.5, 1.5, 1.5))
COMBO_STYLE = {
# --- Subset A: Individual effects ---
"None": (_TOL_GREY, _SOLID, 1.6),
"CI": (_TOL_BLUE, _SOLID, 2.2),
"MK": (_TOL_RED, _SOLID, 2.0),
"ER": (_TOL_GREEN, _SOLID, 2.0),
"IE": (_TOL_PURPLE, _SOLID, 2.0),
"CI+MK+ER+IE": ("#222222", _SOLID, 2.4), # near-black for "all"
# --- Subset B: ER-centric ---
"ER+IE": (_TOL_YELLOW, _DASH, 2.0),
"MK+ER": (_TOL_ORANGE, _DASH, 2.0),
"CI+ER": (_TOL_TEAL, _DASH, 2.2),
"MK+ER+IE": (_TOL_MAGENTA, _DASHDOT, 2.0),
"CI+ER+IE": (_TOL_CYAN, _DASHDOT, 2.0),
# --- Remaining combos (for heatmap annotations, etc.) ---
"CI+IE": (_TOL_DBLUE, _DOT, 1.8),
"CI+MK": (_TOL_ORANGE, _DOT, 1.8),
"MK+IE": (_TOL_RED, _DDASH, 1.8),
"CI+MK+ER": (_TOL_TEAL, _LDASH, 1.8),
"CI+MK+IE": (_TOL_YELLOW, _DDASHDOT,1.8),
}
[docs]
def get_style(label):
"""Return (colour, dash, linewidth) for a combo label."""
return COMBO_STYLE.get(label, (_TOL_GREY, _SOLID, 1.5))
# -- Combo subsets --
SUBSET_A_NAME = "Individual Effects"
SUBSET_A = [
(False, False, False, False), # None
(False, False, True, False), # ER
(False, False, False, True), # IE
(False, True, False, False), # MK
(True, False, False, False), # CI
(True, True, True, True), # CI+MK+ER+IE
]
SUBSET_B_NAME = "ER-Centric Combinations"
SUBSET_B = [
(False, False, False, False), # None
(False, False, True, False), # ER
(False, False, True, True), # ER+IE
(False, True, True, False), # MK+ER
(True, False, True, False), # CI+ER
(False, True, True, True), # MK+ER+IE
(True, False, True, True), # CI+ER+IE
(True, True, True, True), # CI+MK+ER+IE
]
# Map mechanic abbreviations to their position in the (ci, mk, er, ie) tuple
MECHANIC_INDEX = {"CI": 0, "MK": 1, "ER": 2, "IE": 3}
[docs]
def parse_exclude(exclude_str):
"""Parse a comma-separated exclusion string into a set of mechanic names.
Args:
exclude_str (str or None): Comma-separated mechanic abbreviations,
e.g. ``"MK"`` or ``"MK,IE"``. Case-insensitive.
Returns:
set[str]: Uppercase mechanic abbreviations to exclude.
Raises:
SystemExit: If an unrecognised abbreviation is provided.
"""
if not exclude_str:
return set()
names = {s.strip().upper() for s in exclude_str.split(",")}
unknown = names - set(MECHANIC_INDEX.keys())
if unknown:
print(f" ERROR: Unknown mechanic(s): {unknown}")
print(f" Valid options: {', '.join(sorted(MECHANIC_INDEX.keys()))}")
import sys; sys.exit(1)
return names
[docs]
def filter_subset_by_exclusion(subset, excluded):
"""Remove combos from a subset that use any excluded mechanic.
Args:
subset (list[tuple]): List of ``(ci, mk, er, ie)`` boolean tuples.
excluded (set[str]): Mechanic abbreviations to exclude (e.g. ``{"MK"}``).
Returns:
list[tuple]: Filtered subset with excluded combos removed.
"""
if not excluded:
return subset
filtered = []
for combo in subset:
dominated = any(combo[MECHANIC_INDEX[m]] for m in excluded)
if not dominated:
filtered.append(combo)
return filtered
[docs]
def filter_heatmap_configs(row_configs, col_configs, excluded):
"""Remove heatmap rows/columns that contain an excluded mechanic.
Args:
row_configs (list[tuple]): Row definitions ``(ci, mk, label)``.
col_configs (list[tuple]): Column definitions ``(er, ie, label)``.
excluded (set[str]): Mechanic abbreviations to exclude.
Returns:
tuple[list, list]: Filtered ``(row_configs, col_configs)``.
"""
if not excluded:
return row_configs, col_configs
filt_rows = []
for ci, mk, label in row_configs:
if "CI" in excluded and ci:
continue
if "MK" in excluded and mk:
continue
filt_rows.append((ci, mk, label))
filt_cols = []
for er, ie, label in col_configs:
if "ER" in excluded and er:
continue
if "IE" in excluded and ie:
continue
filt_cols.append((er, ie, label))
return filt_rows, filt_cols
# ======================================================================
# Helpers
# ======================================================================
[docs]
def combo_label(ci, mk, er, ie):
"""Build a human-readable label from boolean effect flags.
Args:
ci (bool): Cytoplasmic incompatibility active.
mk (bool): Male killing active.
er (bool): Increased exploration rate active.
ie (bool): Increased eggs active.
Returns:
str: Label like ``"CI+ER"`` or ``"None"`` (no effects).
"""
parts = []
if ci: parts.append("CI")
if mk: parts.append("MK")
if er: parts.append("ER")
if ie: parts.append("IE")
return "+".join(parts) if parts else "None"
[docs]
def save_fig(fig, path_stem):
"""Save a matplotlib figure as PNG (300 dpi) and SVG (editable text).
Args:
fig (matplotlib.figure.Figure): The figure to save.
path_stem (str): Output path without extension (e.g.
``"figures/infection_rate"``). Produces ``.png`` and ``.svg``.
"""
fig.savefig(f"{path_stem}.png")
fig.savefig(f"{path_stem}.svg")
plt.close(fig)
print(f" ✓ {Path(path_stem).name}.{{png,svg}}")
[docs]
def load_data(path):
"""Load a combined simulation CSV produced by :func:`ingest.ingest_directory`.
Normalises boolean effect columns and adds a ``Combo`` label column.
Args:
path (str): Path to the combined CSV.
Returns:
pandas.DataFrame: Loaded data with boolean columns and combo labels.
"""
df = pd.read_csv(path)
for col in ["Cytoplasmic Incompatibility", "Male Killing",
"Increased Exploration Rate", "Increased Eggs"]:
df[col] = df[col].astype(str).str.strip().str.lower() == "true"
df["Combo"] = df.apply(
lambda r: combo_label(
r["Cytoplasmic Incompatibility"],
r["Male Killing"],
r["Increased Exploration Rate"],
r["Increased Eggs"],
),
axis=1,
)
# Absolute infected count = population × infection rate
df["Infected Count"] = (
df["Population Size"] * df["Infection Rate"]
).round().astype(int)
return df
[docs]
def filter_combos(df, subset):
"""Filter a DataFrame to only the specified effect combinations.
Args:
df (pandas.DataFrame): Combined simulation data.
subset (list[tuple]): List of ``(ci, mk, er, ie)`` boolean tuples.
Returns:
pandas.DataFrame: Filtered copy containing only matching rows.
"""
masks = []
for ci, mk, er, ie in subset:
m = (
(df["Cytoplasmic Incompatibility"] == ci)
& (df["Male Killing"] == mk)
& (df["Increased Exploration Rate"] == er)
& (df["Increased Eggs"] == ie)
)
masks.append(m)
return df[pd.concat(masks, axis=1).any(axis=1)].copy()
[docs]
def get_ordered_labels(subset):
"""Return combo labels in the order defined by the subset.
Args:
subset (list[tuple]): List of ``(ci, mk, er, ie)`` tuples.
Returns:
list[str]: Combo label strings in subset order.
"""
return [combo_label(ci, mk, er, ie) for ci, mk, er, ie in subset]
# ======================================================================
# Summary statistics
# ======================================================================
[docs]
def compute_timeseries_stats(df, time_col="Day"):
"""Compute per-timepoint summary statistics across replicates.
For each combination × time point, calculates median, IQR
(25th/75th percentiles), and 5th/95th percentiles for both
infection rate and population size.
Args:
df (pandas.DataFrame): Simulation data.
time_col (str): Name of the time column. Defaults to ``"Day"``.
Returns:
pandas.DataFrame: Summary statistics with columns
``inf_median``, ``inf_q25``, ``inf_q75``, etc.
"""
group_cols = ["Combo", time_col]
stats = df.groupby(group_cols).agg(
inf_median=("Infection Rate", "median"),
inf_q25=("Infection Rate", lambda x: x.quantile(0.25)),
inf_q75=("Infection Rate", lambda x: x.quantile(0.75)),
inf_q05=("Infection Rate", lambda x: x.quantile(0.05)),
inf_q95=("Infection Rate", lambda x: x.quantile(0.95)),
pop_median=("Population Size", "median"),
pop_q25=("Population Size", lambda x: x.quantile(0.25)),
pop_q75=("Population Size", lambda x: x.quantile(0.75)),
ninf_median=("Infected Count", "median"),
ninf_q25=("Infected Count", lambda x: x.quantile(0.25)),
ninf_q75=("Infected Count", lambda x: x.quantile(0.75)),
ninf_q05=("Infected Count", lambda x: x.quantile(0.05)),
ninf_q95=("Infected Count", lambda x: x.quantile(0.95)),
).reset_index()
return stats
[docs]
def compute_final_values(df, time_col="Day"):
"""Extract the final time-point values for each replicate.
Args:
df (pandas.DataFrame): Simulation data.
time_col (str): Name of the time column. Defaults to ``"Day"``.
Returns:
pandas.DataFrame: One row per combo × replicate, containing
the last recorded infection rate and population size.
"""
idx = df.groupby(["Combo", "Replicate ID"])[time_col].idxmax()
return df.loc[idx].copy()
[docs]
def compute_time_to_fixation(df, time_col="Day", threshold=FIXATION_THRESHOLD):
"""Find the first time point where infection rate ≥ threshold.
Args:
df (pandas.DataFrame): Simulation data.
time_col (str): Time column name. Defaults to ``"Day"``.
threshold (float): Fixation threshold. Defaults to ``0.99``.
Returns:
pandas.DataFrame: Columns ``Combo``, ``Replicate ID``,
``t_fix`` (NaN if fixation never reached).
"""
records = []
for (combo, rep), grp in df.groupby(["Combo", "Replicate ID"]):
above = grp[grp["Infection Rate"] >= threshold]
t_fix = above[time_col].min() if len(above) > 0 else np.nan
records.append({"Combo": combo, "Replicate ID": rep, "t_fix": t_fix})
return pd.DataFrame(records)
# ======================================================================
# Plot: time series with ribbons
# ======================================================================
[docs]
def plot_timeseries(
df, subset, subset_name, metric, ylabel, title,
path_stem, time_col="Day", is_abm=True, skip_before=None
):
"""Line + ribbon plot of a metric over time for each combination.
Draws median line with IQR ribbon and 5th–95th percentile
whisker band. ABM plots use symlog x-axis; WFM uses linear.
Args:
df (pandas.DataFrame): Simulation data.
subset (list[tuple]): Effect combinations to include.
subset_name (str): Human-readable subset name for title.
metric (str): Column prefix (``"inf"`` or ``"pop"``).
ylabel (str): Y-axis label.
title (str): Figure title.
path_stem (str): Output path without extension.
time_col (str): Time column name. Defaults to ``"Day"``.
is_abm (bool): Use ABM-specific formatting. Defaults to ``True``.
skip_before (int, optional): Skip time points before this day.
"""
sub = filter_combos(df, subset)
labels = get_ordered_labels(subset)
stats = compute_timeseries_stats(sub, time_col)
fig, ax = plt.subplots(figsize=(8, 4.5))
for label in labels:
s = stats[stats["Combo"] == label].sort_values(time_col)
if s.empty:
continue
t = s[time_col].values
color, dash, lw = get_style(label)
if skip_before is not None:
mask = t >= skip_before
t = t[mask]
s = s.iloc[mask]
if len(t) == 0:
continue
med = s[f"{metric}_median"].values
q25 = s[f"{metric}_q25"].values
q75 = s[f"{metric}_q75"].values
# Light ribbon: 5th–95th (infection only)
if f"{metric}_q05" in s.columns:
q05 = s[f"{metric}_q05"].values
q95 = s[f"{metric}_q95"].values
ax.fill_between(t, q05, q95, alpha=0.07, color=color, linewidth=0)
# IQR ribbon
ax.fill_between(t, q25, q75, alpha=0.18, color=color, linewidth=0)
# Median line with dash pattern
ax.plot(t, med, color=color, linewidth=lw, linestyle=dash, label=label)
# Semilog x for ABM only
if is_abm:
ax.set_xscale("symlog", linthresh=10)
ax.xaxis.set_major_locator(mticker.FixedLocator(
[1, 5, 10, 25, 50, 100, 200, 365]
))
ax.xaxis.set_major_formatter(mticker.ScalarFormatter())
ax.set_xlim(left=max(1, skip_before or 1))
else:
ax.set_xlim(0, df[time_col].max())
ax.xaxis.set_major_locator(mticker.MaxNLocator(integer=True))
# Egg hatch marker (ABM only)
if is_abm and skip_before is None:
ax.axvline(EGG_HATCH_DAY, color="#cccccc", ls="--", lw=0.8, zorder=0)
ax.text(
EGG_HATCH_DAY + 1, ax.get_ylim()[1] * 0.02,
"eggs hatch", fontsize=7, color="#999999", va="bottom",
)
ax.set_xlabel("Generation" if not is_abm else "Day")
ax.set_ylabel(ylabel)
ax.set_title(title, fontweight="bold", pad=10)
# Legend with line styles shown
leg = ax.legend(loc="best", ncol=2 if len(labels) > 4 else 1,
handlelength=3.0)
if metric == "inf":
ax.set_ylim(-0.02, 1.05)
ax.axhline(1.0, color="#eeeeee", ls=":", lw=0.7, zorder=0)
fig.tight_layout()
save_fig(fig, path_stem)
# ======================================================================
# Plot: final infection — strip + bar hybrid
# ======================================================================
[docs]
def plot_final_infection(
df, subset, subset_name, title, path_stem, time_col="Day"
):
"""Strip plot of final infection rate with fixation-% annotation.
Each replicate is a jittered dot; a diamond shows the median;
a vertical bar shows the IQR. Fixation percentage is annotated
above each combination.
Args:
df (pandas.DataFrame): Simulation data.
subset (list[tuple]): Effect combinations to include.
subset_name (str): Subset label for title.
title (str): Figure title.
path_stem (str): Output path without extension.
time_col (str): Time column. Defaults to ``"Day"``.
"""
sub = filter_combos(df, subset)
labels = get_ordered_labels(subset)
finals = compute_final_values(sub, time_col)
fig, ax = plt.subplots(figsize=(8, 4.5))
rng = np.random.default_rng(42)
for i, label in enumerate(labels):
vals = finals.loc[finals["Combo"] == label, "Infection Rate"].values
if len(vals) == 0:
continue
color, _, _ = get_style(label)
# Jittered strip
jitter = rng.uniform(-0.25, 0.25, len(vals))
ax.scatter(
np.full(len(vals), i) + jitter, vals,
color=color, s=10, alpha=0.35, zorder=3, edgecolors="none",
)
# Median diamond
med = np.median(vals)
ax.scatter(
[i], [med], color=color, s=70, zorder=5,
marker="D", edgecolors="white", linewidths=0.8,
)
# IQR bar
q25, q75 = np.percentile(vals, [25, 75])
ax.plot([i, i], [q25, q75], color=color, lw=2.5, solid_capstyle="round",
zorder=4, alpha=0.7)
# Fixation % annotation above
n_fixed = (vals >= FIXATION_THRESHOLD).sum()
pct = 100 * n_fixed / len(vals)
ax.text(
i, 1.07, f"{pct:.0f}%",
ha="center", va="bottom", fontsize=7.5, color=color,
fontweight="bold",
)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=30, ha="right")
ax.set_ylabel("Final Infection Rate")
ax.set_ylim(-0.05, 1.18)
ax.axhline(1.0, color="#eeeeee", ls=":", lw=0.7, zorder=0)
# Add subtle "% fixed" header
ax.text(
0.5, 1.14, "% reaching fixation",
ha="center", va="bottom", fontsize=7, color="#888888",
style="italic", transform=ax.get_xaxis_transform(),
)
ax.set_title(title, fontweight="bold", pad=10)
fig.tight_layout()
save_fig(fig, path_stem)
# --- CSV export ---
csv_rows = []
for label in labels:
vals = finals.loc[finals["Combo"] == label, "Infection Rate"].values
for v in vals:
csv_rows.append({"mechanic": label, "final_infection_rate": v})
pd.DataFrame(csv_rows).to_csv(f"{path_stem}.csv", index=False)
print(f" ✓ {Path(path_stem).name}.csv")
# ======================================================================
# Plot: generic final-value strip plot (reusable for any metric)
# ======================================================================
[docs]
def plot_final_strip(
df, subset, subset_name, metric_col, ylabel, title, path_stem,
time_col="Day", csv_col_name=None, fixation_annot=False,
):
"""
Strip + median diamond + IQR bar for any final-timepoint metric.
Exports companion CSV with columns: mechanic, {csv_col_name}.
"""
sub = filter_combos(df, subset)
labels = get_ordered_labels(subset)
finals = compute_final_values(sub, time_col)
csv_name = csv_col_name or metric_col.lower().replace(" ", "_")
fig, ax = plt.subplots(figsize=(8, 4.5))
rng = np.random.default_rng(42)
for i, label in enumerate(labels):
vals = finals.loc[finals["Combo"] == label, metric_col].values
if len(vals) == 0:
continue
color, _, _ = get_style(label)
jitter = rng.uniform(-0.25, 0.25, len(vals))
ax.scatter(
np.full(len(vals), i) + jitter, vals,
color=color, s=10, alpha=0.35, zorder=3, edgecolors="none",
)
med = np.median(vals)
ax.scatter(
[i], [med], color=color, s=70, zorder=5,
marker="D", edgecolors="white", linewidths=0.8,
)
q25, q75 = np.percentile(vals, [25, 75])
ax.plot([i, i], [q25, q75], color=color, lw=2.5,
solid_capstyle="round", zorder=4, alpha=0.7)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=30, ha="right")
ax.set_ylabel(ylabel)
ax.set_title(title, fontweight="bold", pad=10)
fig.tight_layout()
save_fig(fig, path_stem)
# --- CSV export ---
csv_rows = []
for label in labels:
vals = finals.loc[finals["Combo"] == label, metric_col].values
for v in vals:
csv_rows.append({"mechanic": label, csv_name: v})
pd.DataFrame(csv_rows).to_csv(f"{path_stem}.csv", index=False)
print(f" ✓ {Path(path_stem).name}.csv")
# ======================================================================
# Plot: time to fixation — violin + strip
# ======================================================================
[docs]
def plot_time_to_fixation(
df, subset, subset_name, title, path_stem,
time_col="Day", is_abm=True
):
"""Violin + strip plot of time-to-fixation across replicates.
Combinations that never reach fixation are listed as italic
annotations.
Args:
df (pandas.DataFrame): Simulation data.
subset (list[tuple]): Effect combinations.
subset_name (str): Subset label.
title (str): Figure title.
path_stem (str): Output path without extension.
time_col (str): Time column. Defaults to ``"Day"``.
is_abm (bool): ABM formatting. Defaults to ``True``.
"""
sub = filter_combos(df, subset)
labels = get_ordered_labels(subset)
ttf = compute_time_to_fixation(sub, time_col)
fig, ax = plt.subplots(figsize=(8, 4.5))
rng = np.random.default_rng(42)
positions_used = []
labels_used = []
never_fixed = []
pos = 0
for label in labels:
group = ttf[ttf["Combo"] == label]
vals = group["t_fix"].dropna().values
total = len(group)
n_fixed = len(vals)
color, _, _ = get_style(label)
if n_fixed == 0:
never_fixed.append(label)
continue
positions_used.append(pos)
labels_used.append(label)
# Violin (if enough data)
if n_fixed >= 3:
parts = ax.violinplot(vals, positions=[pos], showextrema=False, widths=0.6)
for pc in parts["bodies"]:
pc.set_facecolor(color)
pc.set_alpha(0.35)
pc.set_edgecolor(color)
pc.set_linewidth(0.5)
# Median marker
ax.scatter([pos], [np.median(vals)], color=color,
s=50, zorder=5, marker="D",
edgecolors="white", linewidths=0.6)
# Jittered strip
jitter = rng.uniform(-0.15, 0.15, len(vals))
ax.scatter(
np.full(len(vals), pos) + jitter, vals,
color=color, s=6, alpha=0.3, zorder=4, edgecolors="none",
)
# Annotate n_fixed / total
if n_fixed < total:
ax.text(
pos, ax.get_ylim()[0] if ax.get_ylim()[0] != 0 else vals.min() - 0.5,
f"{n_fixed}/{total}",
ha="center", va="top", fontsize=6.5, color=color,
)
pos += 1
if never_fixed:
note = "No fixation: " + ", ".join(never_fixed)
ax.text(
0.02, 0.98, note, transform=ax.transAxes,
fontsize=7.5, va="top", color="#888888", style="italic",
)
ax.set_xticks(positions_used)
ax.set_xticklabels(labels_used, rotation=30, ha="right")
ax.set_ylabel("Generation" if not is_abm else "Day")
ax.set_title(title, fontweight="bold", pad=10)
fig.tight_layout()
save_fig(fig, path_stem)
# --- CSV export (all replicates, NaN if never fixed) ---
csv_rows = []
for label in labels:
group = ttf[ttf["Combo"] == label]
for _, row in group.iterrows():
csv_rows.append({
"mechanic": label,
"time_to_fixation": row["t_fix"],
})
pd.DataFrame(csv_rows).to_csv(f"{path_stem}.csv", index=False)
print(f" ✓ {Path(path_stem).name}.csv")
# ======================================================================
# Plot: heatmaps (all 16 combos)
# ======================================================================
[docs]
def plot_heatmap(df, metric_func, cmap, cbar_label, title, path_stem,
time_col="Day", fmt=".2f", vmin=None, vmax=None,
csv_raw_func=None, csv_value_col="value", excluded=None):
"""Heatmap of effect combinations, with optional mechanic exclusion.
Rows represent the CI × MK axis; columns represent the ER × IE axis.
When mechanics are excluded via ``excluded``, the corresponding rows
or columns are removed and the heatmap shrinks accordingly.
Args:
df (pandas.DataFrame): Simulation data.
metric_func (callable): Function ``(df_subset, time_col) → float``
computing the cell value.
cmap (str): Matplotlib colourmap name.
cbar_label (str): Colourbar label.
title (str): Figure title.
path_stem (str): Output path without extension.
time_col (str): Time column. Defaults to ``"Day"``.
fmt (str): Number format string. Defaults to ``".2f"``.
vmin (float, optional): Colourmap minimum.
vmax (float, optional): Colourmap maximum.
csv_raw_func (callable, optional): Per-replicate extraction
function for CSV export.
excluded (set[str], optional): Mechanic abbreviations to exclude.
"""
if excluded is None:
excluded = set()
row_configs = [
(False, False, "—"),
(False, True, "MK"),
(True, False, "CI"),
(True, True, "CI+MK"),
]
col_configs = [
(False, False, "—"),
(False, True, "IE"),
(True, False, "ER"),
(True, True, "ER+IE"),
]
row_configs, col_configs = filter_heatmap_configs(
row_configs, col_configs, excluded)
n_rows = len(row_configs)
n_cols = len(col_configs)
if n_rows == 0 or n_cols == 0:
print(f" [skip] {Path(path_stem).name} — all rows or columns excluded")
return
matrix = np.full((n_rows, n_cols), np.nan)
annot = np.empty((n_rows, n_cols), dtype=object)
for ri, (r_ci, r_mk, _) in enumerate(row_configs):
for ci_col, (c_er, c_ie, _) in enumerate(col_configs):
sub = df[
(df["Cytoplasmic Incompatibility"] == r_ci)
& (df["Male Killing"] == r_mk)
& (df["Increased Exploration Rate"] == c_er)
& (df["Increased Eggs"] == c_ie)
]
if len(sub) == 0:
annot[ri, ci_col] = "—"
continue
val = metric_func(sub, time_col)
matrix[ri, ci_col] = val
annot[ri, ci_col] = "—" if np.isnan(val) else f"{val:{fmt}}"
fig, ax = plt.subplots(figsize=(max(3.5, 1.4 * n_cols + 1.5), max(3, 1.1 * n_rows + 1.5)))
# Use explicit vmin/vmax so we can judge text colour
v0 = vmin if vmin is not None else np.nanmin(matrix)
v1 = vmax if vmax is not None else np.nanmax(matrix)
im = ax.imshow(matrix, cmap=cmap, aspect="auto", vmin=v0, vmax=v1)
for ri in range(n_rows):
for ci_col in range(n_cols):
val = matrix[ri, ci_col]
# Text colour: white on dark cells, black on light
if np.isnan(val):
tc = "#999999"
else:
norm_val = (val - v0) / (v1 - v0 + 1e-9)
tc = "white" if norm_val > 0.55 else "black"
ax.text(
ci_col, ri, annot[ri, ci_col],
ha="center", va="center", fontsize=11, fontweight="bold",
color=tc,
)
ax.set_xticks(range(n_cols))
ax.set_xticklabels([c[2] for c in col_configs])
ax.set_yticks(range(n_rows))
ax.set_yticklabels([r[2] for r in row_configs])
ax.set_xlabel("Exploration / Fecundity axis", fontsize=10)
ax.set_ylabel("CI / MK axis", fontsize=10)
ax.set_title(title, fontweight="bold", pad=12)
cbar = fig.colorbar(im, ax=ax, shrink=0.85, pad=0.08)
cbar.set_label(cbar_label, fontsize=9)
fig.tight_layout()
save_fig(fig, path_stem)
# --- CSV export: per-replicate raw data for all 16 combos ---
if csv_raw_func is not None:
csv_rows = []
for ri, (r_ci, r_mk, _) in enumerate(row_configs):
for ci_col, (c_er, c_ie, _) in enumerate(col_configs):
sub = df[
(df["Cytoplasmic Incompatibility"] == r_ci)
& (df["Male Killing"] == r_mk)
& (df["Increased Exploration Rate"] == c_er)
& (df["Increased Eggs"] == c_ie)
]
if len(sub) == 0:
continue
label = combo_label(r_ci, r_mk, c_er, c_ie)
raw = csv_raw_func(sub, time_col)
for rec in raw:
rec["mechanic"] = label
csv_rows.extend(raw)
out_df = pd.DataFrame(csv_rows)
# Reorder so mechanic is first
cols = ["mechanic"] + [c for c in out_df.columns if c != "mechanic"]
out_df[cols].to_csv(f"{path_stem}.csv", index=False)
print(f" ✓ {Path(path_stem).name}.csv")
# -- Heatmap metric functions (summary for cell values) --
[docs]
def metric_fixation_pct(df_sub, time_col):
n_total = df_sub["Replicate ID"].nunique()
if n_total == 0:
return np.nan
n_fixed = sum(
grp["Infection Rate"].max() >= FIXATION_THRESHOLD
for _, grp in df_sub.groupby("Replicate ID")
)
return 100.0 * n_fixed / n_total
# -- Per-replicate raw extraction functions (for CSV export) --
[docs]
def raw_final_infection(df_sub, time_col):
"""One row per replicate: final infection rate."""
idx = df_sub.groupby("Replicate ID")[time_col].idxmax()
rows = df_sub.loc[idx]
return [
{"replicate": int(r["Replicate ID"]),
"final_infection_rate": r["Infection Rate"]}
for _, r in rows.iterrows()
]
[docs]
def raw_fixation_binary(df_sub, time_col):
"""One row per replicate: 1 if reached fixation, 0 otherwise."""
records = []
for rep, grp in df_sub.groupby("Replicate ID"):
fixed = int(grp["Infection Rate"].max() >= FIXATION_THRESHOLD)
records.append({"replicate": int(rep), "reached_fixation": fixed})
return records
[docs]
def raw_time_to_fixation(df_sub, time_col):
"""One row per replicate: time to fixation (NaN if never)."""
records = []
for rep, grp in df_sub.groupby("Replicate ID"):
above = grp[grp["Infection Rate"] >= FIXATION_THRESHOLD]
t = above[time_col].min() if len(above) > 0 else np.nan
records.append({"replicate": int(rep), "time_to_fixation": t})
return records
[docs]
def raw_final_population(df_sub, time_col):
"""One row per replicate: final population size."""
idx = df_sub.groupby("Replicate ID")[time_col].idxmax()
rows = df_sub.loc[idx]
return [
{"replicate": int(r["Replicate ID"]),
"final_population_size": r["Population Size"]}
for _, r in rows.iterrows()
]
[docs]
def raw_final_infected_count(df_sub, time_col):
"""One row per replicate: final number of infected beetles."""
idx = df_sub.groupby("Replicate ID")[time_col].idxmax()
rows = df_sub.loc[idx]
return [
{"replicate": int(r["Replicate ID"]),
"final_infected_count": int(r["Infected Count"])}
for _, r in rows.iterrows()
]
# ======================================================================
# Main figure generation pipeline
# ======================================================================
# ======================================================================
# CLI
# ======================================================================
[docs]
def main():
"""CLI entry point for the publication figure generator.
Parses ``--model``, ``--input``, ``--outdir`` arguments,
loads data, and generates all figures.
"""
parser = argparse.ArgumentParser(
description="W.I.N.G.S. — Publication figure generator",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python plot_wings.py --model abm --input data/combined_abm.csv
python plot_wings.py --model wfm --input data/combined_wfm.csv
python plot_wings.py --model wfm --input data.csv --outdir ./my_figs
""",
)
parser.add_argument("--model", required=True, choices=["abm", "wfm"])
parser.add_argument("--input", required=True, help="Combined CSV from ingest_data.py")
parser.add_argument("--outdir", default=None, help="Output directory (default: figures_{model}/)")
parser.add_argument("--exclude", default=None,
help="Comma-separated Wolbachia mechanics to exclude from "
"all figures. Valid: CI, MK, ER, IE. "
"Example: --exclude MK or --exclude MK,IE")
args = parser.parse_args()
excluded = parse_exclude(args.exclude)
outdir = args.outdir or f"figures_{args.model}"
print("=" * 56)
print(" W.I.N.G.S. — Figure Generator")
print("=" * 56)
print(f" Model: {args.model.upper()}")
print(f" Input: {args.input}")
print(f" Output: {outdir}/")
if excluded:
print(f" Exclude: {', '.join(sorted(excluded))}")
df = load_data(args.input)
n_combos = df["Combo"].nunique()
n_reps = df["Replicate ID"].nunique()
n_time = df["Day"].max()
print(f" Combos: {n_combos} | Reps: {n_reps} | Max time: {n_time}")
print("=" * 56)
generate_figures(df, args.model, outdir, time_col="Day", excluded=excluded)
n_png = len([f for f in os.listdir(outdir) if f.endswith(".png")])
n_csv = len([f for f in os.listdir(outdir) if f.endswith(".csv")])
print(f"\n Done — {n_png} PNG + {n_png} SVG + {n_csv} CSV saved to {outdir}/")
if __name__ == "__main__":
main()
#Example usage: python -m wings.analysis.plot_wings --model wfm --input data/combined_wfm.csv --exclude MK
# python -m wings.analysis.plot_wings --model abm --input data/combined_abm05.csv --exclude MK
# python -m wings.analysis.plot_wings --model abm --input data/combined_abm.csv --exclude MK
# python -m wings.analysis.plot_wings --model abm --input data/combined.csv