Fused Linear Cross-Entropy
LM 헤드 프로젝션을 크로스 엔트로피와 결합하면, 긴 컨텍스트 학습에서 메모리를 가장 크게 줄일 수 있습니다.
주장
어휘 크기 , 토큰 수 인 배치를 생각해 봅시다. 일반적인 "선형 → 소프트맥스 → NLL" 경로는 한 행씩 버려질 뿐인 로짓 텐서를 float32 로 끝까지 만들어 둡니다. 프로젝션을 손실과 결합하면 이 텐서가 피크 메모리에서 완전히 사라집니다. 그래디언트는 그대로 유지됩니다.
로짓 텐서가 메모리를 차지하는 이유
Llama-3 토크나이저로 인 7B 모델을 로 학습한다고 합시다. LM 헤드는 이고 로짓은 모양입니다.
per-rank 마이크로 배치 에서 float32 로짓이 차지하는 크기는
그리고 백워드에서 가 다시 4.2 GiB 를 더 사용합니다. 헤드로 들어오는 히든 활성값은 (bf16) 이고, 가중치 행렬은 bf16 으로 약 1 GiB 입니다.
즉 스칼라 손실 하나를 계산하려고 활성 메모리에 8 GiB 를 쓰는 셈입니다. 로짓은 한 번 쓰고 소프트맥스에서 한 번 읽힌 뒤 버려집니다. 모델 안의 다른 활성값들은 모두 이 텐서보다 작습니다.
interactive · logits memory
fp32, fwd + bwd
Drag the sliders. The two bars are drawn to the same scale, so the fused bar shrinks linearly with chunk size.
설정
기본 경로는 다음을 계산합니다.
여기서 전체 로짓 를 먼저 저장한 뒤 줄입니다.
크로스 엔트로피는 토큰당 스칼라로 환원되므로, 전체가 동시에 메모리에 있어야 할 이유가 없습니다. 로짓에 대한 그래디언트는 익숙한 형태로
이므로 와 에 대한 토큰별 기여를 텐서를 만들지 않고 한 행씩 누적할 수 있습니다.
결합 커널
의 행을 크기 인 청크로 흘려보냅니다. 청크마다 다음을 수행합니다.
- 를 bf16 으로 계산합니다.
- 를 fp32 로 계산하고 를 누적합니다.
- 를 그 자리에 만들고, 열에서 1 을 뺍니다.
- 와 를 누적합니다.
- , 를 버리고 다음 청크로 넘어갑니다.
피크 로짓 메모리는 에서 로 줄어듭니다. , 이면 포워드 동안 4.2 GiB 대신 512 MiB 만 들고 있게 됩니다. 백워드의 그래디언트 텐서도 그 자리를 재사용하므로 절감 효과가 두 배가 됩니다.
interactive · chunked loop
32 tokens · chunk 4
Scrub the top slider to step through the kernel. The teal block is the live chunk that holds p in SRAM. Grey cells are tokens already processed and freed.
구현 스케치
import torch
def fused_lce(h, W, targets, chunk_size=1024):
# h: (N, d) bf16, requires_grad
# W: (V, d) bf16, requires_grad
# targets: (N,) int64
N = h.shape[0]
loss = h.new_zeros((), dtype=torch.float32)
grad_h = torch.zeros_like(h)
grad_W = torch.zeros_like(W)
for c in range(0, N, chunk_size):
end = min(c + chunk_size, N)
h_c = h[c:end]
y_c = targets[c:end]
z = h_c @ W.T # (C, V) bf16
z_f = z.float()
lse = torch.logsumexp(z_f, dim=-1) # (C,) fp32
nll = lse - z_f.gather(1, y_c[:, None]).squeeze(1)
loss += nll.sum()
p = (z_f - lse[:, None]).exp_() # softmax, fp32
p.scatter_add_(1, y_c[:, None], -torch.ones_like(p[:, :1]))
grad_h[c:end] = (p @ W.float()).to(h.dtype)
grad_W += (p.T @ h_c.float()).to(W.dtype)
loss /= N
grad_h /= N
grad_W /= N
return loss, grad_h, grad_W
위 코드는 교육용 버전입니다. 실제 구현은 Triton 또는 CUDA 커널에서 1 ~ 4 단계를 하나의 타일 안에 fuse 하여, 크기의 블록이 SRAM 을 벗어나지 않도록 합니다. Liger 와 Apple 의 Cut Cross-Entropy 가 모두 이 방식이며, bf16 또는 fp32 누산이나 타일 간 소프트맥스 안정화 같은 디테일에서 차이가 있습니다.
수치
Llama 3 8B, , per-rank , H100 80GB, vocab .
| 경로 | 로짓 피크 | fwd+bwd |
|---|---|---|
| baseline | 8.4 GiB | 1.00x |
| fused, | 0.50 GiB | 1.04x |
| fused, | 0.13 GiB | 1.11x |
에서 프로젝션은 메모리 바운드이므로 런타임 오버헤드는 작습니다. 청킹은 약간의 커널 런치 오버헤드를 내주는 대신 훨씬 작은 메모리 풋프린트를 얻습니다. 덕분에 더 긴 컨텍스트나 덜 공격적인 활성 체크포인팅, 또는 둘 다 사용할 여유가 생깁니다.
주의사항
- 수치 안정성. 소프트맥스는 표준 logsumexp 트릭을 사용해 fp32 로 계산해야 합니다. fp32 캐스팅을 빠뜨리면 긴 시퀀스에서 학습 손실이 조용히 망가집니다.
- 레이블 스무딩, z-loss, 보조 손실. 모두 로짓 또는 에 대해 선형이므로 같은 루프 안에 깔끔하게 들어갑니다. 인기 있는 z-loss 는 2 단계에서 이미 계산한 만 있으면 됩니다.
- 가중치 공유. 헤드가 입력 임베딩과 가중치를 공유하면
grad_W가 같은 파라미터에 누적되며, 임베딩 lookup 의 그래디언트가 거기에 더해집니다. - 시퀀스 병렬. 청크 루프는 토큰 축을 따라 돌므로 시퀀스 병렬 레이아웃과
자연스럽게 합쳐집니다. 다만 시퀀스 샤드 간에
loss와 그래디언트를 일관되게 리듀스해야 합니다.
참고
Liger Kernel 과 Apple 의 Cut Cross-Entropy 와 비교해 보시면 좋습니다. 트릭의 골격은 같고, 타일 크기, 데이터 타입 선택, SRAM 과 HBM 사이의 상태 배분에서 차이가 있습니다.