crabbymetrics
  • Home
  • API
  • Binding Crash Course
  • Regression And GLMs
    • OLS
    • Ridge
    • Fixed Effects OLS
    • ElasticNet
    • Logit
    • Multinomial Logit
    • Poisson
    • GMM
    • FTRL
    • MEstimator Poisson
  • Causal Inference
    • Balancing Weights
    • EPLM
    • Average Derivative
    • Double ML And AIPW
    • Richer Regression
    • TwoSLS
    • Synthetic Control
    • Synthetic DID
    • Horizontal Panel Ridge
    • Matrix Completion
    • Interactive Fixed Effects
    • Staggered Panel Event Study
  • Transforms
    • PCA And Kernel Basis
  • Ablations
    • Variance Estimators
    • Semiparametric Estimator Comparisons
    • Bridging Finite And Superpopulation
    • Panel Estimator DGP Comparisons
    • Same Root Panel Case Studies
    • Randomized Sketching OLS
  • Optimization
    • Optimizers
    • GMM With Optimizers
  • Ding: First Course
    • Overview And TOC
    • Ch 1 Correlation And Simpson
    • Ch 2 Potential Outcomes
    • Ch 3 CRE And Fisher RT
    • Ch 4 CRE And Neyman
    • Ch 9 Bridging Finite And Superpopulation
    • Ch 11 Propensity Score
    • Ch 12 Double Robust ATE
    • Ch 13 Double Robust ATT
    • Ch 21 Experimental IV
    • Ch 23 Econometric IV
    • Ch 27 Mediation

Randomized Sketching OLS

This cached ablation checks the new randomized sketching path for tall least-squares designs.

The library path under test is OLS.fit_sketch(...), which uses a CountSketch row embedding. Each original observation is assigned to one signed sketch bucket, so the sketch construction is \(O(np)\) for an \(n \times p\) design. The sketched estimator then solves the smaller least-squares problem

\[ \hat\beta_S = \arg\min_b \lVert S y - S X b \rVert_2^2, \]

where \(S\) is the sparse signed embedding. This is deliberately different from a dense Gaussian/Rademacher sketch, whose direct construction would cost \(O(snp)\) for sketch size \(s\).

The goal here is not to show that sketching beats full OLS for every moderate problem. It is to document the accuracy/speed tradeoff on a reproducible synthetic design and catch obvious regressions in the implementation.

1 Setup

Show code
from html import escape
import time

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display

import crabbymetrics as cm

np.set_printoptions(precision=4, suppress=True)


def html_table(headers, rows):
    parts = ["<table>", "<thead>", "<tr>"]
    parts.extend(f"<th>{escape(str(header))}</th>" for header in headers)
    parts.extend(["</tr>", "</thead>", "<tbody>"])
    for row in rows:
        parts.append("<tr>")
        parts.extend(f"<td>{cell}</td>" for cell in row)
        parts.append("</tr>")
    parts.extend(["</tbody>", "</table>"])
    return "".join(parts)

2 Design

The DGP is intentionally friendly to OLS: standardized Gaussian regressors, a dense linear signal, and homoskedastic additive noise. That makes the full OLS solution a useful numerical target. We vary the sketch size as a multiple of the full design dimension, including the intercept.

Show code
def run_once(n, p, noise, multiple, seed):
    rng = np.random.default_rng(seed)
    x = rng.normal(size=(n, p))
    beta = rng.normal(size=p)
    intercept = 0.4
    y = intercept + x @ beta + noise * rng.normal(size=n)
    sketch_size = int(multiple * (p + 1))

    full = cm.OLS()
    t0 = time.perf_counter()
    full.fit(x, y)
    full_time = time.perf_counter() - t0
    full_summary = full.summary(vcov="vanilla")
    full_param = np.concatenate([[full_summary["intercept"]], full_summary["coef"]])

    sketch = cm.OLS()
    t0 = time.perf_counter()
    sketch.fit_sketch(x, y, sketch_size=sketch_size, seed=seed + 10_000)
    sketch_time = time.perf_counter() - t0
    sketch_summary = sketch.summary(vcov="vanilla")
    sketch_param = np.concatenate([[sketch_summary["intercept"]], sketch_summary["coef"]])

    x_test = rng.normal(size=(2000, p))
    full_pred = full.predict(x_test)
    sketch_pred = sketch.predict(x_test)

    return {
        "n": n,
        "p": p,
        "noise": noise,
        "multiple": multiple,
        "sketch_size": sketch_size,
        "coef_rel_error": float(
            np.linalg.norm(sketch_param - full_param) / np.linalg.norm(full_param)
        ),
        "prediction_rmse_vs_full": float(np.sqrt(np.mean((sketch_pred - full_pred) ** 2))),
        "full_fit_seconds": full_time,
        "sketch_fit_seconds": sketch_time,
        "speedup": full_time / sketch_time if sketch_time > 0 else np.inf,
    }


