A Python Balnet Prototype Using Adelie

Covariate-balancing propensity scores as a custom GLM path

Show code
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import optimize
from IPython.display import HTML, display

from aipyw.balnet_adelie import (
    CBPSGlm,
    balancing_weights,
    effective_sample_size,
    fit_balnet_arm,
    fit_balnet_ate,
    standardized_mean_differences,
)

np.set_printoptions(precision=4, suppress=True)


def html_table(df, digits=4):
    return HTML(df.to_html(index=False, border=0, classes="table table-sm table-striped", float_format=lambda x: f"{x:.{digits}f}"))

1 Goal

The R package balnet fits regularized logistic propensity-score models, but replaces the ordinary Bernoulli likelihood with a covariate-balancing loss. The key engineering trick is that the loss can be presented to adelie as a custom GLM family, so Adelie’s pathwise group elastic-net solver can do the hard optimization work.

This memo sketches the optimization problem and a Python proof of concept in aipyw:

aipyw/balnet_adelie.py

The implementation uses a Python subclass of Adelie’s GlmBase64 to define the balnet calibration loss, then calls adelie.grpnet(...) for the regularization path.

2 One-arm balancing loss

For one arm, let \(y_i \in \{0,1\}\) indicate membership in the arm being modeled. Let

\[ \eta_i = \alpha + x_i'\beta, \qquad p_i = \frac{1}{1+\exp(-\eta_i)}. \]

Balnet’s one-arm calibration loss is

\[ L(\alpha,\beta) = \sum_{i=1}^n \omega_i \left[y_i\exp(-\eta_i) + (1-y_i)\eta_i\right]. \]

The penalized lasso path solves

\[ \min_{\alpha,\beta}\quad L(\alpha,\beta) + \lambda \lVert \beta \rVert_1, \]

or, more generally, a group elastic-net penalty through Adelie’s grpnet interface.

The derivative of the loss with respect to the linear predictor is

\[ \frac{\partial L}{\partial \eta_i} = \omega_i\left[-y_i\exp(-\eta_i) + (1-y_i)\right]. \]

Adelie’s GLM interface expects the negative gradient, so the Python custom GLM returns

\[ r_i = \omega_i\left[y_i\exp(-\eta_i) - (1-y_i)\right], \]

with curvature

\[ h_i = \omega_i y_i\exp(-\eta_i). \]

This matches balnet’s C++ extension, up to optional target scaling.

3 Why this balances covariates

At an unpenalized optimum, the score equation for a covariate column \(x_j\) is

\[ \sum_i \omega_i x_{ij}\left[y_i\exp(-\eta_i) - (1-y_i)\right] = 0. \]

For the treated arm, \(y_i=W_i\). Since \(\exp(-\eta_i)=(1-p_i)/p_i\), this says

\[ \sum_i \omega_i W_i x_i \frac{1-p_i}{p_i} = \sum_i \omega_i (1-W_i)x_i. \]

With intercept handling, this is equivalent to balancing a propensity-score weighted arm against the target arm in the calibration-loss geometry. With the lasso penalty, the KKT conditions become approximate balance:

\[ \left| X_j' r \right| \leq \lambda \]

for inactive coordinates, and equality with sign corrections for active coordinates. This is the sense in which \(\lambda\) is a direct imbalance budget for standardized covariates.

For the ATE, balnet fits two one-arm models:

  • a treated model with \(y=W\), producing treated weights \(W_i / \hat e_1(X_i)\);
  • a control model with \(y=1-W\), producing control weights \((1-W_i)/(1-\hat e_0(X_i))\).

The Python proof of concept follows that structure.

4 How Adelie is used

The prototype defines:

class CBPSGlm(ad.glm.glm_base, ad.glm.GlmBase64):
    def gradient(self, eta, out): ...
    def hessian(self, eta, grad, out): ...
    def loss(self, eta): ...
    def inv_link(self, eta, out): ...

Then it calls:

state = ad.grpnet(
    X_standardized,
    CBPSGlm(y, weights=sample_weights),
    lmda_path=lambdas,
    alpha=alpha,
    intercept=True,
)

The returned Adelie state contains a sparse coefficient path and intercept path. The wrapper unstandardizes coefficients back to the original covariate scale and provides balancing weights.

5 Sanity check: Adelie vs direct SciPy optimization

For a small one-arm lasso problem, solve the same objective once with Adelie and once with a generic SciPy optimizer. The generic optimizer is slow and not pathwise, but it is useful as a correctness check.

Show code
def simulate(seed=123, n=900, p=8):
    rng = np.random.default_rng(seed)
    x = rng.normal(size=(n, p))
    logit = -0.25 + 1.0 * x[:, 0] - 0.75 * x[:, 1] + 0.45 * x[:, 2]
    e = 1.0 / (1.0 + np.exp(-logit))
    w = rng.binomial(1, e).astype(float)
    tau = 1.0
    mu0 = 0.5 * x[:, 0] + 0.25 * x[:, 1] ** 2 - 0.4 * x[:, 3]
    y = mu0 + tau * w + rng.normal(scale=1.0, size=n)
    return x, w, y, e, tau

x_small, w_small, *_ = simulate(seed=1, n=180, p=5)
lam = 0.04
fit_one = fit_balnet_arm(x_small, w_small, lambdas=np.array([lam]), progress_bar=False)

stan = fit_one.standardization
xs = stan.transform(x_small)
weights = np.full(x_small.shape[0], 1.0 / x_small.shape[0])


def scipy_objective(theta):
    alpha = theta[0]
    beta = theta[1:]
    eta = np.clip(alpha + xs @ beta, -35, 35)
    loss = np.sum(weights * (w_small * np.exp(-eta) + (1 - w_small) * eta))
    return loss + lam * np.sum(np.abs(beta))

theta_start = np.r_[fit_one.adelie_state.intercepts[0], fit_one.adelie_state.betas.toarray()[0]]
scipy_fit = optimize.minimize(
    scipy_objective,
    theta_start,
    method="Powell",
    options={"xtol": 1e-8, "ftol": 1e-8, "maxiter": 3000},
)

beta_scipy = scipy_fit.x[1:] / stan.scale
alpha_scipy = scipy_fit.x[0] - scipy_fit.x[1:] @ (stan.center / stan.scale)
eta_adelie = fit_one.linear_predictor(x_small, 0)
eta_scipy = alpha_scipy + x_small @ beta_scipy

check = pd.DataFrame(
    [
        {
            "lambda": lam,
            "Adelie objective": scipy_objective(theta_start),
            "SciPy objective": scipy_objective(scipy_fit.x),
            "RMSE eta(Adelie, SciPy)": np.sqrt(np.mean((eta_adelie - eta_scipy) ** 2)),
            "SciPy success": scipy_fit.success,
        }
    ]
)
display(html_table(check, digits=6))
lambda Adelie objective SciPy objective RMSE eta(Adelie, SciPy) SciPy success
0.040000 0.176111 0.176111 0.000005 True

6 Path experiment: balance and effective sample size

Now fit the two-arm ATE path on a confounded simulated dataset. The diagnostic is the maximum absolute standardized mean difference against the full-sample covariate mean. Lower is better. Effective sample size tracks the variance cost of aggressive balancing.

Show code
x, w, y, e_true, tau = simulate(seed=44, n=2_500, p=12)
fit = fit_balnet_ate(x, w, n_lambdas=35, min_ratio=2e-3, progress_bar=False)

rows = []
target_weights = np.ones_like(w)
unweighted_treated = np.max(np.abs(standardized_mean_differences(x, w, target_weights)))
unweighted_control = np.max(np.abs(standardized_mean_differences(x, 1 - w, target_weights)))

L = min(len(fit.lambdas["treated"]), len(fit.lambdas["control"]))
for idx in range(L):
    bw = balancing_weights(fit, x, w, lambda_index=idx)
    for arm in ["treated", "control"]:
        smd = standardized_mean_differences(x, bw[arm], target_weights)
        rows.append(
            {
                "path_index": idx,
                "arm": arm,
                "lambda": fit.lambdas[arm][idx],
                "max_abs_smd": np.max(np.abs(smd)),
                "mean_abs_smd": np.mean(np.abs(smd)),
                "ess": effective_sample_size(bw[arm]),
            }
        )

diag = pd.DataFrame(rows)
summary_last = diag.groupby("arm").tail(1).copy()
summary_last["unweighted_max_abs_smd"] = summary_last["arm"].map({"treated": unweighted_treated, "control": unweighted_control})
summary_last["balance_reduction"] = 1 - summary_last["max_abs_smd"] / summary_last["unweighted_max_abs_smd"]
display(html_table(summary_last[["arm", "lambda", "unweighted_max_abs_smd", "max_abs_smd", "mean_abs_smd", "ess", "balance_reduction"]]))
arm lambda unweighted_max_abs_smd max_abs_smd mean_abs_smd ess balance_reduction
treated 0.0008 0.4224 0.0008 0.0008 594.3043 0.9980
control 0.0007 0.3598 0.0007 0.0007 832.0036 0.9980
Show code
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5), constrained_layout=True)
for arm, color in [("treated", "tab:blue"), ("control", "tab:orange")]:
    block = diag[diag["arm"] == arm]
    axes[0].plot(block["lambda"], block["max_abs_smd"], marker="o", label=arm, color=color)
    axes[1].plot(block["lambda"], block["ess"], marker="o", label=arm, color=color)
