Source code for arista.viz.response_curves

# ─────────────────────────────────────────────────────────────────
#  arista.viz.response_curves
#  « median ΔF/F + 95% bootstrap CI per (strain × cell_type) per step »
# ─────────────────────────────────────────────────────────────────
"""Reproduce Kossen 2019 Fig 19 / 22-23: response curves per strain × cell-type.

One panel per cell type (CC / HC, optionally WC). Within each panel
one line per strain, x-axis is ``delta_target_c`` (step target minus
22 °C baseline) or absolute ``target_temp_c``, y-axis is the per-step
median ΔF/F across recordings of that strain × cell_type.

Error band is **median ± 95% bootstrap CI of the median** by default —
non-parametric, no assumption of normality. The fallback ``error="iqr"``
mode plots median + interquartile range. The parametric ``error="sem"``
mode (mean ± SEM) runs a Shapiro-Wilk normality test per x-value and
silently downgrades to the CI band with a warning if any group rejects
normality at α = 0.05. This matches Bart's standing rule: median-based
summaries by default, parametric only after normality is checked.

Colours come from :data:`arista.constants.STRAIN_COLOURS`. The class
:class:`ResponseCurves` carries default styling for batch reuse; the
free function :func:`plot_response_curves` is the primary one-off API.
"""

from __future__ import annotations

import sqlite3
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd

from arista.constants import (
    FIGURE_DPI,
    STRAIN_COLOURS,
)

if TYPE_CHECKING:
    from matplotlib.figure import Figure


ErrorBand = Literal["ci", "iqr", "sem", "none"]

#: Minimum frames-in-window for a stimulus_responses row to qualify as
#: "well-matched". Filters out the Alex protocol-mismatch rows where
#: the sensor only briefly touched the target window.
_MIN_FRAMES_IN_WINDOW: int = 50

#: Bootstrap parameters for the 95% CI of the median.
_DEFAULT_N_BOOTSTRAP: int = 2000
_DEFAULT_CONFIDENCE_LEVEL: float = 0.95
_DEFAULT_RNG_SEED: int = 42

#: Shapiro-Wilk significance threshold for the parametric ``sem`` mode.
_SHAPIRO_ALPHA: float = 0.05


# ─────────────────────────────────────────────────────────────────
#  Data fetch
# ─────────────────────────────────────────────────────────────────