n = 20_000
p = 40
noise = 0.2
multiples = [2, 4, 8, 12, 16]
seeds = range(5)

rows = [run_once(n, p, noise, multiple, seed) for multiple in multiples for seed in seeds]
print("replications:", len(rows))
print("design:", {"n": n, "p": p, "noise": noise})
replications: 25
design: {'n': 20000, 'p': 40, 'noise': 0.2}

3 Summary

Show code
summary = []
for multiple in multiples:
    block = [row for row in rows if row["multiple"] == multiple]
    summary.append(
        {
            "multiple": multiple,
            "sketch_size": block[0]["sketch_size"],
            "median_coef_rel_error": float(np.median([r["coef_rel_error"] for r in block])),
            "median_prediction_rmse_vs_full": float(
                np.median([r["prediction_rmse_vs_full"] for r in block])
            ),
            "median_speedup": float(np.median([r["speedup"] for r in block])),
        }
    )

summary_rows = [
    [
        f"{entry['multiple']}x",
        entry["sketch_size"],
        f"{entry['median_coef_rel_error']:.4f}",
        f"{entry['median_prediction_rmse_vs_full']:.4f}",
        f"{entry['median_speedup']:.2f}x",
    ]
    for entry in summary
]

display(
    HTML(
        html_table(
            [
                "Sketch multiple",
                "Sketch size",
                "Median coefficient relative error",
                "Median prediction RMSE vs full OLS",
                "Median speedup",
            ],
            summary_rows,
        )
    )
)
Sketch multiple Sketch size Median coefficient relative error Median prediction RMSE vs full OLS Median speedup
2x 82 0.0312 0.2242 3.70x
4x 164 0.0188 0.1140 4.49x
8x 328 0.0106 0.0745 4.34x
12x 492 0.0086 0.0629 3.29x
16x 656 0.0083 0.0562 3.53x

The expected pattern is monotone accuracy improvement as the sketch dimension grows. Speedups are meaningful here because CountSketch touches each row once and then solves a much smaller least-squares problem.

4 Plots

Show code
fig, axes = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)

x_axis = np.array([entry["multiple"] for entry in summary])
coef_error = np.array([entry["median_coef_rel_error"] for entry in summary])
pred_error = np.array([entry["median_prediction_rmse_vs_full"] for entry in summary])
speedup = np.array([entry["median_speedup"] for entry in summary])

axes[0].plot(x_axis, coef_error, marker="o")
axes[0].set_xlabel("Sketch size / design dimension")
axes[0].set_ylabel("Median relative coefficient error")
axes[0].set_title("Coefficient accuracy")

axes[1].plot(x_axis, pred_error, marker="o", color="tab:orange")
axes[1].set_xlabel("Sketch size / design dimension")
axes[1].set_ylabel("Median RMSE vs full OLS prediction")
axes[1].set_title("Prediction agreement")

axes[2].plot(x_axis, speedup, marker="o", color="tab:green")
axes[2].axhline(1.0, color="black", linestyle="--", linewidth=1)
axes[2].set_xlabel("Sketch size / design dimension")
axes[2].set_ylabel("Median full/sketch fit time")
axes[2].set_title("Runtime")

plt.show()

5 Takeaway

For this tall, well-conditioned design, a sketch size between \(8(p+1)\) and \(16(p+1)\) gives roughly one-percent coefficient error relative to full OLS while still reducing fit time by several times. Smaller sketches are visibly faster but can introduce enough coefficient and prediction error that they should be treated as approximate exploratory fits rather than a drop-in replacement for final inference.