axes[0].axhline(unweighted_treated, color="tab:blue", linestyle="--", alpha=0.5)
axes[0].axhline(unweighted_control, color="tab:orange", linestyle="--", alpha=0.5)
for ax in axes:
    ax.set_xscale("log")
    ax.invert_xaxis()
    ax.legend()
axes[0].set_ylabel("max absolute SMD")
axes[0].set_xlabel("lambda")
axes[0].set_title("Balance improves down the path")
axes[1].set_ylabel("effective sample size")
axes[1].set_xlabel("lambda")
axes[1].set_title("Balancing costs ESS")
plt.show()

7 Treatment effect comparison

The final path point produces IPW estimates by weighting observed outcomes.

Show code
bw_final = balancing_weights(fit, x, w, lambda_index=-1)
ate_unweighted = y[w == 1].mean() - y[w == 0].mean()
ate_true_ipw = np.mean(w * y / e_true - (1 - w) * y / (1 - e_true))
ate_balnet = np.mean(y * (bw_final["treated"] - bw_final["control"]))

ate_df = pd.DataFrame(
    [
        {"estimator": "unweighted difference", "estimate": ate_unweighted, "error_vs_tau": ate_unweighted - tau},
        {"estimator": "oracle true-propensity IPW", "estimate": ate_true_ipw, "error_vs_tau": ate_true_ipw - tau},
        {"estimator": "balnet-Adelie POC IPW", "estimate": ate_balnet, "error_vs_tau": ate_balnet - tau},
    ]
)
display(html_table(ate_df))
estimator estimate error_vs_tau
unweighted difference 1.4200 0.4200
oracle true-propensity IPW 0.9930 -0.0070
balnet-Adelie POC IPW 0.9618 -0.0382

8 Implementation notes and caveats

This is a proof of concept, not a drop-in port yet.

What is working:

  • custom Python Adelie GLM for the balnet calibration loss;
  • lasso path through adelie.grpnet;
  • coefficient unstandardization;
  • two-arm ATE weights;
  • basic balance and ESS diagnostics;
  • tests against direct SciPy optimization and simulated balance reduction.

What still needs work for a faithful package-level port:

  • interpolation at arbitrary lambda values;
  • grouped penalties and path printing matching balnet;
  • ATT target scaling and diagnostics audited against the R package;
  • CV helpers using balance loss;
  • stronger numerical guards for poor overlap, where the calibration loss can diverge as \(\lambda \to 0\);
  • side-by-side tests against R balnet on fixed seeds.

The main takeaway is positive: Adelie’s Python API can solve the same mathematical problem by subclassing GlmBase64, so the Python port does not need to reimplement the coordinate-descent path solver.