Exact Compression and Row Sketching for Discrete Covariates

When MECE cells make least squares small without approximation

Show code
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import HTML, display

np.set_printoptions(precision=4, suppress=True)


def html_table(df, digits=4):
    return HTML(
        df.to_html(
            index=False,
            border=0,
            classes="table table-sm table-striped",
            float_format=lambda x: f"{x:.{digits}f}",
        )
    )

1 The question

Suppose a large least-squares problem contains a collection of mutually exclusive and collectively exhaustive discrete covariates. Each row belongs to exactly one cell. If there are \(k\) occupied cells, the dummy matrix for those cells has rank \(k\), even when the raw dataset has millions of rows.

This creates two distinct ways to make the regression smaller.

  1. Exact compression. Collapse rows to cell-level sufficient statistics. This preserves the least-squares normal equations exactly.
  2. Row sketching. Randomly project the rows to a smaller synthetic dataset. This approximately preserves the least-squares geometry when the sketch is large enough.

The first is deterministic and lossless for the target normal equations. The second is randomized and approximate. They are connected because both replace a large row space by a smaller representation, but they are not the same object.

2 Pure discrete least squares

Let \(D \in \{0,1\}^{n \times k}\) be the MECE cell-dummy matrix. Row \(i\) has one nonzero entry, indicating its cell \(g(i)\). If every cell appears, \(D\) has rank \(k\).

OLS on only these discrete covariates solves

