[sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
PR 링크: sgl-project/sglang#20501 상태: Merged | 변경: +None / -None
들어가며
LLM의 디코드(Decode) 단계는 매 토큰 생성마다 반복되는 연산으로, 전체 추론 성능을 결정짓는 핵심 병목 구간입니다. 기존 SGLang의 sampler.py는 logits.div_(temperature)와 torch.softmax를 별도의 CUDA 커널로 실행했습니다. 이 방식은 두 번의 커널 호출과 6번의 전역 메모리 패스(4번 읽기, 2번 쓰기)를 유발하여, 특히 Llama 3(128K)나 Qwen(152K)과 같은 대규모 어휘 사전(Large-vocab) 모델에서 성능 저하를 야기했습니다. 본 PR은 이 두 연산을 하나의 Triton 커널로 융합하여 메모리 대역폭 사용을 획기적으로 줄였습니다.
코드 분석
1. python/sglang/srt/layers/fused_sampling.py (신규)
핵심 변경 사항은 Triton을 이용한 커널 융합입니다. 어휘 사전 크기에 따라 두 가지 전략을 사용합니다.
- Single-pass kernel: 어휘 사전이 32,768 이하일 때 사용하며, 데이터를 레지스터에 로드하여 1번의 읽기와 1번의 쓰기만으로 연산을 완료합니다.
- Multi-pass kernel: 대규모 어휘 사전을 위해 3-pass online softmax 방식을 채택했습니다.
triton.autotune을 통해BLOCK_SIZE와num_warps를 최적화합니다.
# Before (sampler.py)
logits.div_(sampling_info.temperatures)
logits[:] = torch.softmax(logits, dim=-1)
# After (fused_sampling.py - Single-pass 예시)
x = tl.load(logits_ptr + row_idx * logits_stride + offsets, mask=mask)
x = (x / temp).to(tl.float32)
x_max = tl.max(x, axis=0)
exp_x = tl.exp(x - x_max)
prob = exp_x / tl.sum(exp_x, axis=0)
tl.store(output_ptr + row_idx * output_stride + offsets, prob, mask=mask)
2. python/sglang/srt/layers/sampler.py (수정)
단순히 커널을 교체하는 것을 넘어, 성능과 정확성을 모두 잡기 위해 하이브리드 디스패치 로직을 도입했습니다.
# 하이브리드 디스패치 도입
if batch_size < 128:
# 소규모 배치에서는 PyTorch 네이티브 연산이 더 빠름
logits.div_(temperatures)
torch.softmax(logits, dim=-1, out=logits)
else:
# 대규모 배치에서는 융합된 Triton 커널 사용
fused_temperature_softmax_inplace(logits, temperatures)
왜 이게 좋은가
이번 최적화의 핵심은 메모리 패스 감소와 커널 호출 오버헤드 제거입니다.
- 성능 향상: 벤치마크 결과, 배치 사이즈가 128 이상인 경우 기존 대비 약 2~4배의 속도 향상을 보였습니다. 특히 대규모 어휘 사전 환경에서 그 효과가 극대화됩니다.
- 정확성 유지: 초기 구현에서 PyTorch와 연산 순서가 달라 구조화된 출력(Grammar-constrained decoding)에서 문제가 발생했으나, 3-pass 연산 방식을 채택하여 PyTorch와 수학적으로 동일한 결과를 보장하도록 수정되었습니다.
- 하이브리드 전략: 무조건적인 융합이 아닌, 배치 사이즈에 따른 분기 처리를 통해 소규모 배치에서의 오버헤드까지 고려했습니다.
교훈: 커널 융합은 단순히 연산을 합치는 것뿐만 아니라, 메모리 접근 패턴(Memory Access Pattern)을 최적화하는 것이 핵심입니다. 또한, 최적화된 커널이라도 특정 워크로드(소규모 배치)에서는 네이티브 라이브러리가 더 효율적일 수 있음을 인지하고 하이브리드 전략을 취하는 것이 중요합니다.
참고 자료
- Triton Documentation — 본 PR에서 사용된 Triton 커널 작성 가이드
- PyTorch Softmax Documentation — 기존 연산 방식의 기준점
참고 자료
- https://triton-lang.org/main/index.html
- https://pytorch.org/docs/stable/generated/torch.softmax.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang의 FA3 디코드 최적화: get_scheduler_metadata 도입
- [sglang] HiSparse 도입: Sparse Attention 모델을 위한 효율적인 KV 캐시 관리
- [triton] Triton 커널 최적화: High Occupancy Persistent Matmul 구현을 통한 성능 향상
- [triton] Triton PROTON: CUDA 그래프 프로파일링 오버헤드를 줄이고 MsgPack API를 추가하여 성능을 대폭 개선
- [triton] [NVIDIA] SM120을 위한 FP4 Native Scaled Matmul 지원 및 성능 최적화 분석
PR Analysis 의 다른글
- 이전글 [Loki] Ingester 타임아웃 반영하여 레이턴시 알림 임계값 1초에서 5초로 조정
- 현재글 : [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- 다음글 [sglang] SGLang: MiniMax-M2.5 MoE 모델을 위한 FP8 FlashInfer TRT-LLM 라우팅 최적화
댓글