from _api_doc_utils import *MultinomialLogit
Multiclass logistic regression
1 Where it fits
Group: Regression
MultinomialLogit generalizes binary logit to \(K\) classes with softmax probabilities:
\[ \Pr(Y_i=k\mid X_i=x_i)=\frac{\exp(\alpha_k+x_i'\beta_k)}{\sum_\ell \exp(\alpha_\ell+x_i'\beta_\ell)}. \]
The summary packs coefficients and standard errors by class.
2 Python API
Constructor: cm.MultinomialLogit
Use integer class labels in fit(x, y_int32). predict(x) returns class labels. summary() returns coef and se matrices rather than the scalar-intercept/vector-coefficient schema used by binary GLMs.
print(inspect.signature(cm.MultinomialLogit))(alpha=1.0, max_iterations=100, gradient_tolerance=0.0001)
cls = cm.MultinomialLogit
display(HTML(html_table(["Public method"], public_methods(cls))))| Public method |
|---|
bootstrap(self, /, n_bootstrap, seed=None) |
fit(self, /, x, y) |
predict(self, /, x) |
summary(self, /) |
3 Minimal example
rng=np.random.default_rng(6)
x=rng.normal(size=(240,2)); logits=x@np.array([[.6,-.3],[-.4,.5],[.2,.2]]).T + np.array([.1,-.2,0.])
p=np.exp(logits-logits.max(axis=1,keepdims=True)); p=p/p.sum(axis=1,keepdims=True)
y=np.array([rng.choice(3,p=row) for row in p], dtype=np.int32)
model=cm.MultinomialLogit(max_iterations=200); model.fit(x,y)
print(model.summary()["coef"])
print(model.predict(x[:5]))[[ 0.22393229 0.46606385 -0.46127557]
[-0.18578571 -0.41057874 0.38304187]
[-0.03814657 -0.05548511 0.0782337 ]]
[1 1 0 1 0]
4 summary() contract
The table below is generated by fitting the live class in this repository and then inspecting summary(). Shapes are shown because most values are plain NumPy arrays or scalars.
rng=np.random.default_rng(106); x=rng.normal(size=(100,2)); logits=x@np.array([[.6,-.3],[-.4,.5],[.2,.2]]).T; p=np.exp(logits-logits.max(1,keepdims=True)); p=p/p.sum(1,keepdims=True); y=np.array([rng.choice(3,p=row) for row in p],dtype=np.int32)
model=cm.MultinomialLogit(max_iterations=200); model.fit(x,y)
summary = model.summary()
display(HTML(html_table(["summary() key", "shape"], summary_shape_rows(summary))))| summary() key | shape |
|---|---|
coef |
(3, 3) |
se |
(3, 3) |