Skip to content
back to blog
·9 min read·

Fused Linear Cross-Entropy

LM 헤드 프로젝션을 크로스 엔트로피와 결합하면, 긴 컨텍스트 학습에서 메모리를 가장 크게 줄일 수 있습니다.

기본 경로는 N × V 로짓 텐서 전체를 메모리에 유지합니다. 결합 경로는 C × V 청크를 SRAM 안에서 슬라이딩합니다. Manim 으로 렌더링했습니다.

주장

어휘 크기 VV, 토큰 수 N=BTN = B \cdot T 인 배치를 생각해 봅시다. 일반적인 "선형 → 소프트맥스 → NLL" 경로는 한 행씩 버려질 뿐인 N×VN \times V 로짓 텐서를 float32 로 끝까지 만들어 둡니다. 프로젝션을 손실과 결합하면 이 텐서가 피크 메모리에서 완전히 사라집니다. 그래디언트는 그대로 유지됩니다.

로짓 텐서가 메모리를 차지하는 이유

Llama-3 토크나이저로 V128kV \approx 128\text{k} 인 7B 모델을 T=8192T = 8192 로 학습한다고 합시다. LM 헤드는 WRV×dW \in \mathbb{R}^{V \times d} 이고 로짓은 [N,V][N, V] 모양입니다.

per-rank 마이크로 배치 B=1B = 1 에서 float32 로짓이 차지하는 크기는

8192×128000×4B    4.2GiB,8192 \times 128000 \times 4\,\text{B} \;\approx\; 4.2\,\text{GiB},

그리고 백워드에서 L/logits\partial L / \partial \text{logits} 가 다시 4.2 GiB 를 더 사용합니다. 헤드로 들어오는 히든 활성값은 8192×4096×2B67MiB8192 \times 4096 \times 2\,\text{B} \approx 67\,\text{MiB} (bf16) 이고, 가중치 행렬은 bf16 으로 약 1 GiB 입니다.

즉 스칼라 손실 하나를 계산하려고 활성 메모리에 8 GiB 를 쓰는 셈입니다. 로짓은 한 번 쓰고 소프트맥스에서 한 번 읽힌 뒤 버려집니다. 모델 안의 다른 활성값들은 모두 이 텐서보다 작습니다.

interactive · logits memory

fp32, fwd + bwd

Drag the sliders. The two bars are drawn to the same scale, so the fused bar shrinks linearly with chunk size.

8,192
128k
1024
baseline7.81 GiB
fused1000.0 MiB
memory saved87.5%

설정

기본 경로는 다음을 계산합니다.

i=logsoftmax(Whi)yi,L=1Nii,\ell_i = -\log \mathrm{softmax}(W h_i)_{y_i}, \qquad L = \frac{1}{N}\sum_i \ell_i,

여기서 전체 로짓 zi=WhiRVz_i = W h_i \in \mathbb{R}^V 를 먼저 저장한 뒤 줄입니다.

크로스 엔트로피는 토큰당 스칼라로 환원되므로, ziz_i 전체가 동시에 메모리에 있어야 할 이유가 없습니다. 로짓에 대한 그래디언트는 익숙한 형태로

Lzi  =  1N(softmax(zi)eyi),\frac{\partial L}{\partial z_i} \;=\; \frac{1}{N}\bigl(\mathrm{softmax}(z_i) - e_{y_i}\bigr),

이므로 L/hi\partial L/\partial h_iL/W\partial L/\partial W 에 대한 토큰별 기여를 [N,V][N, V] 텐서를 만들지 않고 한 행씩 누적할 수 있습니다.

결합 커널

hh 의 행을 크기 CC 인 청크로 흘려보냅니다. 청크마다 다음을 수행합니다.

  1. z=h[c:c+C]WRC×Vz = h_{[c:c+C]} W^\top \in \mathbb{R}^{C \times V} 를 bf16 으로 계산합니다.
  2. lsei=logjezij\mathrm{lse}_i = \log \sum_j e^{z_{ij}} 를 fp32 로 계산하고 i=lseizi,yi\ell_i = \mathrm{lse}_i - z_{i, y_i} 를 누적합니다.
  3. pij=ezijlseip_{ij} = e^{z_{ij} - \mathrm{lse}_i} 를 그 자리에 만들고, yiy_i 열에서 1 을 뺍니다.
  4. L/h[c:c+C]+=pW\partial L/\partial h_{[c:c+C]} \mathrel{+}= p\, WL/W+=ph[c:c+C]\partial L/\partial W \mathrel{+}= p^\top h_{[c:c+C]} 를 누적합니다.
  5. pp, zz 를 버리고 다음 청크로 넘어갑니다.

