from _api_doc_utils import *AIPW
Cross-fit augmented inverse-probability weighting
1 Where it fits
Group: Causal inference
AIPW estimates a binary-treatment ATE by combining outcome regressions and a propensity model:
\[ \hat\tau = n^{-1}\sum_i \left[\hat\mu_1(x_i)-\hat\mu_0(x_i) + \frac{d_i(y_i-\hat\mu_1(x_i))}{\hat e(x_i)} - \frac{(1-d_i)(y_i-\hat\mu_0(x_i))}{1-\hat e(x_i)}\right]. \]
The nuisance functions are cross-fit ridge models.
2 Python API
Constructor: cm.AIPW
Call fit(y, d, x) with binary treatment d. summary() reports ate, se, vcov, and selected penalties for the outcome and propensity nuisance models.
print(inspect.signature(cm.AIPW))(penalty=None, cv=5, n_folds=5, propensity_clip=0.02, seed=42)
cls = cm.AIPW
display(HTML(html_table(["Public method"], public_methods(cls))))| Public method |
|---|
fit(self, /, y, d, x) |
summary(self, /, vcov=None, lags=None, clusters=None) |
3 Minimal example
rng=np.random.default_rng(14)
x=rng.normal(size=(420,3)); pi=1/(1+np.exp(-(.1+x@np.array([.6,-.3,.2])))); d=rng.binomial(1,pi,size=420).astype(float); y=.5+x@np.array([.2,-.1,.3])+1.0*d+rng.normal(size=420)
model=cm.AIPW(penalty=np.logspace(-4,1,10), cv=3, n_folds=4, seed=2); model.fit(y,d,x)
print(model.summary()["ate"])
print(model.summary()["se"])1.1056887995229303
0.1055551427292932
4 summary() contract
The table below is generated by fitting the live class in this repository and then inspecting summary(). Shapes are shown because most values are plain NumPy arrays or scalars.
rng=np.random.default_rng(114); x=rng.normal(size=(160,3)); pi=1/(1+np.exp(-(.1+x@np.array([.6,-.3,.2])))); d=rng.binomial(1,pi,size=160).astype(float); y=.5+x@np.array([.2,-.1,.3])+d+rng.normal(size=160)
model=cm.AIPW(penalty=np.logspace(-4,1,6),cv=3,n_folds=4,seed=2); model.fit(y,d,x)
summary = model.summary()
display(HTML(html_table(["summary() key", "shape"], summary_shape_rows(summary))))| summary() key | shape |
|---|---|
ate |
() |
se |
() |
vcov |
(1, 1) |
outcome0_penalties |
(4,) |
outcome1_penalties |
(4,) |
propensity_penalties |
(4,) |