Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
7838d62
Add ContinuousDiD estimator for continuous treatment dose-response
igerber Feb 21, 2026
b9e97f0
Fix PR #177 review issues: control group bug, safe_inference, dose va…
igerber Feb 21, 2026
8c3980b
Fix analytical SE scaling, add empty post_gt guard, and validate params
igerber Feb 21, 2026
40d22d4
Fix bootstrap percentile inference, add P(D=0) warning, and analytica…
igerber Feb 21, 2026
1e01fd4
Address round-4 review: harden validation and fix tempfile usage
igerber Feb 22, 2026
0f2849b
Fix bootstrap ACRT^{glob} centering bug and add regression test
igerber Feb 22, 2026
f737ab6
Address PR #177 review round 6: boundary knot docs, results provenanc…
igerber Feb 22, 2026
ce8240e
Address PR #177 review round 7: clarify global knots and aggregation …
igerber Feb 22, 2026
9808280
Guard bootstrap NaN propagation: SE/CI/p-value all NaN when SE invalid
igerber Feb 22, 2026
b9bd264
Store bootstrap p-values in DoseResponseCurve, add event-study parame…
igerber Feb 22, 2026
55b0d2d
Fix NaN propagation in rank-deficient spline predictions
igerber Feb 22, 2026
a8c1c2e
Fix bootstrap NaN propagation for rank-deficient cells and remove unu…
igerber Feb 22, 2026
6e570a7
Replace SunAbraham manual bootstrap stats with NaN-gated utility
igerber Feb 22, 2026
2d663d9
Fix test_all_same_dose bootstrap assertion on macOS
igerber Feb 22, 2026
fddb6c3
Use heterogeneous outcomes instead of noise injection in bootstrap test
igerber Feb 22, 2026
5449bbb
Guard non-finite original_effect in compute_effect_bootstrap_stats
igerber Feb 22, 2026
999da34
Fix not-yet-treated control mask to respect anticipation parameter
igerber Feb 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Deferred items from PR reviews that were not addressed before merge.
| Issue | Location | PR | Priority |
|-------|----------|----|----------|
| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) |
| Bootstrap NaN-gating gap: manual SE/CI/p-value without non-finite filtering or SE<=0 guard | `imputation_bootstrap.py`, `two_stage_bootstrap.py` | #177 | Medium — migrate to `compute_effect_bootstrap_stats` from `bootstrap_utils.py` |

#### Performance

