Source code for wings.models.gpu_abm

"""
W.I.N.G.S. GPU-Accelerated Simulation Engine
=============================================
Replaces the per-beetle Python loop with fully vectorized PyTorch tensor
operations.  Designed for NVIDIA L40S (48 GB VRAM) but will run on any
CUDA device (or CPU as fallback).

Scaling strategy
----------------
The original ABM is O(F·M) per time-step for the mating search because every
female is checked against every male.  At N = 20 000 that means ~100 M distance
evaluations per step – trivial on an L40S but wasteful in memory for N > 50 000.

This module provides **two mating backends**:

1. ``brute``  – compute the full female × male distance matrix on GPU.
   Memory: O(F·M) floats ≈ 400 MB at N = 20 000.  Fast and simple.

2. ``cell_list`` – partition the toroidal grid into cells of side length
   ≥ mating_distance, then only check the 3×3 neighbourhood of each
   female's cell.  Memory: O(N · k) where k is the mean number of
   neighbours.  Scales to N > 100 000.

Usage
-----
>>> from gpu_simulation import GPUSimulation, SimConfig
>>> cfg = SimConfig(initial_population=20_000, max_population=25_000,
...                 grid_size=500, wolbachia_effects={'cytoplasmic_incompatibility': True,
...                 'male_killing': False, 'increased_exploration_rate': True,
...                 'increased_eggs': True, 'reduced_eggs': False})
>>> sim = GPUSimulation(cfg)
>>> for day in range(365):
...     sim.step_one_day()          # 24 hourly sub-steps
...     print(sim.get_infection_rate())

Author: Adapted from WINGS ABM (Geurten et al.)
"""

import torch
import numpy as np
from dataclasses import dataclass, field
from typing import Dict, Optional, List, Tuple
import time


# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
[docs] @dataclass class SimConfig: """Configuration dataclass for the GPU ABM simulation. All simulation parameters are set here and passed to :class:`GPUSimulation`. Immutable after construction. Attributes: initial_population (int): Starting number of adult beetles. max_population (int): Carrying capacity (K). max_eggs (int): Egg buffer cap. grid_size (int): Side length of the toroidal grid. ci_strength (float): CI intensity (0–1). mortality_mode (str): Density-dependent mortality type. One of ``"none"``, ``"logistic"``, ``"cannibalism"``, ``"contest"``. mortality_beta (float): Exponent for density dependence. cannibalism_rate (float): Egg cannibalism rate at N=K. mating_backend (str): ``"brute"`` (O(F·M) distance matrix) or ``"cell_list"`` (O(N·k) spatial hashing). device (str): ``"cuda"`` or ``"cpu"``. seed (int, optional): Random seed for reproducibility. wolbachia_effects (dict): Boolean toggles for CI, MK, ER, IE, RE. infected_fraction (float): Initial infection prevalence. """ # --- Population --- initial_population: int = 50 max_population: int = 20_000 # carrying capacity for adults max_eggs: int = 800_000 # egg buffer cap (must be large: eggs take 23 days to # hatch, so the pipeline holds ~daily_eggs × 23 days) infected_fraction: float = 0.10 male_to_female_ratio: float = 0.50 # fraction male among initial uninfected # --- Spatial --- grid_size: int = 500 # side length of the toroidal arena mating_distance: float = 5.0 # --- Wolbachia effects (bool flags) --- wolbachia_effects: Dict[str, bool] = field(default_factory=lambda: { 'cytoplasmic_incompatibility': True, 'male_killing': False, 'increased_exploration_rate': False, 'increased_eggs': False, 'reduced_eggs': False, }) # --- Reproduction parameters --- egg_laying_max: int = 15 ci_strength: float = 1.0 # 0.0–1.0 fecundity_increase_factor: float = 1.2 fecundity_decrease_factor: float = 0.85 male_offspring_rate: float = 0.10 # under male-killing exploration_rate_boost: float = 1.4 # radius multiplier for infected females mating_cooldown_female: int = 48 # hours mating_cooldown_male: int = 5 # hours (~48/10) multiple_mating: bool = True # allow up to 2 mates per female per step egg_hatching_age: int = 552 # hours (~23 days) # --- Life expectancy (hours) --- life_expectancy_min: int = 280 * 24 # ~280 days life_expectancy_max: int = 450 * 24 # ~450 days initial_age_min: int = 889 # hours (~37 days) initial_age_max: int = 2500 # hours (~104 days) # --- Levy flight --- levy_alpha: float = 1.5 # Pareto shape parameter # --- Density-dependent mortality mode --- # Controls how population regulation occurs beyond logistic birth suppression. # 'none' : Only natural death (age > max_life) + hard cap. Original ABM behavior. # Logistic birth suppression is applied to prevent runaway growth. # 'logistic' : Per-capita adult death rate increases linearly with N/K. # 'cannibalism' : Beetle-specific: adults destroy eggs proportional to density. # Based on Tribolium literature (Daly & Ryan 1983, Park 1934). # 'contest' : Above K, excess adults die each hour with probability ∝ (N-K)/N. # # IMPORTANT: When mortality_mode != 'none', logistic birth suppression is DISABLED # so that CI can operate at carrying capacity. Population regulation comes from # the density-dependent mortality instead. mortality_mode: str = 'cannibalism' # Exponent for density-dependent effects (higher = sharper response near K) mortality_beta: float = 2.0 # Egg cannibalism rate: per-adult per-hour probability of eating one egg at N = K. # Calibrated so that at steady state (N = K = 20,000), egg survival through the # 552-hour pipeline balances adult mortality: # P(eaten/hr at K) = rate × K × (K/K)^β = rate × K ≈ 0.012 # P(survive 552 hr) ≈ 0.988^552 ≈ 0.1% (matches ~2.3 adults dying/hr) # → rate ≈ 0.012 / K = 6e-7 for K = 20,000 cannibalism_rate: float = 6e-7 # --- Backend --- mating_backend: str = 'cell_list' # 'brute' or 'cell_list' device: str = 'cuda' # 'cuda' or 'cpu' seed: Optional[int] = None
# --------------------------------------------------------------------------- # Tensor-based beetle state (Structure-of-Arrays) # ---------------------------------------------------------------------------
[docs] class PopulationState: """ Holds the entire population (adults + eggs) as flat GPU tensors. We use a Structure-of-Arrays layout so that every operation (move, age, filter, mate) is a single vectorized kernel call instead of a Python loop over beetles. Attributes (all tensors on ``device``): x, y : float32 [N] – positions infected : bool [N] – Wolbachia status is_male : bool [N] – sex (True = male) age : int32 [N] – current age in hours max_life : int32 [N] – sampled life expectancy (hours) last_mate : int32 [N] – sim-time of last mating event is_egg : bool [N] – True while in egg stage (not yet hatched) """ def __init__(self, device: torch.device): self.device = device # Start empty – filled by GPUSimulation.initialize_population() self.x = torch.empty(0, device=device, dtype=torch.float32) self.y = torch.empty(0, device=device, dtype=torch.float32) self.infected = torch.empty(0, device=device, dtype=torch.bool) self.is_male = torch.empty(0, device=device, dtype=torch.bool) self.age = torch.empty(0, device=device, dtype=torch.int32) self.max_life = torch.empty(0, device=device, dtype=torch.int32) self.last_mate = torch.empty(0, device=device, dtype=torch.int32) self.is_egg = torch.empty(0, device=device, dtype=torch.bool) @property def n(self) -> int: return self.x.shape[0] @property def n_adults(self) -> int: return int((~self.is_egg).sum().item()) @property def n_eggs(self) -> int: return int(self.is_egg.sum().item()) # --- Helpers for adding / removing individuals -------------------------
[docs] def append(self, **kwargs): """Concatenate new individuals (given as keyword tensors) onto the state.""" for attr in ('x', 'y', 'infected', 'is_male', 'age', 'max_life', 'last_mate', 'is_egg'): current = getattr(self, attr) new_vals = kwargs[attr] setattr(self, attr, torch.cat([current, new_vals], dim=0))
[docs] def keep(self, mask: torch.Tensor): """Keep only individuals where ``mask`` is True (boolean tensor [N]).""" for attr in ('x', 'y', 'infected', 'is_male', 'age', 'max_life', 'last_mate', 'is_egg'): setattr(self, attr, getattr(self, attr)[mask])
[docs] def subsample(self, max_n: int): """If n > max_n, randomly keep max_n individuals (GPU grim_reaper).""" if self.n <= max_n: return perm = torch.randperm(self.n, device=self.device)[:max_n] for attr in ('x', 'y', 'infected', 'is_male', 'age', 'max_life', 'last_mate', 'is_egg'): setattr(self, attr, getattr(self, attr)[perm])
# --------------------------------------------------------------------------- # Main simulation # ---------------------------------------------------------------------------
[docs] class GPUSimulation: """GPU-accelerated agent-based model of *Wolbachia* spread. Replaces per-beetle Python loops with fully vectorised PyTorch tensor operations. All beetle state (position, sex, infection, age, mating cooldown) is stored as contiguous GPU tensors. Two mating backends are supported: - **brute**: Full F×M distance matrix. Fast for N < 20,000. - **cell_list**: Spatial hashing into grid cells. Scales to N > 100,000. The egg pipeline models *Tribolium*'s 23-day (552-hour) development period as a ring buffer of daily cohorts. Args: config (SimConfig): Simulation configuration. """ def __init__(self, cfg: SimConfig): self.cfg = cfg # Device selection if cfg.device == 'cuda' and torch.cuda.is_available(): self.device = torch.device('cuda') else: if cfg.device == 'cuda': print("CUDA not available – falling back to CPU.") self.device = torch.device('cpu') if cfg.seed is not None: torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) # State self.pop = PopulationState(self.device) self.sim_time: int = 0 # History (recorded per call to step / step_one_day) self.infection_history: List[float] = [] self.population_history: List[int] = [] # Pre-compute cell-list grid (if using that backend) self._cell_size: float = max(cfg.mating_distance, cfg.mating_distance * cfg.exploration_rate_boost) self._n_cells: int = max(1, int(cfg.grid_size / self._cell_size)) self._cell_size = cfg.grid_size / self._n_cells # exact fit # Initialize self._initialize_population() # ------------------------------------------------------------------ # Initialization # ------------------------------------------------------------------ def _initialize_population(self): N = self.cfg.initial_population n_infected = int(round(N * self.cfg.infected_fraction)) # --- Positions: central third of the grid --- lo = self.cfg.grid_size / 3.0 hi = 2.0 * self.cfg.grid_size / 3.0 x = torch.empty(N, device=self.device).uniform_(lo, hi) y = torch.empty(N, device=self.device).uniform_(lo, hi) # --- Infection status --- infected = torch.zeros(N, device=self.device, dtype=torch.bool) infected[:n_infected] = True # Shuffle so infected aren't all in the first slots perm = torch.randperm(N, device=self.device) infected = infected[perm] # --- Sex assignment --- # Among infected: 50/50 male/female # Among uninfected: follow male_to_female_ratio is_male = torch.zeros(N, device=self.device, dtype=torch.bool) inf_mask = infected uninf_mask = ~infected n_inf = int(inf_mask.sum().item()) n_uninf = int(uninf_mask.sum().item()) # Infected: half male inf_indices = inf_mask.nonzero(as_tuple=False).squeeze(-1) n_inf_male = n_inf // 2 inf_male_sel = inf_indices[torch.randperm(n_inf, device=self.device)[:n_inf_male]] is_male[inf_male_sel] = True # Uninfected: male_to_female_ratio fraction male uninf_indices = uninf_mask.nonzero(as_tuple=False).squeeze(-1) n_uninf_male = int(round(n_uninf * self.cfg.male_to_female_ratio)) uninf_male_sel = uninf_indices[torch.randperm(n_uninf, device=self.device)[:n_uninf_male]] is_male[uninf_male_sel] = True # --- Ages & life expectancy --- age = torch.randint(self.cfg.initial_age_min, self.cfg.initial_age_max + 1, (N,), device=self.device, dtype=torch.int32) max_life = torch.randint(self.cfg.life_expectancy_min, self.cfg.life_expectancy_max + 1, (N,), device=self.device, dtype=torch.int32) # --- Mating cooldown: allow immediate mating at t=0 --- last_mate = torch.full((N,), -self.cfg.mating_cooldown_female, device=self.device, dtype=torch.int32) # --- All start as adults (not eggs) --- is_egg = torch.zeros(N, device=self.device, dtype=torch.bool) # Write into state self.pop.x = x self.pop.y = y self.pop.infected = infected self.pop.is_male = is_male self.pop.age = age self.pop.max_life = max_life self.pop.last_mate = last_mate self.pop.is_egg = is_egg # ------------------------------------------------------------------ # Single simulation step (1 hour) # ------------------------------------------------------------------
[docs] def step(self): """Advance the simulation by one hour. Sequence: 1. Age all adults; remove those exceeding max lifespan. 2. Move adults (Lévy flight; ER beetles get 1.4× step size). 3. Find mating pairs within mating distance. 4. Reproduce (apply CI, MK, IE/RE). 5. Add new eggs to the pipeline. 6. Hatch eggs from 23 days ago. 7. Apply density-dependent egg mortality (cannibalism). 8. Record population size and infection rate. """ self._move() self._age() self._hatch_eggs() self._retire_dead() self._mate() self._enforce_capacity() self._record_stats() self.sim_time += 1
[docs] def step_one_day(self): """Run 24 hourly steps, recording stats only at the end of the day.""" for _ in range(24): self._move() self._age() self._hatch_eggs() self._retire_dead() self._mate() self._enforce_capacity() self.sim_time += 1 self._record_stats()
# ------------------------------------------------------------------ # Movement (vectorized Lévy flight) # ------------------------------------------------------------------ def _move(self): """ Lévy flight for all adults (eggs don't move). Step size ~ Pareto(α) + 1 implemented via inverse-CDF: U ~ Uniform(0,1) → step = U^(-1/α) Direction ~ Uniform(0, 2π). Positions wrap toroidally. """ adults = ~self.pop.is_egg n_adults = int(adults.sum().item()) if n_adults == 0: return U = torch.rand(n_adults, device=self.device).clamp(min=1e-7) step_sizes = U.pow(-1.0 / self.cfg.levy_alpha) angles = 2.0 * np.pi * torch.rand(n_adults, device=self.device) dx = step_sizes * torch.cos(angles) dy = step_sizes * torch.sin(angles) self.pop.x[adults] = (self.pop.x[adults] + dx) % self.cfg.grid_size self.pop.y[adults] = (self.pop.y[adults] + dy) % self.cfg.grid_size # ------------------------------------------------------------------ # Aging # ------------------------------------------------------------------ def _age(self): """Increment age of every individual (adults and eggs) by 1 hour.""" self.pop.age += 1 # ------------------------------------------------------------------ # Egg hatching # ------------------------------------------------------------------ def _hatch_eggs(self): """Eggs whose age exceeds the hatching threshold become adults.""" ready = self.pop.is_egg & (self.pop.age > self.cfg.egg_hatching_age) self.pop.is_egg[ready] = False # promote to adult # ------------------------------------------------------------------ # Death / retirement (with density-dependent mortality options) # ------------------------------------------------------------------ def _retire_dead(self): """ Remove dead individuals from the population. Always applies natural death (age > max_life). Then applies the selected density-dependent mortality mode. """ # 1. Natural death: remove adults that exceeded their life expectancy alive = self.pop.is_egg | (self.pop.age <= self.pop.max_life) self.pop.keep(alive) mode = self.cfg.mortality_mode if mode == 'none': return adult_mask = ~self.pop.is_egg n_adults = int(adult_mask.sum().item()) K = self.cfg.max_population if n_adults == 0 or K <= 0: return density_ratio = n_adults / K # N/K if mode == 'logistic': # ── Logistic density-dependent adult mortality ── # Per-capita hourly death probability increases with (N/K)^β. # At N=K, extra mortality ≈ 1/mean_lifespan per hour (doubles natural rate). # At N<<K, extra mortality ≈ 0. if density_ratio <= 0.5: return # negligible at low density base_hourly_mu = 1.0 / ((self.cfg.life_expectancy_min + self.cfg.life_expectancy_max) / 2) extra_mu = base_hourly_mu * (density_ratio ** self.cfg.mortality_beta) extra_mu = min(extra_mu, 0.1) # cap at 10% per hour to avoid instabilities # Each adult dies with probability extra_mu this hour adult_idx = adult_mask.nonzero(as_tuple=False).squeeze(-1) death_roll = torch.rand(n_adults, device=self.device) survivors = death_roll >= extra_mu kill_mask = torch.ones(self.pop.n, device=self.device, dtype=torch.bool) kill_mask[adult_idx[~survivors]] = False self.pop.keep(kill_mask) elif mode == 'cannibalism': # ── Egg cannibalism (Tribolium-style) ── # Adults consume eggs at a rate that scales with adult density. # This is the primary population regulation mechanism in flour beetles # (Daly & Ryan 1983, Park 1934, Sonleitner & Gutherie 1991). # Probability each egg is eaten = cannibalism_rate × N_adults × (N/K)^β egg_mask = self.pop.is_egg n_eggs = int(egg_mask.sum().item()) if n_eggs == 0: return p_eaten = self.cfg.cannibalism_rate * n_adults * (density_ratio ** self.cfg.mortality_beta) p_eaten = min(p_eaten, 0.95) # cap so some eggs always survive egg_idx = egg_mask.nonzero(as_tuple=False).squeeze(-1) death_roll = torch.rand(n_eggs, device=self.device) kill_mask = torch.ones(self.pop.n, device=self.device, dtype=torch.bool) kill_mask[egg_idx[death_roll < p_eaten]] = False self.pop.keep(kill_mask) elif mode == 'contest': # ── Contest competition ── # When N > K, each excess individual has a probability of dying # each hour. Below K, no extra mortality. if n_adults <= K: return excess = n_adults - K # Kill probability per adult = (excess / N) per hour p_die = excess / n_adults adult_idx = adult_mask.nonzero(as_tuple=False).squeeze(-1) death_roll = torch.rand(n_adults, device=self.device) kill_mask = torch.ones(self.pop.n, device=self.device, dtype=torch.bool) kill_mask[adult_idx[death_roll < p_die]] = False self.pop.keep(kill_mask) # ------------------------------------------------------------------ # Carrying-capacity enforcement # ------------------------------------------------------------------ def _enforce_capacity(self): """ Enforce population limits: - Adults: random cull if above max_population (hard carrying capacity). - Eggs: age-priority cull — remove the YOUNGEST eggs first so that eggs close to hatching survive. Without this, random culling kills eggs long before they reach the 552-hour hatching age. """ adult_mask = ~self.pop.is_egg egg_mask = self.pop.is_egg n_adults = int(adult_mask.sum().item()) n_eggs = int(egg_mask.sum().item()) excess_adults = n_adults - self.cfg.max_population excess_eggs = n_eggs - self.cfg.max_eggs if excess_adults > 0 or excess_eggs > 0: keep = torch.ones(self.pop.n, device=self.device, dtype=torch.bool) if excess_adults > 0: adult_idx = adult_mask.nonzero(as_tuple=False).squeeze(-1) remove_idx = adult_idx[torch.randperm(n_adults, device=self.device)[:excess_adults]] keep[remove_idx] = False if excess_eggs > 0: # Age-priority: remove the youngest eggs (smallest age → furthest from hatching) egg_idx = egg_mask.nonzero(as_tuple=False).squeeze(-1) egg_ages = self.pop.age[egg_idx] # Sort eggs by age ascending — first entries are youngest _, age_order = torch.sort(egg_ages) youngest_idx = egg_idx[age_order[:excess_eggs]] keep[youngest_idx] = False self.pop.keep(keep) # ------------------------------------------------------------------ # Mating (dispatcher) # ------------------------------------------------------------------ def _mate(self): if self.cfg.mating_backend == 'brute': self._mate_brute() else: self._mate_cell_list() # ------------------------------------------------------------------ # Mating backend 1: brute-force distance matrix # ------------------------------------------------------------------ def _mate_brute(self): """ Compute the full female × male toroidal distance matrix on GPU. For N = 20 000 (10 K × 10 K) this is ~400 MB float32 — fits comfortably in the L40S's 48 GB. Steps: 1. Identify eligible females and males (adult, off cooldown). 2. Compute pairwise toroidal distances. 3. Build a boolean "within range" mask. 4. For each female, randomly pick one (or two) eligible males. 5. Generate offspring and add as eggs. """ t = self.sim_time pop = self.pop # --- Eligibility masks --- adult = ~pop.is_egg female_mask = adult & ~pop.is_male male_mask = adult & pop.is_male # Cooldown check cd_female = self.cfg.mating_cooldown_female cd_male = self.cfg.mating_cooldown_male female_eligible = female_mask & ((t - pop.last_mate) >= cd_female) male_eligible = male_mask & ((t - pop.last_mate) >= cd_male) fem_idx = female_eligible.nonzero(as_tuple=False).squeeze(-1) mal_idx = male_eligible.nonzero(as_tuple=False).squeeze(-1) nf = fem_idx.shape[0] nm = mal_idx.shape[0] if nf == 0 or nm == 0: return # --- Toroidal distances [nf, nm] --- fx = pop.x[fem_idx].unsqueeze(1) # [nf, 1] fy = pop.y[fem_idx].unsqueeze(1) mx = pop.x[mal_idx].unsqueeze(0) # [1, nm] my = pop.y[mal_idx].unsqueeze(0) gs = self.cfg.grid_size dx = torch.abs(fx - mx) dx = torch.min(dx, gs - dx) dy = torch.abs(fy - my) dy = torch.min(dy, gs - dy) dist = torch.sqrt(dx * dx + dy * dy) # [nf, nm] # --- Per-female mating distance (infected females may have expanded range) --- md = self.cfg.mating_distance if self.cfg.wolbachia_effects.get('increased_exploration_rate', False): fem_infected = pop.infected[fem_idx] # [nf] per_fem_dist = torch.where(fem_infected, torch.tensor(md * self.cfg.exploration_rate_boost, device=self.device), torch.tensor(md, device=self.device)) # [nf] in_range = dist <= per_fem_dist.unsqueeze(1) # [nf, nm] else: in_range = dist <= md # --- Pair assignment: for each female, pick 1 (or 2) random males --- self._assign_mates_and_reproduce(fem_idx, mal_idx, in_range) # ------------------------------------------------------------------ # Mating backend 2: cell-list spatial hashing (for N > 50 000) # ------------------------------------------------------------------ def _mate_cell_list(self): """ Partition beetles into grid cells of side ≥ mating_distance. Each female only checks males in her own cell and the 8 neighbours (with toroidal wrapping). This reduces the distance evaluations from O(F·M) to O(F · k) where k is the mean per-cell male count × 9. At uniform density with 20 000 beetles and cell_size = 5 on a 500×500 grid, each cell has ~2 beetles on average, so k ≈ 18. This scales linearly with N. Implementation: we process one cell-neighbourhood at a time in a vectorized batch. No Python loop over individual beetles. """ t = self.sim_time pop = self.pop gs = self.cfg.grid_size cs = self._cell_size nc = self._n_cells # number of cells per dimension # --- Eligibility --- adult = ~pop.is_egg cd_female = self.cfg.mating_cooldown_female cd_male = self.cfg.mating_cooldown_male female_eligible = adult & ~pop.is_male & ((t - pop.last_mate) >= cd_female) male_eligible = adult & pop.is_male & ((t - pop.last_mate) >= cd_male) fem_idx = female_eligible.nonzero(as_tuple=False).squeeze(-1) mal_idx = male_eligible.nonzero(as_tuple=False).squeeze(-1) nf = fem_idx.shape[0] nm = mal_idx.shape[0] if nf == 0 or nm == 0: return # --- Assign cell indices --- fem_cx = (pop.x[fem_idx] / cs).long() % nc fem_cy = (pop.y[fem_idx] / cs).long() % nc fem_cell = fem_cx * nc + fem_cy # flat cell index mal_cx = (pop.x[mal_idx] / cs).long() % nc mal_cy = (pop.y[mal_idx] / cs).long() % nc mal_cell = mal_cx * nc + mal_cy # --- Sort males by cell for fast lookup --- sort_order = torch.argsort(mal_cell) mal_idx_sorted = mal_idx[sort_order] mal_cell_sorted = mal_cell[sort_order] # Build cell → male index ranges via searchsorted all_cells = torch.arange(nc * nc, device=self.device, dtype=torch.long) cell_starts = torch.searchsorted(mal_cell_sorted, all_cells) cell_ends = torch.searchsorted(mal_cell_sorted, all_cells + 1) # --- For each female, gather candidate males from 3×3 neighbourhood --- # Offsets for the 9 neighbours (including self) offsets_x = torch.tensor([-1, -1, -1, 0, 0, 0, 1, 1, 1], device=self.device, dtype=torch.long) offsets_y = torch.tensor([-1, 0, 1, -1, 0, 1, -1, 0, 1], device=self.device, dtype=torch.long) # For efficiency, we process females in batches grouped by cell. # But a simpler approach that works well for moderate N: # iterate over the 9 neighbour offsets, gather all candidate males # for ALL females at once, compute distances, then merge. # We'll build a sparse list of (female_local_idx, male_global_idx) pairs # that are within mating distance. md = self.cfg.mating_distance boost = self.cfg.exploration_rate_boost if self.cfg.wolbachia_effects.get( 'increased_exploration_rate', False) else 1.0 pair_fem = [] # local indices into fem_idx pair_mal = [] # global indices into pop arrays (via mal_idx_sorted) for ox, oy in zip(offsets_x.tolist(), offsets_y.tolist()): # Neighbour cell for each female nb_cx = (fem_cx + ox) % nc nb_cy = (fem_cy + oy) % nc nb_cell = nb_cx * nc + nb_cy # [nf] # Gather start/end for each female's neighbour cell starts = cell_starts[nb_cell] # [nf] ends = cell_ends[nb_cell] # [nf] lengths = ends - starts # [nf] – males in this neighbour cell max_len = int(lengths.max().item()) if nf > 0 else 0 if max_len == 0: continue # Build a [nf, max_len] index matrix of candidate males arange = torch.arange(max_len, device=self.device).unsqueeze(0) # [1, max_len] offsets_mat = starts.unsqueeze(1) + arange # [nf, max_len] valid = arange < lengths.unsqueeze(1) # [nf, max_len] # Clamp to avoid out-of-bounds (invalid slots will be masked out) offsets_mat = offsets_mat.clamp(max=nm - 1) cand_global = mal_idx_sorted[offsets_mat] # [nf, max_len] – global pop indices # Compute toroidal distances f_x = pop.x[fem_idx].unsqueeze(1).expand_as(offsets_mat.float()) f_y = pop.y[fem_idx].unsqueeze(1).expand_as(offsets_mat.float()) c_x = pop.x[cand_global] c_y = pop.y[cand_global] ddx = torch.abs(f_x - c_x) ddx = torch.min(ddx, gs - ddx) ddy = torch.abs(f_y - c_y) ddy = torch.min(ddy, gs - ddy) d = torch.sqrt(ddx * ddx + ddy * ddy) # Per-female mating distance if boost > 1.0: fem_inf = pop.infected[fem_idx].unsqueeze(1).expand_as(d) threshold = torch.where(fem_inf, torch.tensor(md * boost, device=self.device), torch.tensor(md, device=self.device)) else: threshold = md close = (d <= threshold) & valid # Extract pairs fi_local, mj_local = close.nonzero(as_tuple=True) if fi_local.shape[0] > 0: pair_fem.append(fi_local) pair_mal.append(cand_global[fi_local, mj_local]) if len(pair_fem) == 0: return all_pair_fem = torch.cat(pair_fem) # local indices into fem_idx all_pair_mal_global = torch.cat(pair_mal) # global pop indices of males # Deduplicate: same (female, male) pair may appear from overlapping cells pair_key = all_pair_fem.long() * pop.n + all_pair_mal_global.long() unique_keys, unique_inv = torch.unique(pair_key, return_inverse=True) # Keep first occurrence first_occ = torch.zeros(unique_keys.shape[0], device=self.device, dtype=torch.long) first_occ.scatter_(0, unique_inv, torch.arange(all_pair_fem.shape[0], device=self.device, dtype=torch.long)) all_pair_fem = all_pair_fem[first_occ] all_pair_mal_global = all_pair_mal_global[first_occ] # Build the in_range boolean matrix as a sparse representation, # then delegate to the shared assignment routine. # For the assignment function, we need a [nf, nm] mask. # But that defeats the purpose of the cell list for very large N. # Instead, we call a sparse-pair version of the assignment. self._assign_mates_sparse(fem_idx, mal_idx, all_pair_fem, all_pair_mal_global) # ------------------------------------------------------------------ # Mate assignment + offspring generation (dense matrix version) # ------------------------------------------------------------------ def _assign_mates_and_reproduce(self, fem_idx, mal_idx, in_range): """ Given the boolean [nf, nm] in_range matrix, assign mates randomly and generate offspring. Strategy: - For each female, pick one random in-range male. - If multiple_mating, allow a second mate for females that had a first. - A male can only be claimed once per step (enforced sequentially on unique males, but parallelised over the candidate selection). For 10K females this takes a few ms on GPU. """ nf = fem_idx.shape[0] nm = mal_idx.shape[0] pop = self.pop t = self.sim_time max_mates = 2 if self.cfg.multiple_mating else 1 # Random scores for each candidate pair (used to pick a random in-range male) rand_scores = torch.rand(nf, nm, device=self.device) rand_scores[~in_range] = -1.0 # disqualify out-of-range all_offspring = [] male_claimed = torch.zeros(nm, device=self.device, dtype=torch.bool) for _ in range(max_mates): # Mask out already-claimed males current_scores = rand_scores.clone() current_scores[:, male_claimed] = -1.0 # For each female, find the male with the highest random score best_scores, best_local_mal = current_scores.max(dim=1) # [nf] has_mate = best_scores > 0 # females that found at least one eligible male if not has_mate.any(): break # Indices of matched females (local) and males (local into mal_idx) matched_fem_local = has_mate.nonzero(as_tuple=False).squeeze(-1) matched_mal_local = best_local_mal[matched_fem_local] # Mark these males as claimed male_claimed[matched_mal_local] = True # Global indices gf = fem_idx[matched_fem_local] gm = mal_idx[matched_mal_local] # Update mating times pop.last_mate[gf] = t pop.last_mate[gm] = t # Disable these females from getting another mate in this pass rand_scores[matched_fem_local, :] = -1.0 # Generate offspring offspring = self._generate_offspring_batch(gf, gm) if offspring is not None: all_offspring.append(offspring) # Add all offspring to population self._add_offspring(all_offspring) # ------------------------------------------------------------------ # Mate assignment + offspring generation (sparse pairs version) # ------------------------------------------------------------------ def _assign_mates_sparse(self, fem_idx, mal_idx, pair_fem_local, pair_mal_global): """ Sparse-pair version for the cell-list backend. ``pair_fem_local``: indices into fem_idx for each candidate pair. ``pair_mal_global``: global pop indices for the male in each pair. """ pop = self.pop t = self.sim_time max_mates = 2 if self.cfg.multiple_mating else 1 nf = fem_idx.shape[0] # Random scores for tie-breaking rand_scores = torch.rand(pair_fem_local.shape[0], device=self.device) all_offspring = [] female_mated_count = torch.zeros(nf, device=self.device, dtype=torch.int32) male_claimed = set() # track globally which males are claimed # Sort pairs by female, then iterate in a vectorized-ish way # For each mating round, we pick one male per female. for _ in range(max_mates): if pair_fem_local.shape[0] == 0: break # For each female, pick the pair with the highest random score # Use scatter_max-like logic best_score = torch.full((nf,), -1.0, device=self.device) best_pair_idx = torch.full((nf,), -1, device=self.device, dtype=torch.long) # Simple approach: sort by (female, -score) and take first per female sort_key = pair_fem_local.float() - rand_scores / (rand_scores.max() + 1) order = torch.argsort(sort_key) sorted_fem = pair_fem_local[order] sorted_pair = order # Find first occurrence of each female change = torch.ones(sorted_fem.shape[0], device=self.device, dtype=torch.bool) change[1:] = sorted_fem[1:] != sorted_fem[:-1] first_indices = change.nonzero(as_tuple=False).squeeze(-1) selected_fem_local = sorted_fem[first_indices] selected_pair_idx = sorted_pair[first_indices] selected_mal_global = pair_mal_global[selected_pair_idx] # Filter out females that already reached max mates still_eligible = female_mated_count[selected_fem_local] < max_mates selected_fem_local = selected_fem_local[still_eligible] selected_mal_global = selected_mal_global[still_eligible] if selected_fem_local.shape[0] == 0: break # Deduplicate males (each male only once per round) _, unique_male_inv = torch.unique(selected_mal_global, return_inverse=True) unique_male_first = torch.zeros(selected_mal_global.max().item() + 1, device=self.device, dtype=torch.bool) keep_pair = torch.zeros(selected_fem_local.shape[0], device=self.device, dtype=torch.bool) for i in range(selected_fem_local.shape[0]): mg = selected_mal_global[i].item() if mg not in male_claimed: keep_pair[i] = True male_claimed.add(mg) selected_fem_local = selected_fem_local[keep_pair] selected_mal_global = selected_mal_global[keep_pair] if selected_fem_local.shape[0] == 0: break gf = fem_idx[selected_fem_local] gm = selected_mal_global pop.last_mate[gf] = t pop.last_mate[gm] = t female_mated_count[selected_fem_local] += 1 offspring = self._generate_offspring_batch(gf, gm) if offspring is not None: all_offspring.append(offspring) # Remove used pairs used_fem_set = set(selected_fem_local.tolist()) used_mal_set = male_claimed mask = torch.tensor([(pf.item() not in used_fem_set) and (pm.item() not in used_mal_set) for pf, pm in zip(pair_fem_local, pair_mal_global)], device=self.device, dtype=torch.bool) pair_fem_local = pair_fem_local[mask] pair_mal_global = pair_mal_global[mask] rand_scores = rand_scores[mask] self._add_offspring(all_offspring) # ------------------------------------------------------------------ # Vectorized offspring generation # ------------------------------------------------------------------ def _generate_offspring_batch(self, mother_global_idx, father_global_idx): """ Generate offspring for all mating pairs in a single vectorized call. Parameters ---------- mother_global_idx : Tensor [P] – global population indices of mothers father_global_idx : Tensor [P] – global population indices of fathers Returns ------- dict with tensors for each offspring attribute, or None if no offspring. """ P = mother_global_idx.shape[0] if P == 0: return None pop = self.pop cfg = self.cfg effects = cfg.wolbachia_effects # --- Eggs per pair (Uniform(1, egg_laying_max)) --- eggs = torch.randint(1, cfg.egg_laying_max + 1, (P,), device=self.device, dtype=torch.int32) # --- Logistic birth suppression: reduce clutch size as N → K --- # ONLY applied when mortality_mode == 'none' (hard cap only). # When density-dependent mortality is active (cannibalism/logistic/contest), # the mortality mechanism regulates population. Suppressing births at K # would starve CI of eggs to act on, making Wolbachia invasion impossible # once the population approaches carrying capacity. if cfg.mortality_mode == 'none': n_adults = int((~pop.is_egg).sum().item()) logistic_factor = max(0.0, 1.0 - n_adults / cfg.max_population) if logistic_factor < 1.0: eggs = torch.round(eggs.float() * logistic_factor).to(torch.int32) eggs = eggs.clamp(min=0) # --- Fecundity modifiers (only for infected mothers) --- mom_infected = pop.infected[mother_global_idx] # [P] inc_eggs = effects.get('increased_eggs', False) red_eggs = effects.get('reduced_eggs', False) if inc_eggs and not red_eggs: eggs[mom_infected] = torch.round( eggs[mom_infected].float() * cfg.fecundity_increase_factor ).to(torch.int32) elif red_eggs and not inc_eggs: eggs[mom_infected] = torch.round( eggs[mom_infected].float() * cfg.fecundity_decrease_factor ).to(torch.int32) # If both or neither, no change. # --- Cytoplasmic incompatibility --- if effects.get('cytoplasmic_incompatibility', False): dad_infected = pop.infected[father_global_idx] ci_mask = dad_infected & ~mom_infected # infected ♂ × uninfected ♀ if ci_mask.any(): if cfg.ci_strength >= 1.0: eggs[ci_mask] = 0 else: # Each egg survives independently with prob (1 - ci_strength) ci_idx = ci_mask.nonzero(as_tuple=False).squeeze(-1) max_e = int(eggs[ci_idx].max().item()) if max_e > 0: rand_mat = torch.rand(ci_idx.shape[0], max_e, device=self.device) lengths = eggs[ci_idx].unsqueeze(1) valid = torch.arange(max_e, device=self.device).unsqueeze(0) < lengths survived = ((rand_mat >= cfg.ci_strength) & valid).sum(dim=1).to(torch.int32) eggs[ci_idx] = survived # --- Expand: repeat mother attributes per egg --- total_eggs = int(eggs.sum().item()) if total_eggs == 0: return None # Mother index for each offspring mom_for_egg = mother_global_idx.repeat_interleave(eggs.long()) # [total_eggs] # --- Position near mother (offset ∈ {-1, 0, 1}) --- ox = torch.randint(-1, 2, (total_eggs,), device=self.device, dtype=torch.float32) oy = torch.randint(-1, 2, (total_eggs,), device=self.device, dtype=torch.float32) new_x = (pop.x[mom_for_egg] + ox) % cfg.grid_size new_y = (pop.y[mom_for_egg] + oy) % cfg.grid_size # --- Infection: inherited from mother --- new_infected = pop.infected[mom_for_egg] # --- Sex determination --- if effects.get('male_killing', False): # Infected mothers → mostly female offspring prob_male = torch.where( new_infected, torch.tensor(cfg.male_offspring_rate, device=self.device), torch.tensor(0.5, device=self.device) ) else: prob_male = torch.full((total_eggs,), 0.5, device=self.device) new_is_male = torch.rand(total_eggs, device=self.device) < prob_male # --- Age & life expectancy --- new_age = torch.zeros(total_eggs, device=self.device, dtype=torch.int32) new_max_life = torch.randint(cfg.life_expectancy_min, cfg.life_expectancy_max + 1, (total_eggs,), device=self.device, dtype=torch.int32) new_last_mate = torch.full((total_eggs,), -cfg.mating_cooldown_female, device=self.device, dtype=torch.int32) new_is_egg = torch.ones(total_eggs, device=self.device, dtype=torch.bool) return { 'x': new_x, 'y': new_y, 'infected': new_infected, 'is_male': new_is_male, 'age': new_age, 'max_life': new_max_life, 'last_mate': new_last_mate, 'is_egg': new_is_egg, } def _add_offspring(self, offspring_list): """Concatenate all offspring dicts into the population state.""" if not offspring_list: return combined = {} for key in offspring_list[0]: combined[key] = torch.cat([o[key] for o in offspring_list], dim=0) self.pop.append(**combined) # ------------------------------------------------------------------ # Statistics # ------------------------------------------------------------------ def _record_stats(self): adult = ~self.pop.is_egg n_adults = int(adult.sum().item()) if n_adults == 0: self.infection_history.append(0.0) else: n_inf = int((self.pop.infected & adult).sum().item()) self.infection_history.append(n_inf / n_adults) self.population_history.append(n_adults)
[docs] def get_infection_rate(self) -> float: return self.infection_history[-1] if self.infection_history else 0.0
[docs] def get_population_size(self) -> int: return self.population_history[-1] if self.population_history else 0
[docs] def get_sex_ratio(self) -> dict: """Returns counts of adult males and females (infected and uninfected).""" adult = ~self.pop.is_egg return { 'F_U': int((adult & ~self.pop.is_male & ~self.pop.infected).sum().item()), 'F_I': int((adult & ~self.pop.is_male & self.pop.infected).sum().item()), 'M_U': int((adult & self.pop.is_male & ~self.pop.infected).sum().item()), 'M_I': int((adult & self.pop.is_male & self.pop.infected).sum().item()), }
# ------------------------------------------------------------------ # CSV export (compatible with existing analysis pipeline) # ------------------------------------------------------------------
[docs] def export_history_csv(self, path: str): """Write the simulation time series to a CSV file. Columns: ``Population Size``, ``Infection Rate`` (one row per recorded time point — typically daily). Args: path (str): Output CSV file path. """ import csv n = min(len(self.infection_history), len(self.population_history)) with open(path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['Population Size', 'Infection Rate']) for i in range(n): writer.writerow([self.population_history[i], f"{self.infection_history[i]:.6f}"])
# --------------------------------------------------------------------------- # Convenience: run a full experiment # ---------------------------------------------------------------------------
[docs] def run_experiment(cfg: SimConfig, n_days: int = 365, verbose: bool = True) -> GPUSimulation: """Run a complete simulation experiment. Creates a :class:`GPUSimulation` from the config, runs for ``n_days × 24`` hours, and returns the simulation object with its recorded history. Args: config (SimConfig): Simulation parameters. n_days (int): Duration in days. Defaults to ``365``. Returns: GPUSimulation: Completed simulation with history. """ sim = GPUSimulation(cfg) start = time.time() for day in range(1, n_days + 1): sim.step_one_day() if verbose and day % 30 == 0: elapsed = time.time() - start print(f" Day {day:4d} | Pop {sim.get_population_size():6d} | " f"Inf {sim.get_infection_rate():.3f} | " f"Eggs {sim.pop.n_eggs:6d} | " f"Total {sim.pop.n:7d} | " f"{elapsed:.1f}s elapsed") if verbose: print(f"Completed {n_days} days in {time.time()-start:.1f}s") return sim
# --------------------------------------------------------------------------- # CLI entry point # --------------------------------------------------------------------------- if __name__ == '__main__': import argparse, json parser = argparse.ArgumentParser(description="WINGS GPU Simulation") parser.add_argument('--population', type=int, default=50) parser.add_argument('--max-pop', type=int, default=20_000) parser.add_argument('--max-eggs', type=int, default=800_000, help='Egg buffer cap (must be large for 23-day pipeline)') parser.add_argument('--grid-size', type=int, default=500) parser.add_argument('--days', type=int, default=365) parser.add_argument('--ci', action='store_true', help='Enable cytoplasmic incompatibility') parser.add_argument('--mk', action='store_true', help='Enable male killing') parser.add_argument('--er', action='store_true', help='Enable increased exploration rate') parser.add_argument('--ie', action='store_true', help='Enable increased eggs') parser.add_argument('--re', action='store_true', help='Enable reduced eggs') parser.add_argument('--ci-strength', type=float, default=1.0) parser.add_argument('--mortality', choices=['none', 'logistic', 'cannibalism', 'contest'], default='cannibalism', help='Density-dependent mortality mode') parser.add_argument('--mortality-beta', type=float, default=2.0, help='Exponent for density-dependent effects') parser.add_argument('--cannibalism-rate', type=float, default=6e-7, help='Egg cannibalism rate per adult per hour at N=K (default: 6e-7, calibrated for K=20000)') parser.add_argument('--backend', choices=['brute', 'cell_list'], default='cell_list') parser.add_argument('--device', choices=['cuda', 'cpu'], default='cuda') parser.add_argument('--infected-fraction', type=float, default=0.10, help='Initial fraction of population infected (default: 0.10)') parser.add_argument('--seed', type=int, default=None) parser.add_argument('--output', type=str, default=None, help='Output CSV path') args = parser.parse_args() cfg = SimConfig( initial_population=args.population, max_population=args.max_pop, max_eggs=args.max_eggs, grid_size=args.grid_size, infected_fraction=args.infected_fraction, ci_strength=args.ci_strength, mortality_mode=args.mortality, mortality_beta=args.mortality_beta, cannibalism_rate=args.cannibalism_rate, mating_backend=args.backend, device=args.device, seed=args.seed, wolbachia_effects={ 'cytoplasmic_incompatibility': args.ci, 'male_killing': args.mk, 'increased_exploration_rate': args.er, 'increased_eggs': args.ie, 'reduced_eggs': args.re, }, ) print(f"WINGS GPU Simulation") print(f" Device: {args.device}") print(f" Backend: {args.backend}") print(f" Population: {args.population}") print(f" Max pop: {args.max_pop} | Max eggs: {cfg.max_eggs}") print(f" Mortality: {args.mortality} (beta={args.mortality_beta})") print(f" Grid: {args.grid_size}×{args.grid_size}") print(f" Inf frac: {args.infected_fraction}") print(f" Days: {args.days}") print(f" Effects: {json.dumps(cfg.wolbachia_effects)}") print() sim = run_experiment(cfg, n_days=args.days) if args.output: sim.export_history_csv(args.output) print(f"Results saved to {args.output}")