[docs] def fetch_response_data( conn: sqlite3.Connection, *, stimulus_name: str, cell_types: tuple[str, ...] = ("CC", "HC"), strains: tuple[str, ...] | None = None, min_frames_in_window: int = _MIN_FRAMES_IN_WINDOW, ) -> pd.DataFrame: """JOIN ``stimulus_responses`` with ``v_recordings`` and filter. Args: conn: Open SQLite connection to a populated ``arista.db``. stimulus_name: Canonical stimulus protocol name (``"ascAmp"``, ``"descAmp"``, …). cell_types: Which cell types to include. strains: Optional whitelist of canonical strain names; ``None`` means "every strain that has matching responses". min_frames_in_window: Drop rows that medianed over fewer than this many frames — filters Alex's protocol-mismatch cases where the sensor barely touched a step's target window. Returns: DataFrame with one row per (recording, step). """ placeholders = ",".join(["?"] * len(cell_types)) query = f""" SELECT v.recording_id, v.researcher_name, v.strain_name, v.cell_type, v.cell_number, v.hemisphere, v.animal_number, sr.step_index, sr.target_temp_c, sr.delta_target_c, sr.dfbf_response_median, sr.observed_temp_median, sr.n_frames_in_window FROM stimulus_responses sr JOIN v_recordings v ON v.recording_id = sr.recording_id WHERE v.stimulus_name = ? AND v.cell_type IN ({placeholders}) AND sr.n_frames_in_window >= ? """ params: list = [stimulus_name, *cell_types, min_frames_in_window] if strains is not None: strain_placeholders = ",".join(["?"] * len(strains)) query += f" AND v.strain_name IN ({strain_placeholders})" params.extend(strains) return pd.read_sql_query(query, conn, params=params)
# ───────────────────────────────────────────────────────────────── # Aggregation: median + bootstrap CI + Shapiro-Wilk # ───────────────────────────────────────────────────────────────── def _bootstrap_median_ci( values: np.ndarray, *, n_resamples: int, confidence_level: float, rng: np.random.Generator, ) -> tuple[float, float]: """Percentile bootstrap CI for the median of ``values``.""" n = len(values) if n < 2: v = float(values[0]) if n == 1 else float("nan") return (v, v) indices = rng.integers(0, n, size=(n_resamples, n)) boot_medians = np.median(values[indices], axis=1) alpha = (1.0 - confidence_level) / 2.0 low = float(np.quantile(boot_medians, alpha)) high = float(np.quantile(boot_medians, 1.0 - alpha)) return (low, high) def _shapiro_p(values: np.ndarray) -> float: """Shapiro-Wilk p; NaN when n < 3 or when scipy refuses.""" if len(values) < 3: return float("nan") from scipy.stats import shapiro try: return float(shapiro(values).pvalue) except Exception: return float("nan") def _aggregate_group( values: pd.Series, *, n_bootstrap: int, confidence_level: float, rng: np.random.Generator, ) -> dict[str, float]: """Compute median + bootstrap CI + IQR + mean + SEM + Shapiro p + n.""" x = values.dropna().to_numpy(dtype=float) n = len(x) if n == 0: return { "median": np.nan, "ci_low": np.nan, "ci_high": np.nan, "q25": np.nan, "q75": np.nan, "mean": np.nan, "sem": np.nan, "shapiro_p": np.nan, "n": 0, } ci_low, ci_high = _bootstrap_median_ci( x, n_resamples=n_bootstrap, confidence_level=confidence_level, rng=rng, ) sem = float(np.std(x, ddof=1) / np.sqrt(n)) if n >= 2 else float("nan") return { "median": float(np.median(x)), "ci_low": ci_low, "ci_high": ci_high, "q25": float(np.quantile(x, 0.25)), "q75": float(np.quantile(x, 0.75)), "mean": float(np.mean(x)), "sem": sem, "shapiro_p": _shapiro_p(x), "n": int(n), }
[docs] def aggregate_response_data( df: pd.DataFrame, *, by_x: str = "delta_target_c", n_bootstrap: int = _DEFAULT_N_BOOTSTRAP, confidence_level: float = _DEFAULT_CONFIDENCE_LEVEL, rng_seed: int = _DEFAULT_RNG_SEED, ) -> pd.DataFrame: """Per (strain, cell_type, x) compute centre + spread + n. The output carries enough columns to render any of the four supported error bands (``ci`` / ``iqr`` / ``sem`` / ``none``) without re-aggregating: - ``median, ci_low, ci_high`` — non-parametric default. CI is the percentile bootstrap CI of the median at ``confidence_level`` (default 0.95). Resampled ``n_bootstrap`` times (default 2000) from a seeded ``np.random.Generator``. - ``q25, q75`` — interquartile range (legacy ``error="iqr"`` mode). - ``mean, sem`` — for the parametric ``error="sem"`` mode. - ``shapiro_p`` — per-group Shapiro-Wilk p (NaN when n<3); used by the plot layer to gate ``error="sem"``. - ``n`` — recordings contributing to this (strain, cell_type, x) bucket. Args: df: Output of :func:`fetch_response_data`. by_x: ``"delta_target_c"`` (default, Kossen-style relative amplitude) or ``"target_temp_c"`` (absolute temperature). n_bootstrap: Number of bootstrap resamples for the median CI. confidence_level: CI coverage (default 0.95). rng_seed: Seed for the bootstrap RNG (reproducibility). """ columns = [ "strain_name", "cell_type", by_x, "n", "median", "ci_low", "ci_high", "q25", "q75", "mean", "sem", "shapiro_p", ] if df.empty: return pd.DataFrame(columns=columns) rng = np.random.default_rng(rng_seed) rows: list[dict] = [] for (strain, cell_type, x_val), group_df in df.groupby( ["strain_name", "cell_type", by_x] ): row: dict = {"strain_name": strain, "cell_type": cell_type, by_x: x_val} row.update(_aggregate_group( group_df["dfbf_response_median"], n_bootstrap=n_bootstrap, confidence_level=confidence_level, rng=rng, )) rows.append(row) return pd.DataFrame(rows, columns=columns)
# ───────────────────────────────────────────────────────────────── # Plot # ───────────────────────────────────────────────────────────────── def _band_for_strain( strain_data: pd.DataFrame, *, error: ErrorBand, label: str, shapiro_alpha: float, ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, str]: """Return ``(centre, lower, upper, centre_kind)`` for one strain. ``centre_kind`` is either ``"median"`` or ``"mean"`` (drives the y-axis label). When ``error == "sem"`` and Shapiro-Wilk rejects normality at any x-value the band silently downgrades to median + CI and emits a runtime warning naming the strain. """ median = strain_data["median"].to_numpy() if error == "ci": return median, strain_data["ci_low"].to_numpy(), strain_data["ci_high"].to_numpy(), "median" if error == "iqr": return median, strain_data["q25"].to_numpy(), strain_data["q75"].to_numpy(), "median" if error == "none": return median, None, None, "median" # error == "sem": gated by Shapiro-Wilk shapiro_p = strain_data["shapiro_p"].to_numpy() finite = np.isfinite(shapiro_p) rejected = (shapiro_p < shapiro_alpha) & finite if np.any(rejected): warnings.warn( f"{label}: Shapiro-Wilk rejected normality at " f"{int(rejected.sum())} of {int(finite.sum())} x-values " f"(p < {shapiro_alpha}); falling back to median + 95% CI.", RuntimeWarning, stacklevel=3, ) return median, strain_data["ci_low"].to_numpy(), strain_data["ci_high"].to_numpy(), "median" mean = strain_data["mean"].to_numpy() sem = strain_data["sem"].to_numpy() return mean, mean - sem, mean + sem, "mean"
[docs] def plot_response_curves( data: pd.DataFrame | sqlite3.Connection, *, stimulus_name: str | None = None, strains: tuple[str, ...] | None = None, cell_types: tuple[str, ...] = ("CC", "HC"), by_x: str = "delta_target_c", error: ErrorBand = "ci", figsize: tuple[float, float] = (10, 5), title: str | None = None, shapiro_alpha: float = _SHAPIRO_ALPHA, ) -> Figure: """Two-panel response curve figure: one panel per cell type. Args: data: Either an already-fetched DataFrame (output of :func:`fetch_response_data`) or a live SQLite connection. When a connection is passed ``stimulus_name`` becomes mandatory. stimulus_name: Stimulus protocol name (when fetching from DB). strains: Optional strain whitelist. cell_types: Cell types to plot, one panel per. by_x: ``"delta_target_c"`` (Kossen's relative-amplitude axis, default) or ``"target_temp_c"`` (absolute). error: Error band mode. ``"ci"`` (default, median + 95% bootstrap CI of the median), ``"iqr"`` (median + IQR), ``"sem"`` (mean + SEM, gated by per-x Shapiro-Wilk; falls back to ``"ci"`` with a warning on non-normality), or ``"none"``. figsize: ``(width, height)`` inches. title: Figure suptitle. Auto-generated from ``stimulus_name`` when ``None``. shapiro_alpha: Normality rejection threshold for the ``"sem"`` mode (default 0.05). Returns: :class:`matplotlib.figure.Figure`. The caller owns I/O. """ import matplotlib.pyplot as plt if isinstance(data, sqlite3.Connection): if stimulus_name is None: raise ValueError( "stimulus_name is required when data is a Connection" ) df = fetch_response_data( data, stimulus_name=stimulus_name, cell_types=cell_types, strains=strains, ) else: df = data agg = aggregate_response_data(df, by_x=by_x) fig, axes = plt.subplots(1, len(cell_types), figsize=figsize, sharey=True) if len(cell_types) == 1: axes = [axes] centre_kinds: set[str] = set() for ax, cell_type in zip(axes, cell_types, strict=True): sub = agg[agg["cell_type"] == cell_type] if sub.empty: ax.text( 0.5, 0.5, f"no data for {cell_type}", transform=ax.transAxes, ha="center", va="center", color="#888888", ) for strain in sorted(sub["strain_name"].unique()): strain_data = sub[sub["strain_name"] == strain].sort_values(by_x) colour = STRAIN_COLOURS.get(strain, "#666666") n_label = int(strain_data["n"].max()) label = f"{strain} (n≤{n_label})" centre, lower, upper, kind = _band_for_strain( strain_data, error=error, label=f"{strain}/{cell_type}", shapiro_alpha=shapiro_alpha, ) centre_kinds.add(kind) x = strain_data[by_x].to_numpy() ax.plot( x, centre, color=colour, marker="o", linewidth=1.4, markersize=4, label=label, ) if lower is not None and upper is not None: ax.fill_between( x, lower, upper, color=colour, alpha=0.18, linewidth=0, ) ax.set_title(f"{cell_type} cells") ax.axhline(0.0, color="#cccccc", linewidth=0.5, linestyle="--", zorder=0) ax.axvline(0.0, color="#cccccc", linewidth=0.5, linestyle="--", zorder=0) ax.set_xlabel( r"$\Delta$ target T (°C from baseline)" if by_x == "delta_target_c" else "target T (°C)" ) ax.legend(loc="best", fontsize=7, framealpha=0.85) ax.grid(True, axis="y", linestyle=":", linewidth=0.4, alpha=0.5) ylabel_centre = "mean" if centre_kinds == {"mean"} else "median" band_label = { "ci": " (median ± 95% CI)", "iqr": " (median ± IQR)", "sem": " (mean ± SEM)" if ylabel_centre == "mean" else " (median ± 95% CI)", "none": "", }[error] axes[0].set_ylabel( rf"{ylabel_centre} $\Delta F / F_0$" + band_label ) if title is None and stimulus_name: title = f"Response curves — {stimulus_name}" if title: fig.suptitle(title, fontsize=11) fig.tight_layout() return fig
# ───────────────────────────────────────────────────────────────── # Class wrapper for batch reuse # ─────────────────────────────────────────────────────────────────
[docs] @dataclass class ResponseCurves: """Callable wrapper holding default styling. Useful when rendering many figures (one per stimulus protocol) with consistent kwargs:: plotter = ResponseCurves(figsize=(12, 5)) for stim in ("ascAmp", "descAmp"): fig = plotter(conn, stimulus_name=stim) plotter.save(fig, output_dir / f"resp_{stim}.png") """ cell_types: tuple[str, ...] = ("CC", "HC") by_x: str = "delta_target_c" error: ErrorBand = "ci" figsize: tuple[float, float] = (10, 5) dpi: int = FIGURE_DPI strains: tuple[str, ...] | None = None
[docs] def plot( self, data: pd.DataFrame | sqlite3.Connection, *, stimulus_name: str | None = None, title: str | None = None, **overrides, ) -> Figure: kwargs = { "cell_types": self.cell_types, "by_x": self.by_x, "error": self.error, "figsize": self.figsize, "strains": self.strains, } kwargs.update(overrides) return plot_response_curves( data, stimulus_name=stimulus_name, title=title, **kwargs )
def __call__( self, data: pd.DataFrame | sqlite3.Connection, **kwargs, ) -> Figure: return self.plot(data, **kwargs)
[docs] def save( self, fig: Figure, path: Path | str, *, dpi: int | None = None, close: bool = True, ) -> Path: """Save fig as PNG. SVG sibling is written automatically.""" import matplotlib.pyplot as plt path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) png_path = path.with_suffix(".png") svg_path = path.with_suffix(".svg") fig.savefig(png_path, dpi=dpi or self.dpi) fig.savefig(svg_path) if close: plt.close(fig) return png_path