\[ \hat\alpha = \arg\min_a \sum_{i=1}^n (y_i - d_i'a)^2. \]

The full normal equations are

\[ D'D\hat\alpha = D'y. \]

Because \(D\) is a cell-membership matrix,

\[ D'D = \operatorname{diag}(n_1,\ldots,n_k), \qquad D'y = \left(\sum_{i:g(i)=1}y_i,\ldots,\sum_{i:g(i)=k}y_i\right)'. \]

So the full dataset can be replaced by \(k\) weighted representative rows:

\[ \tilde D = \operatorname{diag}(\sqrt{n_g}) I_k, \qquad \tilde y_g = \sqrt{n_g}\bar y_g. \]

Then

\[ \tilde D'\tilde D = D'D, \qquad \tilde D'\tilde y = D'y. \]

This is the clean row-span interpretation: repeated rows of \(D\) are collapsed to their unique row patterns, with weights equal to multiplicities. No approximation is involved.

3 Discrete controls plus continuous regressors

Now suppose the target coefficient is on continuous or treatment variables \(X \in \mathbb{R}^{n \times p}\), while \(D\) is a high-cardinality discrete control. The model is

\[ y = X\beta + D\alpha + \varepsilon. \]

By Frisch–Waugh–Lovell,

\[ \hat\beta = (X'M_DX)^{-1}X'M_Dy, \qquad M_D = I - D(D'D)^{-1}D'. \]

For MECE cells, applying \(M_D\) is just within-cell demeaning:

\[ (M_Dx)_i = x_i - \bar x_{g(i)}, \qquad (M_Dy)_i = y_i - \bar y_{g(i)}. \]

Therefore the residualized normal equations are sums of within-cell cross-products:

\[ X'M_DX = \sum_g \sum_{i\in g} (x_i-\bar x_g)(x_i-\bar x_g)', \]

and

\[ X'M_Dy = \sum_g \sum_{i\in g} (x_i-\bar x_g)(y_i-\bar y_g). \]

These are exactly recoverable from per-cell sufficient statistics:

\[ n_g,\quad \sum_{i\in g}x_i,\quad \sum_{i\in g}y_i,\quad \sum_{i\in g}x_ix_i',\quad \sum_{i\in g}x_iy_i. \]

Thus the compression is not row-sketching of the raw rows. It is exact compression of the residualized Gram system induced by the rank-\(k\) nuisance span. The large data pass only needs to return \(O(kp^2)\) aggregates, and the final solve is \(p\times p\).

4 Row sketching as the approximate cousin

A row sketch draws a random matrix \(S \in \mathbb{R}^{s\times n}\) and solves

\[ \hat\theta_S = \arg\min_\theta \lVert S(y - Z\theta)\rVert_2^2, \qquad Z = [X\ D]. \]

If \(S\) is a subspace embedding for the column span of \([Z,y]\), then the sketched problem approximately preserves the objective. This is useful when the design does not have an obvious exact grouping structure.

But if \(D\) is MECE and the target is the coefficient on \(X\) after absorbing \(D\), exact sufficient-statistic compression is usually the sharper statement:

  • it is deterministic;
  • it preserves the relevant normal equations exactly;
  • its compressed size depends on the number of occupied cells and the number of continuous regressors, not on a random embedding dimension;
  • it gives a transparent audit trail: counts, first moments, and second moments by cell.

Row sketching can still be useful as a computational baseline or for problems where continuous covariates generate too many distinct rows to collapse. It should not be sold as equivalent to the MECE sufficient-statistic trick.

5 Numerical experiment

The experiment below simulates a large dataset with many discrete cells and a small number of continuous regressors. We compare four estimates of the coefficient on \(X\):

  1. full OLS with cell dummies, computed through residualization;
  2. exact cell compression using only sufficient statistics;
  3. CountSketch applied to the residualized rows;
  4. uniform row subsampling of residualized rows.

The exact compression should match the full residualized OLS up to numerical precision. The randomized methods should approach it as the sketch size grows.

Show code
def make_data(n=200_000, k=1_000, p=4, seed=123):
    rng = np.random.default_rng(seed)
    # Unequal cell probabilities make the compression problem look realistic.
    weights = rng.gamma(shape=1.5, scale=1.0, size=k)
    probs = weights / weights.sum()
    g = rng.choice(k, size=n, p=probs)

    cell_effect = rng.normal(scale=2.0, size=k)
    cell_x_shift = rng.normal(scale=0.8, size=(k, p))
    x = cell_x_shift[g] + rng.normal(size=(n, p))
    beta = np.array([1.0, -0.7, 0.35, 0.2])
    y = x @ beta + cell_effect[g] + rng.normal(scale=1.0, size=n)
    return x, y, g, beta


def residualize_by_group(x, y, g, k):
    n, p = x.shape
    counts = np.bincount(g, minlength=k).astype(float)
    sx = np.zeros((k, p))
    sy = np.bincount(g, weights=y, minlength=k)
    np.add.at(sx, g, x)
    xbar = sx / counts[:, None]
    ybar = sy / counts
    return x - xbar[g], y - ybar[g]


def full_residualized_ols(x, y, g, k):
    xr, yr = residualize_by_group(x, y, g, k)
    beta = np.linalg.solve(xr.T @ xr, xr.T @ yr)
    return beta, xr, yr


def exact_compressed_beta(x, y, g, k):
    n, p = x.shape
    counts = np.bincount(g, minlength=k).astype(float)
    sx = np.zeros((k, p))
    sxy = np.zeros((k, p))
    sxx = np.zeros((k, p, p))
    sy = np.bincount(g, weights=y, minlength=k)

    np.add.at(sx, g, x)
    np.add.at(sxy, g, x * y[:, None])
    for j in range(p):
        for l in range(j, p):
            vals = x[:, j] * x[:, l]
            accum = np.bincount(g, weights=vals, minlength=k)
            sxx[:, j, l] = accum
            sxx[:, l, j] = accum

    xtx = sxx.sum(axis=0) - np.einsum("gp,gq,g->pq", sx, sx, 1.0 / counts)
    xty = sxy.sum(axis=0) - (sx * (sy / counts)[:, None]).sum(axis=0)
    return np.linalg.solve(xtx, xty), xtx, xty


def countsketch_beta(xr, yr, s, seed):
    rng = np.random.default_rng(seed)
    n, p = xr.shape
    buckets = rng.integers(0, s, size=n)
    signs = rng.choice(np.array([-1.0, 1.0]), size=n)
    sx = np.zeros((s, p))
    sy = np.zeros(s)
    np.add.at(sx, buckets, xr * signs[:, None])
    np.add.at(sy, buckets, yr * signs)
    return np.linalg.solve(sx.T @ sx, sx.T @ sy)


def subsample_beta(xr, yr, s, seed):
    rng = np.random.default_rng(seed)
    n = xr.shape[0]
    idx = rng.choice(n, size=s, replace=False)
    # Uniform rescaling cancels in OLS, so ordinary OLS on the sampled rows is enough.
    xs = xr[idx]
    ys = yr[idx]
    return np.linalg.solve(xs.T @ xs, xs.T @ ys)


x, y, g, beta_true = make_data()
k = int(g.max() + 1)
beta_full, xr, yr = full_residualized_ols(x, y, g, k)
beta_exact, xtx_exact, xty_exact = exact_compressed_beta(x, y, g, k)

print("n rows:", len(y))
print("occupied cells:", k)
print("continuous regressors:", x.shape[1])
print("true beta:", beta_true)
print("full residualized OLS:", beta_full)
print("exact compressed beta:", beta_exact)
print("max absolute difference, full vs exact:", np.max(np.abs(beta_full - beta_exact)))
n rows: 200000
occupied cells: 1000
continuous regressors: 4
true beta: [ 1.   -0.7   0.35  0.2 ]
full residualized OLS: [ 0.9998 -0.7013  0.3501  0.2004]
exact compressed beta: [ 0.9998 -0.7013  0.3501  0.2004]
max absolute difference, full vs exact: 2.55351295663786e-15

The exact method has compressed the relevant least-squares information to one small \(p\times p\) matrix and one \(p\)-vector, after a grouped aggregation pass.

Show code
sketch_sizes = [250, 500, 1_000, 2_000, 4_000, 8_000]
replications = 80
rows = []
for s in sketch_sizes:
    for r in range(replications):
        b_cs = countsketch_beta(xr, yr, s=s, seed=10_000 + 97 * r + s)
        b_sub = subsample_beta(xr, yr, s=s, seed=20_000 + 97 * r + s)
        rows.append(
            {
                "method": "CountSketch",
                "sketch_rows": s,
                "rel_error": np.linalg.norm(b_cs - beta_full) / np.linalg.norm(beta_full),
                "max_abs_error": np.max(np.abs(b_cs - beta_full)),
            }
        )
        rows.append(
            {
                "method": "Uniform subsample",
                "sketch_rows": s,
                "rel_error": np.linalg.norm(b_sub - beta_full) / np.linalg.norm(beta_full),
                "max_abs_error": np.max(np.abs(b_sub - beta_full)),
            }
        )

results = pd.DataFrame(rows)
summary = (
    results.groupby(["method", "sketch_rows"])
    .agg(
        median_relative_error=("rel_error", "median"),
        p90_relative_error=("rel_error", lambda z: np.quantile(z, 0.9)),
        median_max_abs_error=("max_abs_error", "median"),
    )
    .reset_index()
)
display(html_table(summary))
method sketch_rows median_relative_error p90_relative_error median_max_abs_error
CountSketch 250 0.0831 0.1310 0.0837
CountSketch 500 0.0575 0.0914 0.0568
CountSketch 1000 0.0462 0.0645 0.0447
CountSketch 2000 0.0313 0.0495 0.0317
CountSketch 4000 0.0217 0.0359 0.0216
CountSketch 8000 0.0159 0.0253 0.0157
Uniform subsample 250 0.0896 0.1362 0.0855
Uniform subsample 500 0.0605 0.0999 0.0642
Uniform subsample 1000 0.0447 0.0685 0.0447
Uniform subsample 2000 0.0308 0.0476 0.0302
Uniform subsample 4000 0.0244 0.0369 0.0242
Uniform subsample 8000 0.0153 0.0227 0.0155
Show code
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5), constrained_layout=True)
for method, color in [("CountSketch", "tab:blue"), ("Uniform subsample", "tab:orange")]:
    block = summary[summary["method"] == method]
    axes[0].plot(
        block["sketch_rows"],
        block["median_relative_error"],
        marker="o",
        label=method,
        color=color,
    )
    axes[1].plot(
        block["sketch_rows"],
        block["median_max_abs_error"],
        marker="o",
        label=method,
        color=color,
    )

axes[0].axhline(np.max(np.abs(beta_full - beta_exact)), color="white", linestyle="--", linewidth=1, label="Exact compression error")
axes[0].set_xscale("log")
axes[0].set_yscale("log")
axes[0].set_xlabel("Compressed / sketched rows")
axes[0].set_ylabel("Median relative coefficient error")
axes[0].set_title("Random row sketches approach full OLS")
axes[0].legend()

axes[1].axhline(np.max(np.abs(beta_full - beta_exact)), color="white", linestyle="--", linewidth=1, label="Exact compression error")
axes[1].set_xscale("log")
axes[1].set_yscale("log")
axes[1].set_xlabel("Compressed / sketched rows")
axes[1].set_ylabel("Median max absolute coefficient error")
axes[1].set_title("Exact compression is already at numerical zero")
axes[1].legend()
plt.show()

6 Interpretation

The experiment separates the two ideas cleanly.

  • The MECE cell structure gives an exact rank-\(k\) nuisance span. Absorbing the cell dummies only requires within-cell first and second moments.
  • Exact compression reproduces the full residualized OLS coefficients to floating-point tolerance.
  • Row sketching and subsampling reduce the row dimension too, but they are approximate. More sketch rows improve agreement, but the result varies across random draws.
  • In a system built around a database engine, exact compression is especially legible: one grouped aggregation query returns the sufficient statistics, and Python/Rust only solves a small dense system.

The conceptual slogan is:

Exact MECE compression is row-span compression when the rows are literally repeated discrete patterns; with continuous regressors, it becomes exact compression of the FWL-residualized Gram system. Row sketching is the approximate randomized analogue for designs without such a deterministic collapse.