Cross-entropy is the standard loss function for language models. In implementations, it is usually abstracted away by a function call (e.g., PyTorch’s F.cross_entropy). This is handy but hides what’s actually happening. In this blog post, I’ll explain what’s going on under the hood and discuss the relationship between cross-entropy and negative log likelihood (NLL).

Inputs of cross-entropy: logits and targets

Let’s first understand the inputs of F.cross_entropy: logits and targets. I’ve explained them in the table below, where $B$ denotes the batch size and $V$ denotes the vocabulary size.

Tensor Shape Description
logits $(B,V)$ Each row has the predicted (by the LLM) unnormalized logits for all possible tokens in the vocabulary. Note: $p_{\text{LM}}(\cdot
targets $(B,)$ Labels: for each example in the batch, the label is the correct next token ID that should follow the given context. Each value is a token ID and thus an integer between 0 and $V-1$.

To make this concrete, imagine the following setup:

Assuming that we form a batch just from this sequence, we can say that:

The table below shows each example in the batch and its label.

Preceding context, $x_{<t}$ Correct next token, $x_t$ Logits ($V$ logits, one per vocabulary token)
$[45]$ $128$ $[2.1, -1.4, 0.8, \ldots,1.1]$
$[45, 128]$ $293$ $[0.5, 2.8, -0.9, \ldots,-2.0]$
$[45, 128, 293]$ $501$ $[1.3, 0.2, 4.1, \ldots,0.5]$
$[45, 128, 293, 501]$ $672$ $[-0.8, 3.5, 1.9, \ldots,-3.2]$

Computing cross-entropy

As mentioned earlier, PyTorch has a built-in function for computing cross-entropy:

loss = F.cross_entropy(logits, targets)

This produces the same result as the calculation outlined below:

Step (in calculation of cross-entropy) Shape Comment
counts = logits.exp() $(B,V)$ Converts logits to unnormalized probabilities.
p = counts / counts.sum(dim=1, keepdim=True)
$(B,V)$ Normalization. p is the probability distribution over all possible next tokens, $p_{\theta}(\cdot x_{<t})$.
p_labels = p[torch.arange(B), targets] $(B,)$ Extracts the probability for each label. p_labels is $p_θ(x_t
loss = -p_labels.log().mean() Scalar Mean negative log likelihood (NLL), i.e. cross-entropy. Computes the NLL for each example in the batch and averages across the batch.

While valid, the calculation above isn’t numerically stable. To overcome this, we can use the log-sum-exp trick. In fact, F.cross_entropy also does that. I’ve outlined the key steps for computing cross-entropy using the log-sum-exp below. For improved readability, I’ve included a few extra steps.

Step (in calculation of cross-entropy) Shape Comment
`logit_maxes = logits.max(dim=1, keepdim=True)
                 .values` | $(B, 1)$ | Get the max logit value across the vocabulary dimension for each example in the batch.  |

| shifted_logits = logits - logit_maxes | $(B,V)$ | Subtract the max from each row for numerical stability, specifically preventing overflow during exponentiation. | | counts = shifted_logits.exp() | $(B, V)$ | Computes unnormalized probabilities by exponentiating the shifted logits. | | counts_sum = counts.sum(dim=1, keepdim=True) | $(B, 1)$ | Computes sum of unnormalized probabilities for each example. | | p = counts / counts_sum | $(B, V)$ | Softmax. p is $p_{\theta}(\cdot|x_{<t})$. | | logp = p.log() | $(B, V)$ | Log probability (not log likelihood!) | | logp_labels = logp[torch.arange(B), targets] | $(B,)$ | Log likelihood, $p_θ(x_t | x_{<t})$ | | loss = -logp_labels.mean() | Scalar | Mean NLL, i.e. cross-entropy |

A final note on F.cross_entropy compared to the two methods I’ve discussed: