from _api_doc_utils import *MatrixCompletion
Nuclear-norm panel counterfactual completion
1 Where it fits
Group: Causal inference
MatrixCompletion treats untreated cells as observed entries and treated cells as missing counterfactuals. It estimates a low-rank untreated-outcome surface, optionally with unit and time effects, using nuclear-norm style shrinkage.
The completed values in treated cells become counterfactual outcomes for ATT and event-study summaries.
2 Python API
Constructor: cm.MatrixCompletion
Call MatrixCompletion(...).fit(y, w). predict() returns completed/counterfactual values and summary() reports ATT, completed matrices, treatment effects, low-rank components, singular values, objective history, and panel summaries.
print(inspect.signature(cm.MatrixCompletion))(lambda_l=None, lambda_fraction=0.25, fit_unit_effects=True, fit_time_effects=True, max_iterations=500, effect_iterations=2, tolerance=1e-06, svd_method=Ellipsis, svd_rank=None, svd_oversamples=10, svd_power_iter=1, svd_seed=None)
cls = cm.MatrixCompletion
display(HTML(html_table(["Public method"], public_methods(cls))))| Public method |
|---|
fit(self, /, y, w) |
predict(self, /) |
summary(self, /) |
3 Minimal example
rng=np.random.default_rng(18)
load=rng.normal(size=(10,2)); fac=rng.normal(size=(2,14)); y=load@fac+rng.normal(scale=.1,size=(10,14)); w=np.zeros_like(y); w[7:,9:]=1; y[7:,9:]+=1
model=cm.MatrixCompletion(max_iterations=100, tolerance=1e-5); model.fit(y,w)
print(model.summary()["att"])
print(model.predict().shape)1.201613382954045
(10, 14)
4 summary() contract
The table below is generated by fitting the live class in this repository and then inspecting summary(). Shapes are shown because most values are plain NumPy arrays or scalars.
rng=np.random.default_rng(118); y=rng.normal(size=(8,10)); w=np.zeros_like(y); w[6:,7:]=1; y[6:,7:]+=.8
model=cm.MatrixCompletion(max_iterations=80,tolerance=1e-5); model.fit(y,w)
summary = model.summary()
display(HTML(html_table(["summary() key", "shape"], summary_shape_rows(summary))))| summary() key | shape |
|---|---|
completed |
(8, 10) |
low_rank |
(8, 10) |
unit_effects |
(8,) |
time_effects |
(10,) |
singular_values |
(8,) |
lambda_l |
() |
objective |
() |
iterations |
() |
history_objective |
(16,) |
history_rmse |
(16,) |
svd_method |
() |
svd_rank |
() |
svd_oversamples |
() |
svd_power_iter |
() |
att |
() |
counterfactual |
(8, 10) |
treatment_effect |
(8, 10) |
event_study |
() |
group_means |
() |
control_units |
(6,) |
treated_units |
(2,) |
cohorts |
(1,) |