피크 로짓 메모리는 NVN \cdot V 에서 CVC \cdot V 로 줄어듭니다. C=1024C = 1024, V=128kV = 128\text{k} 이면 포워드 동안 4.2 GiB 대신 512 MiB 만 들고 있게 됩니다. 백워드의 그래디언트 텐서도 그 자리를 재사용하므로 절감 효과가 두 배가 됩니다.

interactive · chunked loop

32 tokens · chunk 4

Scrub the top slider to step through the kernel. The teal block is the live chunk that holds p in SRAM. Grey cells are tokens already processed and freed.

hidden states h[0..N]softmax buffer p (resident in SRAM)chunk 3 / 8 · tokens processed 8 / 32
8
4

구현 스케치

import torch

def fused_lce(h, W, targets, chunk_size=1024):
    # h: (N, d) bf16, requires_grad
    # W: (V, d) bf16, requires_grad
    # targets: (N,) int64
    N = h.shape[0]
    loss = h.new_zeros((), dtype=torch.float32)
    grad_h = torch.zeros_like(h)
    grad_W = torch.zeros_like(W)

    for c in range(0, N, chunk_size):
        end = min(c + chunk_size, N)
        h_c = h[c:end]
        y_c = targets[c:end]

        z = h_c @ W.T                                  # (C, V) bf16
        z_f = z.float()
        lse = torch.logsumexp(z_f, dim=-1)             # (C,) fp32
        nll = lse - z_f.gather(1, y_c[:, None]).squeeze(1)
        loss += nll.sum()

        p = (z_f - lse[:, None]).exp_()                # softmax, fp32
        p.scatter_add_(1, y_c[:, None], -torch.ones_like(p[:, :1]))

        grad_h[c:end] = (p @ W.float()).to(h.dtype)
        grad_W += (p.T @ h_c.float()).to(W.dtype)

    loss /= N
    grad_h /= N
    grad_W /= N
    return loss, grad_h, grad_W

위 코드는 교육용 버전입니다. 실제 구현은 Triton 또는 CUDA 커널에서 1 ~ 4 단계를 하나의 타일 안에 fuse 하여, C×VC \times V 크기의 pp 블록이 SRAM 을 벗어나지 않도록 합니다. Liger 와 Apple 의 Cut Cross-Entropy 가 모두 이 방식이며, bf16 또는 fp32 누산이나 타일 간 소프트맥스 안정화 같은 디테일에서 차이가 있습니다.

수치

Llama 3 8B, T=8192T = 8192, per-rank B=1B = 1, H100 80GB, vocab V=128kV = 128\text{k}.

경로로짓 피크fwd+bwd
baseline8.4 GiB1.00x
fused, C=1024C = 10240.50 GiB1.04x
fused, C=256C = 2560.13 GiB1.11x

V=128kV = 128\text{k} 에서 프로젝션은 메모리 바운드이므로 런타임 오버헤드는 작습니다. 청킹은 약간의 커널 런치 오버헤드를 내주는 대신 훨씬 작은 메모리 풋프린트를 얻습니다. 덕분에 더 긴 컨텍스트나 덜 공격적인 활성 체크포인팅, 또는 둘 다 사용할 여유가 생깁니다.

주의사항

  • 수치 안정성. 소프트맥스는 표준 logsumexp 트릭을 사용해 fp32 로 계산해야 합니다. fp32 캐스팅을 빠뜨리면 긴 시퀀스에서 학습 손실이 조용히 망가집니다.
  • 레이블 스무딩, z-loss, 보조 손실. 모두 로짓 또는 pp 에 대해 선형이므로 같은 루프 안에 깔끔하게 들어갑니다. 인기 있는 (lse)2(\mathrm{lse})^2 z-loss 는 2 단계에서 이미 계산한 lse\mathrm{lse} 만 있으면 됩니다.
  • 가중치 공유. 헤드가 입력 임베딩과 가중치를 공유하면 grad_W 가 같은 파라미터에 누적되며, 임베딩 lookup 의 그래디언트가 거기에 더해집니다.
  • 시퀀스 병렬. 청크 루프는 토큰 축을 따라 돌므로 시퀀스 병렬 레이아웃과 자연스럽게 합쳐집니다. 다만 시퀀스 샤드 간에 loss 와 그래디언트를 일관되게 리듀스해야 합니다.

참고

Liger Kernel 과 Apple 의 Cut Cross-Entropy 와 비교해 보시면 좋습니다. 트릭의 골격은 같고, 타일 크기, 데이터 타입 선택, SRAM 과 HBM 사이의 상태 배분에서 차이가 있습니다.