Cross-entropy is the standard loss function for language models. In implementations, it is usually abstracted away by a function call (e.g., PyTorch’s F.cross_entropy
). This is handy but hides what’s actually happening. In this blog post, I’ll explain what’s going on under the hood and discuss the relationship between cross-entropy and negative log likelihood (NLL).
Let’s first understand the inputs of F.cross_entropy
: logits and targets. I’ve explained them in the table below, where $B$ denotes the batch size and $V$ denotes the vocabulary size.
Tensor | Shape | Description |
---|---|---|
logits |
$(B,V)$ | Each row has the predicted (by the LLM) unnormalized logits for all possible tokens in the vocabulary. Note: $p_{\text{LM}}(\cdot |
targets |
$(B,)$ | Labels: for each example in the batch, the label is the correct next token ID that should follow the given context. Each value is a token ID and thus an integer between 0 and $V-1$. |
To make this concrete, imagine the following setup:
Assuming that we form a batch just from this sequence, we can say that:
targets
is $[128, 293, 501, 672]$, which has shape $(4,)$.logits
has shape $(4,5000)$.The table below shows each example in the batch and its label.
Preceding context, $x_{<t}$ | Correct next token, $x_t$ | Logits ($V$ logits, one per vocabulary token) |
---|---|---|
$[45]$ | $128$ | $[2.1, -1.4, 0.8, \ldots,1.1]$ |
$[45, 128]$ | $293$ | $[0.5, 2.8, -0.9, \ldots,-2.0]$ |
$[45, 128, 293]$ | $501$ | $[1.3, 0.2, 4.1, \ldots,0.5]$ |
$[45, 128, 293, 501]$ | $672$ | $[-0.8, 3.5, 1.9, \ldots,-3.2]$ |
As mentioned earlier, PyTorch has a built-in function for computing cross-entropy:
loss = F.cross_entropy(logits, targets)
This produces the same result as the calculation outlined below:
Step (in calculation of cross-entropy) | Shape | Comment |
---|---|---|
counts = logits.exp() |
$(B,V)$ | Converts logits to unnormalized probabilities. |
p = counts / counts.sum(dim=1, keepdim=True) |
||
$(B,V)$ | Normalization. p is the probability distribution over all possible next tokens, $p_{\theta}(\cdot |
x_{<t})$. |
p_labels = p[torch.arange(B), targets] |
$(B,)$ | Extracts the probability for each label. p_labels is $p_θ(x_t |
loss = -p_labels.log().mean() |
Scalar | Mean negative log likelihood (NLL), i.e. cross-entropy. Computes the NLL for each example in the batch and averages across the batch. |
While valid, the calculation above isn’t numerically stable. To overcome this, we can use the log-sum-exp trick. In fact, F.cross_entropy
also does that. I’ve outlined the key steps for computing cross-entropy using the log-sum-exp below. For improved readability, I’ve included a few extra steps.
Step (in calculation of cross-entropy) | Shape | Comment |
---|---|---|
`logit_maxes = logits.max(dim=1, keepdim=True) |
.values` | $(B, 1)$ | Get the max logit value across the vocabulary dimension for each example in the batch. |
| shifted_logits = logits - logit_maxes
| $(B,V)$ | Subtract the max from each row for numerical stability, specifically preventing overflow during exponentiation. |
| counts = shifted_logits.exp()
| $(B, V)$ | Computes unnormalized probabilities by exponentiating the shifted logits. |
| counts_sum = counts.sum(dim=1, keepdim=True)
| $(B, 1)$ | Computes sum of unnormalized probabilities for each example. |
| p = counts / counts_sum
| $(B, V)$ | Softmax. p
is $p_{\theta}(\cdot|x_{<t})$. |
| logp = p.log()
| $(B, V)$ | Log probability (not log likelihood!) |
| logp_labels = logp[torch.arange(B), targets]
| $(B,)$ | Log likelihood, $p_θ(x_t | x_{<t})$ |
| loss = -logp_labels.mean()
| Scalar | Mean NLL, i.e. cross-entropy |
A final note on F.cross_entropy
compared to the two methods I’ve discussed: