Skip to content
back to blog
·5 min read·

Fused Linear Cross-Entropy

Why fusing the LM head projection with cross-entropy is the single biggest memory win for training LLMs at long context.

The baseline path keeps the full N × V logits tensor resident; the fused path slides a C × V chunk through SRAM. Rendered with Manim.

Claim

For a vocabulary of size VV and a batch of N=BTN = B \cdot T tokens, the standard "linear → softmax → NLL" path materializes an N×VN \times V logits tensor in float32 just to throw it away one row at a time. Fusing the projection with the loss removes that tensor from peak memory entirely, with no change to the gradient.

Why the logits tensor dominates

Take a 7B model trained at T=8192T = 8192 on a Llama-3 tokenizer with V128kV \approx 128\text{k}. The LM head is WRV×dW \in \mathbb{R}^{V \times d} and the logits are shaped [N,V][N, V].

At per-rank micro-batch B=1B = 1, the logits in float32 occupy

8192×128000×4B    4.2GiB,8192 \times 128000 \times 4\,\text{B} \;\approx\; 4.2\,\text{GiB},

with another 4.2 GiB for L/logits\partial L / \partial \text{logits} during backward. The hidden activations entering the head are 8192×4096×2B67MiB8192 \times 4096 \times 2\,\text{B} \approx 67\,\text{MiB} in bfloat16, and the weight matrix itself is 1GiB\approx 1\,\text{GiB} in bfloat16.

So we pay roughly 8 GiB of activation memory to compute a scalar loss. The logits are written, read once for the softmax, and then thrown away. Every other activation in the model is smaller than this one tensor.

interactive · logits memory

fp32, fwd + bwd

Drag the sliders. The two bars are drawn to the same scale, so the fused bar shrinks linearly with chunk size.

8,192
128k
1024
baseline7.81 GiB
fused1000.0 MiB
memory saved87.5%

Setup

The unfused path computes

i=logsoftmax(Whi)yi,L=1Nii,\ell_i = -\log \mathrm{softmax}(W h_i)_{y_i}, \qquad L = \frac{1}{N}\sum_i \ell_i,

storing the full logits zi=WhiRVz_i = W h_i \in \mathbb{R}^V before reducing.

Cross-entropy reduces to a scalar per token, so the full ziz_i never needs to be resident at once. The gradient w.r.t. logits has the familiar form

Lzi  =  1N(softmax(zi)eyi),\frac{\partial L}{\partial z_i} \;=\; \frac{1}{N}\bigl(\mathrm{softmax}(z_i) - e_{y_i}\bigr),

so per-token contributions to L/hi\partial L/\partial h_i and L/W\partial L/\partial W can be accumulated row-by-row without ever materializing the full [N,V][N, V] tensor.

The fused kernel

Stream over rows of hh in chunks of size CC. For each chunk:

  1. Compute z=h[c:c+C]WRC×Vz = h_{[c:c+C]} W^\top \in \mathbb{R}^{C \times V} in bf16.
  2. Compute lsei=logjezij\mathrm{lse}_i = \log \sum_j e^{z_{ij}} in fp32, accumulate i=lseizi,yi\ell_i = \mathrm{lse}_i - z_{i, y_i}.
  3. Form pij=ezijlseip_{ij} = e^{z_{ij} - \mathrm{lse}_i} in place; subtract one at column yiy_i.
  4. Accumulate L/h[c:c+C]+=pW\partial L/\partial h_{[c:c+C]} \mathrel{+}= p\, W and L/W+=ph[c:c+C]\partial L/\partial W \mathrel{+}= p^\top h_{[c:c+C]}.
  5. Free pp, zz. Move to the next chunk.

Peak logits memory drops from NVN \cdot V to CVC \cdot V. With C=1024C = 1024 and V=128kV = 128\text{k}, that's 512 MiB instead of 4.2 GiB during forward, and the backward grad tensor is reused in place so the saving doubles.

interactive · chunked loop

32 tokens · chunk 4

Scrub the top slider to step through the kernel. The teal block is the live chunk that holds p in SRAM. Grey cells are tokens already processed and freed.

hidden states h[0..N]softmax buffer p (resident in SRAM)chunk 3 / 8 · tokens processed 8 / 32
8
4

Implementation sketch

import torch

def fused_lce(h, W, targets, chunk_size=1024):
    # h: (N, d) bf16, requires_grad
    # W: (V, d) bf16, requires_grad
    # targets: (N,) int64
    N = h.shape[0]
    loss = h.new_zeros((), dtype=torch.float32)
    grad_h = torch.zeros_like(h)
    grad_W = torch.zeros_like(W)

    for c in range(0, N, chunk_size):
        end = min(c + chunk_size, N)
        h_c = h[c:end]
        y_c = targets[c:end]

        z = h_c @ W.T                                  # (C, V) bf16
        z_f = z.float()
        lse = torch.logsumexp(z_f, dim=-1)             # (C,) fp32
        nll = lse - z_f.gather(1, y_c[:, None]).squeeze(1)
        loss += nll.sum()

        p = (z_f - lse[:, None]).exp_()                # softmax, fp32
        p.scatter_add_(1, y_c[:, None], -torch.ones_like(p[:, :1]))

        grad_h[c:end] = (p @ W.float()).to(h.dtype)
        grad_W += (p.T @ h_c.float()).to(W.dtype)

    loss /= N
    grad_h /= N
    grad_W /= N
    return loss, grad_h, grad_W

The above is the pedagogical version. A production implementation is a Triton or CUDA kernel that fuses steps 1 through 4 inside one tile, so the C×VC \times V block of pp never leaves SRAM. Liger and Apple's Cut Cross-Entropy both do this, with different choices around bf16 vs fp32 accumulation and how the softmax is stabilized across tiles.

Numbers

Llama 3 8B, T=8192T = 8192, B=1B = 1 per rank, H100 80GB, vocab V=128kV = 128\text{k}.

pathlogits peakfwd+bwd
baseline8.4 GiB1.00x
fused, C=1024C = 10240.50 GiB1.04x
fused, C=256C = 2560.13 GiB1.11x

The runtime overhead is small because the projection is memory-bound at V=128kV = 128\text{k}; chunking trades a bit of kernel-launch overhead for a much smaller resident footprint. In return you get headroom for longer context, less aggressive activation checkpointing, or both.

Caveats

  • Numerical stability. The softmax must be computed in fp32 with the standard logsumexp trick. Skipping the fp32 cast silently degrades training loss on long sequences.
  • Label smoothing, z-loss, auxiliary losses. All are linear in the logits or in pp, so they fit cleanly into the same loop. The popular (lse)2(\mathrm{lse})^2 z-loss only needs lse\mathrm{lse}, which you already compute in step 2.
  • Weight tying. If the head shares weights with the input embedding, grad_W accumulates into the same parameter; the embedding-lookup gradient from the input side adds on top.
  • Sequence parallelism. The chunked loop runs along the token axis, so it composes with sequence-parallel layouts. Just be careful to reduce loss and grads consistently across the sequence shards.

References

Compare against Liger Kernel and Apple's Cut Cross-Entropy. The shape of the trick is the same; the differences are in tile sizes, dtype choices, and how much state lives in SRAM versus HBM.