Expand Down
10 changes: 10 additions & 0 deletions diff_diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
aggregate_to_cohorts,
balance_panel,
create_event_time,
generate_continuous_did_data,
generate_did_data,
generate_ddd_data,
generate_event_study_data,
Expand Down Expand Up @@ -122,6 +123,11 @@
TripleDifferenceResults,
triple_difference,
)
from diff_diff.continuous_did import (
ContinuousDiD,
ContinuousDiDResults,
DoseResponseCurve,
)
from diff_diff.trop import (
TROP,
TROPResults,
Expand Down Expand Up @@ -161,6 +167,7 @@
"MultiPeriodDiD",
"SyntheticDiD",
"CallawaySantAnna",
"ContinuousDiD",
"SunAbraham",
"ImputationDiD",
"TwoStageDiD",
Expand All @@ -181,6 +188,8 @@
"CallawaySantAnnaResults",
"CSBootstrapResults",
"GroupTimeEffect",
"ContinuousDiDResults",
"DoseResponseCurve",
"SunAbrahamResults",
"SABootstrapResults",
"ImputationDiDResults",
Expand Down Expand Up @@ -228,6 +237,7 @@
"generate_ddd_data",
"generate_panel_data",
"generate_event_study_data",
"generate_continuous_did_data",
"create_event_time",
"aggregate_to_cohorts",
"rank_control_units",
Expand Down
279 changes: 279 additions & 0 deletions diff_diff/bootstrap_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""
Shared bootstrap utilities for multiplier bootstrap inference.

Provides weight generation, percentile CI, and p-value helpers used by
both CallawaySantAnna and ContinuousDiD estimators.
"""

import warnings
from typing import Optional, Tuple

import numpy as np

from diff_diff._backend import HAS_RUST_BACKEND, _rust_bootstrap_weights

__all__ = [
"generate_bootstrap_weights",
"generate_bootstrap_weights_batch",
"generate_bootstrap_weights_batch_numpy",
"compute_percentile_ci",
"compute_bootstrap_pvalue",
"compute_effect_bootstrap_stats",
]


def generate_bootstrap_weights(
n_units: int,
weight_type: str,
rng: np.random.Generator,
) -> np.ndarray:
"""
Generate bootstrap weights for multiplier bootstrap.

Parameters
----------
n_units : int
Number of units (clusters) to generate weights for.
weight_type : str
Type of weights: "rademacher", "mammen", or "webb".
rng : np.random.Generator
Random number generator.

Returns
-------
np.ndarray
Array of bootstrap weights with shape (n_units,).
"""
if weight_type == "rademacher":
return rng.choice([-1.0, 1.0], size=n_units)
elif weight_type == "mammen":
sqrt5 = np.sqrt(5)
val1 = -(sqrt5 - 1) / 2
val2 = (sqrt5 + 1) / 2
p1 = (sqrt5 + 1) / (2 * sqrt5)
return rng.choice([val1, val2], size=n_units, p=[p1, 1 - p1])
elif weight_type == "webb":
values = np.array([
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
])
return rng.choice(values, size=n_units)
else:
raise ValueError(
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
f"got '{weight_type}'"
)


def generate_bootstrap_weights_batch(
n_bootstrap: int,
n_units: int,
weight_type: str,
rng: np.random.Generator,
) -> np.ndarray:
"""
Generate all bootstrap weights at once (vectorized).

Uses Rust backend if available for parallel generation.

Parameters
----------
n_bootstrap : int
Number of bootstrap iterations.
n_units : int
Number of units (clusters) to generate weights for.
weight_type : str
Type of weights: "rademacher", "mammen", or "webb".
rng : np.random.Generator
Random number generator.

Returns
-------
np.ndarray
Array of bootstrap weights with shape (n_bootstrap, n_units).
"""
if HAS_RUST_BACKEND and _rust_bootstrap_weights is not None:
seed = rng.integers(0, 2**63 - 1)
return _rust_bootstrap_weights(n_bootstrap, n_units, weight_type, seed)
return generate_bootstrap_weights_batch_numpy(n_bootstrap, n_units, weight_type, rng)


def generate_bootstrap_weights_batch_numpy(
n_bootstrap: int,
n_units: int,
weight_type: str,
rng: np.random.Generator,
) -> np.ndarray:
"""
NumPy fallback implementation of :func:`generate_bootstrap_weights_batch`.

Parameters
----------
n_bootstrap : int
Number of bootstrap iterations.
n_units : int
Number of units (clusters) to generate weights for.
weight_type : str
Type of weights: "rademacher", "mammen", or "webb".
rng : np.random.Generator
Random number generator.

Returns
-------
np.ndarray
Array of bootstrap weights with shape (n_bootstrap, n_units).
"""
if weight_type == "rademacher":
return rng.choice([-1.0, 1.0], size=(n_bootstrap, n_units))
elif weight_type == "mammen":
sqrt5 = np.sqrt(5)
val1 = -(sqrt5 - 1) / 2
val2 = (sqrt5 + 1) / 2
p1 = (sqrt5 + 1) / (2 * sqrt5)
return rng.choice([val1, val2], size=(n_bootstrap, n_units), p=[p1, 1 - p1])
elif weight_type == "webb":
values = np.array([
-np.sqrt(3 / 2), -np.sqrt(2 / 2), -np.sqrt(1 / 2),
np.sqrt(1 / 2), np.sqrt(2 / 2), np.sqrt(3 / 2)
])
return rng.choice(values, size=(n_bootstrap, n_units))
else:
raise ValueError(
f"weight_type must be 'rademacher', 'mammen', or 'webb', "
f"got '{weight_type}'"
)


def compute_percentile_ci(
boot_dist: np.ndarray,
alpha: float,
) -> Tuple[float, float]:
"""
Compute percentile confidence interval from bootstrap distribution.

Parameters
----------
boot_dist : np.ndarray
Bootstrap distribution (1-D array).
alpha : float
Significance level (e.g., 0.05 for 95% CI).

Returns
-------
tuple of float
``(lower, upper)`` confidence interval bounds.
"""
lower = float(np.percentile(boot_dist, alpha / 2 * 100))
upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
return (lower, upper)


def compute_bootstrap_pvalue(
original_effect: float,
boot_dist: np.ndarray,
n_valid: Optional[int] = None,
) -> float:
"""
Compute two-sided bootstrap p-value using the percentile method.

Parameters
----------
original_effect : float
Original point estimate.
boot_dist : np.ndarray
Bootstrap distribution of the effect.
n_valid : int, optional
Number of valid bootstrap samples for p-value floor.
If None, uses ``len(boot_dist)``.

Returns
-------
float
Two-sided bootstrap p-value.
"""
if original_effect >= 0:
p_one_sided = np.mean(boot_dist <= 0)
else:
p_one_sided = np.mean(boot_dist >= 0)

p_value = min(2 * p_one_sided, 1.0)
n_for_floor = n_valid if n_valid is not None else len(boot_dist)
p_value = max(p_value, 1 / (n_for_floor + 1))
return float(p_value)


def compute_effect_bootstrap_stats(
original_effect: float,
boot_dist: np.ndarray,
alpha: float = 0.05,
context: str = "bootstrap distribution",
) -> Tuple[float, Tuple[float, float], float]:
"""
Compute bootstrap statistics for a single effect.

Filters non-finite samples, returning NaN for all statistics if
fewer than 50% of samples are valid.

Parameters
----------
original_effect : float
Original point estimate.
boot_dist : np.ndarray
Bootstrap distribution of the effect.
alpha : float, default=0.05
Significance level.
context : str, optional
Description for warning messages.

Returns
-------
se : float
Bootstrap standard error.
ci : tuple of float
Percentile confidence interval.
p_value : float
Bootstrap p-value.
"""
if not np.isfinite(original_effect):
return np.nan, (np.nan, np.nan), np.nan

finite_mask = np.isfinite(boot_dist)
n_valid = np.sum(finite_mask)
n_total = len(boot_dist)

if n_valid < n_total:
n_nonfinite = n_total - n_valid
warnings.warn(
f"Dropping {n_nonfinite}/{n_total} non-finite bootstrap samples "
f"in {context}. Bootstrap estimates based on remaining valid samples.",
RuntimeWarning,
stacklevel=3,
)

if n_valid < n_total * 0.5:
warnings.warn(
f"Too few valid bootstrap samples ({n_valid}/{n_total}) in {context}. "
"Returning NaN for SE/CI/p-value to signal invalid inference.",
RuntimeWarning,
stacklevel=3,
)
return np.nan, (np.nan, np.nan), np.nan

valid_dist = boot_dist[finite_mask]
se = float(np.std(valid_dist, ddof=1))

# Guard: if SE is not finite or zero, all inference fields must be NaN.
if not np.isfinite(se) or se <= 0:
warnings.warn(
f"Bootstrap SE is non-finite or zero (n_valid={n_valid}) in {context}. "
"Returning NaN for SE/CI/p-value.",
RuntimeWarning,
stacklevel=3,
)
return np.nan, (np.nan, np.nan), np.nan

ci = compute_percentile_ci(valid_dist, alpha)
p_value = compute_bootstrap_pvalue(
original_effect, valid_dist, n_valid=len(valid_dist)
)
return se, ci, p_value
Loading