Multiclass Logistic Regression
Extend logistic regression from binary to multiclass classification using the softmax function, cross-entropy loss, and gradient descent — with full derivation and interactive demo.
From Binary to Multiclass
In Logistic Regression, we modeled a binary outcome using the sigmoid function. But many real-world problems have more than two classes — for example, classifying handwritten digits (0–9) or categorizing species.
How do we extend logistic regression to handle classes?
The Setup
For each class , we define a linear predictor:
where is the feature vector, and each class has its own weight vector and bias .
The Softmax Function
Instead of the sigmoid, we use the softmax function to convert the logits into probabilities:
This guarantees two essential properties:
- Each probability is positive: for all
- Probabilities sum to one:
The predicted class is the one with the highest probability:
Notice that when , softmax reduces to the sigmoid function from binary logistic regression.
One-Hot Encoding and Likelihood
To express the true label as a vector, we use one-hot encoding: if class is correct, otherwise. Let .
The likelihood of observing the correct label is:
Since only one , this picks out the probability of the true class. Taking the log:
Cross-Entropy Loss
Negating the log-likelihood gives us the cross-entropy loss:
This is the natural generalization of the binary cross-entropy loss from Logistic Regression. It heavily penalizes the model when the predicted probability for the true class is small.
Deriving the Gradient
To optimize with gradient descent, we need and . We use the chain rule:
Step 1 — Loss w.r.t. probabilities:
Step 2 — Softmax derivative (the tricky part!):
The softmax derivative requires care because both the numerator and denominator depend on . Using the Kronecker delta (where if , else ):
Step 3 — Combining via the chain rule:
Since and picks out the -th term:
This is remarkably simple and elegant — the gradient is just the difference between the predicted probability and the true label!
Step 4 — Finally, since and :
Gradient Descent Update
The parameters are updated iteratively:
where is the learning rate. Compare this to the gradient in binary logistic regression — the structure is identical, just extended to classes.
Interactive Demo
Explore multiclass logistic regression with 3 classes in 2D. The Decision Regions tab shows how the model partitions the input space, and the Softmax Probabilities tab visualizes for each class.
Learned Parameters
| Class | w₁ | w₂ | b |
|---|---|---|---|
| Class 0 | 0.3000 | 0.3000 | 0.0000 |
| Class 1 | 0.3000 | -0.3000 | 0.0000 |
| Class 2 | -0.3000 | 0.0000 | 0.0000 |
How It Works
Things to Try
- Start Gradient Descent and watch the decision boundaries form in real time
- Hover over data points to see the softmax probability breakdown for each class
- Switch to the Softmax Probabilities tab to see individual class probability heatmaps — notice how they always sum to 1
- Add overlapping clusters to see how the model handles ambiguous regions
Connection to Binary Logistic Regression
When , the softmax function simplifies:
This is exactly the sigmoid function from Logistic Regression with an effective weight .
Connection to GLMs
Like binary logistic regression, multiclass logistic regression fits into the GLM framework:
| Component | Choice |
|---|---|
| Distribution | Categorical (Multinoulli) |
| Link function | Softmax (generalized logit) |
| Linear predictor | for each class |
After classification, evaluate your model’s per-class performance using metrics like Sensitivity, Specificity, and ROC curves (applied in a one-vs-rest fashion).
Summary
Multiclass logistic regression extends binary classification to classes by replacing the sigmoid with the softmax function. The cross-entropy loss generalizes naturally, and the gradient has the same elegant form as the binary case. The softmax derivative is the trickiest part of the derivation, requiring careful application of the quotient rule and Kronecker delta — but the final result is clean and computationally efficient.