Fused Linear Cross-Entropy
Why fusing the LM head projection with cross-entropy is the single biggest memory win for training LLMs at long context.
Claim
For a vocabulary of size and a batch of tokens, the standard "linear → softmax → NLL" path materializes an logits tensor in float32 just to throw it away one row at a time. Fusing the projection with the loss removes that tensor from peak memory entirely, with no change to the gradient.
Why the logits tensor dominates
Take a 7B model trained at on a Llama-3 tokenizer with . The LM head is and the logits are shaped .
At per-rank micro-batch , the logits in float32 occupy
with another 4.2 GiB for during backward. The hidden activations entering the head are in bfloat16, and the weight matrix itself is in bfloat16.
So we pay roughly 8 GiB of activation memory to compute a scalar loss. The logits are written, read once for the softmax, and then thrown away. Every other activation in the model is smaller than this one tensor.
interactive · logits memory
fp32, fwd + bwd
Drag the sliders. The two bars are drawn to the same scale, so the fused bar shrinks linearly with chunk size.
Setup
The unfused path computes
storing the full logits before reducing.
Cross-entropy reduces to a scalar per token, so the full never needs to be resident at once. The gradient w.r.t. logits has the familiar form
so per-token contributions to and can be accumulated row-by-row without ever materializing the full tensor.
The fused kernel
Stream over rows of in chunks of size . For each chunk:
- Compute in bf16.
- Compute in fp32, accumulate .
- Form in place; subtract one at column .
- Accumulate and .
- Free , . Move to the next chunk.
Peak logits memory drops from to . With and , that's 512 MiB instead of 4.2 GiB during forward, and the backward grad tensor is reused in place so the saving doubles.
interactive · chunked loop
32 tokens · chunk 4
Scrub the top slider to step through the kernel. The teal block is the live chunk that holds p in SRAM. Grey cells are tokens already processed and freed.
Implementation sketch
import torch
def fused_lce(h, W, targets, chunk_size=1024):
# h: (N, d) bf16, requires_grad
# W: (V, d) bf16, requires_grad
# targets: (N,) int64
N = h.shape[0]
loss = h.new_zeros((), dtype=torch.float32)
grad_h = torch.zeros_like(h)
grad_W = torch.zeros_like(W)
for c in range(0, N, chunk_size):
end = min(c + chunk_size, N)
h_c = h[c:end]
y_c = targets[c:end]
z = h_c @ W.T # (C, V) bf16
z_f = z.float()
lse = torch.logsumexp(z_f, dim=-1) # (C,) fp32
nll = lse - z_f.gather(1, y_c[:, None]).squeeze(1)
loss += nll.sum()
p = (z_f - lse[:, None]).exp_() # softmax, fp32
p.scatter_add_(1, y_c[:, None], -torch.ones_like(p[:, :1]))
grad_h[c:end] = (p @ W.float()).to(h.dtype)
grad_W += (p.T @ h_c.float()).to(W.dtype)
loss /= N
grad_h /= N
grad_W /= N
return loss, grad_h, grad_W
The above is the pedagogical version. A production implementation is a Triton or CUDA kernel that fuses steps 1 through 4 inside one tile, so the block of never leaves SRAM. Liger and Apple's Cut Cross-Entropy both do this, with different choices around bf16 vs fp32 accumulation and how the softmax is stabilized across tiles.
Numbers
Llama 3 8B, , per rank, H100 80GB, vocab .
| path | logits peak | fwd+bwd |
|---|---|---|
| baseline | 8.4 GiB | 1.00x |
| fused, | 0.50 GiB | 1.04x |
| fused, | 0.13 GiB | 1.11x |
The runtime overhead is small because the projection is memory-bound at ; chunking trades a bit of kernel-launch overhead for a much smaller resident footprint. In return you get headroom for longer context, less aggressive activation checkpointing, or both.
Caveats
- Numerical stability. The softmax must be computed in fp32 with the standard logsumexp trick. Skipping the fp32 cast silently degrades training loss on long sequences.
- Label smoothing, z-loss, auxiliary losses. All are linear in the logits or in , so they fit cleanly into the same loop. The popular z-loss only needs , which you already compute in step 2.
- Weight tying. If the head shares weights with the input embedding,
grad_Waccumulates into the same parameter; the embedding-lookup gradient from the input side adds on top. - Sequence parallelism. The chunked loop runs along the token axis, so it
composes with sequence-parallel layouts. Just be careful to reduce
lossand grads consistently across the sequence shards.
References
Compare against Liger Kernel and Apple's Cut Cross-Entropy. The shape of the trick is the same; the differences are in tile sizes, dtype choices, and how much state lives in SRAM versus HBM.