Skip to content

Cross-Entropy Loss

Measures how well a predicted probability distribution matches a target distribution. The standard loss function for classification — both discrete (image classification, next-token prediction) and binary (spam detection, real/fake discrimination).

Imagine you’re placing bets. Cross-entropy asks: “if the true answer follows distribution P, how many bits do I waste by using my predicted distribution Q to encode it?” A perfect prediction wastes zero extra bits. The worse Q matches P, the more bits wasted.

For classification: the true distribution P is a one-hot vector (all probability on the correct class). Cross-entropy then simplifies to “how much probability did you put on the right answer?” If you put probability 0.9 on the correct class, your loss is −log(0.9) ≈ 0.105. If you put 0.01, your loss is −log(0.01) ≈ 4.6. The log means the penalty grows explosively as confidence in the wrong answer increases — this is the key property that makes it work well for training.

Note: cross-entropy on one-hot targets is mathematically identical to negative log-likelihood (NLL). The terms are used interchangeably in practice.

General form (discrete distributions):

H(P,Q)=iP(i)logQ(i)H(P, Q) = -\sum_{i} P(i) \log Q(i)

Classification (one-hot target, class yy is correct):

L=logQ(y)=logezyjezj\mathcal{L} = -\log Q(y) = -\log \frac{e^{z_y}}{\sum_j e^{z_j}}

where zjz_j are the raw logits (unnormalised scores) from the network.

Expanding the log-softmax:

L=zy+logjezj\mathcal{L} = -z_y + \log \sum_j e^{z_j}

This is the form actually computed — it’s numerically stable and avoids computing softmax explicitly.

Binary cross-entropy (single probability qq, target y{0,1}y \in \{0, 1\}):

L=[ylogq+(1y)log(1q)]\mathcal{L} = -\bigl[y \log q + (1 - y) \log (1 - q)\bigr]

With label smoothing (soften one-hot targets by mixing with uniform):

P(i)=(1α)P(i)+αKP'(i) = (1 - \alpha) \cdot P(i) + \frac{\alpha}{K}

where KK is the number of classes and α\alpha is typically 0.1. Prevents the model from becoming overconfident.

import torch
import torch.nn.functional as F
# ── Standard classification (logits → loss) ──────────────────────
# F.cross_entropy takes RAW LOGITS, not probabilities.
# It does log-softmax + NLL internally in a numerically stable way.
# NEVER apply softmax before this — you'll get wrong gradients and
# numerical issues.
logits = model(x) # (B, n_classes) — raw scores
targets = labels # (B,) — integer class indices
loss = F.cross_entropy(logits, targets) # scalar
# ── With label smoothing ─────────────────────────────────────────
loss = F.cross_entropy(logits, targets, label_smoothing=0.1)
# ── Binary classification (single logit per sample) ──────────────
logit = model(x) # (B, 1) or (B,) — single score
target = labels.float() # (B,) — 0.0 or 1.0
loss = F.binary_cross_entropy_with_logits(logit, target)
# Again: takes raw logits, applies sigmoid internally.
import numpy as np
def cross_entropy_manual(logits, targets):
"""
Equivalent to F.cross_entropy.
logits: (B, C) raw scores — NOT probabilities
targets: (B,) integer class indices
"""
B, C = logits.shape
# Numerically stable log-softmax: subtract max to prevent overflow in exp()
shifted = logits - logits.max(axis=1, keepdims=True) # (B, C)
log_sum_exp = np.log(np.exp(shifted).sum(axis=1, keepdims=True)) # (B, 1)
log_probs = shifted - log_sum_exp # (B, C)
# Pick the log-prob of the correct class for each sample
loss_per_sample = -log_probs[np.arange(B), targets] # (B,)
return loss_per_sample.mean()
def binary_cross_entropy_manual(logits, targets):
"""
Equivalent to F.binary_cross_entropy_with_logits.
logits: (B,) raw scores
targets: (B,) float 0.0 or 1.0
"""
# Numerically stable form: max(0, logit) - logit*target + log(1 + exp(-|logit|))
return (np.maximum(0, logits) - logits * targets
+ np.log1p(np.exp(-np.abs(logits)))).mean()
  • Image classification (ResNet, ViT): predict one class from K options
  • Language modelling / next-token prediction (GPT, LLaMA): cross-entropy over the full vocabulary at every position — this is THE training objective for LLMs
  • GAN discriminators (vanilla GAN): binary cross-entropy for real/fake classification
  • Knowledge distillation: cross-entropy between student and teacher softened distributions (with temperature)
  • Contrastive learning (SimCLR, CLIP): InfoNCE loss is cross-entropy where the “classes” are the positive pair vs. all negatives in the batch
AlternativeWhen to useTradeoff
MSE lossRegression, diffusion noise predictionDoesn’t penalise confident wrong answers as harshly; not suitable for classification
Focal lossClass-imbalanced classification (object detection)Down-weights easy examples, focuses on hard ones. Adds a (1q)γ(1-q)^\gamma modulating factor
Hinge lossSVMs, hinge GANMargin-based — only penalises if the correct class score isn’t above the margin. No probability interpretation
CTC lossSequence-to-sequence without alignment (speech recognition)Marginalises over all valid alignments between input and output sequences
KL divergenceSoft target distributions (distillation, VAE regularisation)Cross-entropy minus the target’s own entropy. Identical gradient when the target is fixed

Cross-entropy comes from information theory (Shannon, 1948), where it measures the expected message length when using code Q to encode messages from distribution P. It entered machine learning through logistic regression in statistics and was formalised as the standard classification loss through maximum likelihood estimation — minimising cross-entropy is equivalent to maximising the likelihood of the data under the model.

The key practical innovation was the “logits” formulation: computing log-softmax + NLL in a single fused operation (the logsumexp trick) rather than computing softmax first and then taking the log. This avoids numerical underflow/overflow and is the reason every modern framework has cross_entropy take raw logits rather than probabilities.

Label smoothing (Szegedy et al., 2016, “Rethinking the Inception Architecture”) was a simple but effective addition that prevents the model from learning to output extreme logits, improving generalisation and calibration.