[sglang] SGLang EAGLE 디코딩 최적화: 불필요한 Softmax 연산 제거로 성능 향상
PR 링크: sgl-project/sglang#26235 상태: Merged | 변경: +26 / -6
들어가며
최근 대규모 언어 모델(LLM)의 추론 속도 향상을 위해 다양한 최적화 기법이 연구되고 있습니다. 그중 하나인 추측적 디코딩(Speculative Decoding)은 작은 모델(draft model)이 생성한 토큰 시퀀스를 큰 모델(target model)이 검증하는 방식으로, 전체 추론 속도를 크게 향상시킬 수 있습니다. SGLang은 EAGLE이라는 추측적 디코딩 방식을 구현하여 LLM 추론 성능을 높이고 있습니다.
이번 PR은 SGLang의 EAGLE 디코딩 방식에서 topk == 1일 때 발생하는 불필요한 연산을 제거하여 성능을 개선하는 것을 목표로 합니다. 특히, greedy 디코딩 경로에서 매 스텝마다 수행되던 torch.softmax 연산이 실제로는 아무런 영향을 주지 않음에도 불구하고 계산 비용을 발생시키는 문제를 해결합니다.
코드 변경 분석
이번 PR은 주로 eagle_worker_v2.py와 eagle_draft_extend_cuda_graph_runner.py 파일에서 EAGLE 디코딩 로직을 수정했습니다. 핵심 변경 사항은 topk == 1인 경우, 전체 어휘(vocabulary)에 대한 torch.softmax 연산을 건너뛰고 바로 torch.argmax를 사용하여 가장 확률이 높은 토큰을 선택하도록 변경한 것입니다.
1. eagle_worker_v2.py
이 파일은 EAGLE 디코딩의 메인 워커 로직을 담당합니다. 두 곳에서 topk == 1일 때의 최적화가 이루어졌습니다.
1.1. draft_forward 함수 내 변경
draft_forward 함수는 작은 모델이 토큰을 생성하는 핵심 루프입니다. 각 스텝마다 다음 토큰의 로짓(logits)을 계산하고, 이를 기반으로 topk 샘플링을 수행합니다. 기존 코드에서는 topk 값에 상관없이 항상 torch.softmax를 먼저 계산한 후 fast_topk 함수를 호출했습니다.
Before:
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
After:
if self.topk == 1:
# topk=1 → degenerate single-path tree; `topk_p` is unused
# downstream, so skip softmax and just argmax over logits.
topk_index = torch.argmax(
logits_output.next_token_logits, dim=-1, keepdim=True
)
topk_p = torch.ones_like(topk_index, dtype=torch.float32)
else:
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
topk == 1인 경우, torch.softmax 연산을 건너뛰고 torch.argmax를 직접 사용하여 가장 높은 로짓 값을 가진 토큰의 인덱스를 얻습니다. 이때 topk_p는 1로 채워진 텐서로 설정됩니다. 이는 topk == 1일 때, 즉 greedy 디코딩 경로에서는 topk_p 값이 실제로 사용되지 않기 때문입니다. topk_index = argmax(logits)는 topk_index = argmax(softmax(logits))와 동일한 결과를 반환하므로, softmax 연산은 불필요한 계산이 됩니다.
1.2. _draft_extend_for_decode 함수 내 변경
이 함수는 생성된 draft 토큰을 기반으로 실제 모델의 출력을 확장하는 부분입니다. 이 함수 내부에서도 마찬가지로 topk == 1일 때 torch.softmax 연산을 건너뛰도록 수정되었습니다.
Before:
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
After:
if self.topk == 1:
ret_topk_index = torch.argmax(
draft_logits_output.next_token_logits, dim=-1, keepdim=True
)
ret_topk_p = torch.ones_like(ret_topk_index, dtype=torch.float32)
else:
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
draft_forward와 동일한 로직으로, topk == 1일 때 softmax 연산을 제거하고 argmax를 사용하도록 변경되었습니다.
2. eagle_draft_extend_cuda_graph_runner.py
이 파일은 CUDA 그래프를 사용하여 _draft_extend_for_decode 로직을 최적화하는 부분을 담당합니다. CUDA 그래프는 연산을 캡처하여 재사용함으로써 오버헤드를 줄입니다. 이 부분에서도 topk == 1일 때 torch.softmax 연산이 불필요하게 포함되어 있었습니다.
Before:
probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)
After:
if self.topk == 1:
ret.topk_index = torch.argmax(
ret.next_token_logits, dim=-1, keepdim=True
)
ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32)
else:
probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)
CUDA 그래프 캡처 내에서도 topk == 1일 경우 softmax 연산을 제거하고 argmax를 사용하도록 수정되었습니다. 이는 CUDA 그래프 실행 시에도 불필요한 연산이 수행되는 것을 방지하여 성능을 향상시킵니다.
왜 이게 좋은가?
이번 PR의 핵심은 topk == 1이라는 특정 조건 하에서 torch.softmax 연산이 불필요하다는 점을 정확히 파악하고 이를 제거했다는 것입니다.
성능 향상
PR 설명에 따르면, 이 변경으로 인해 다음과 같은 성능 향상이 측정되었습니다 (Kimi-K2.5-NVFP4 / TP=4 / 80K ctx / EAGLE3 3-step / bs=1 워크로드 기준):
cunn_SoftMaxForward연산이 GPU 프로파일링에서 완전히 사라졌습니다.- 이전에는
DRAFT_DECODE루프에서 2번,_draft_extend_for_decode에서 1번,DRAFT_EXTENDCUDA 그래프 내에서 1번, 총 4번의softmax연산이 수행되었으나, 이제는 이 연산이 제거되었습니다. - Mean TPOT (Time Per Output Token): 2.41 ms → 2.36 ms (-0.05 ms)
- Median TPOT: 2.37 ms → 2.34 ms (-0.03 ms)
- 1000/Mean (Tokens/sec): 414.9 → 423.7 (+8.8 tok/s, +2.1%)
- 1000/Med: 421.9 → 427.4 (+5.5 tok/s, +1.3%)
이 수치들은 작아 보일 수 있지만, LLM 추론에서 초당 처리할 수 있는 토큰 수가 증가하는 것은 매우 의미 있는 개선입니다. 특히 대규모 트래픽을 처리하는 서비스에서는 이러한 작은 개선들이 모여 큰 성능 향상을 가져올 수 있습니다.
일반적인 교훈
- 조건부 최적화의 중요성: 모든 코드 경로에 대해 동일한 최적화를 적용하는 것보다, 특정 조건(예:
topk == 1)에서 불필요한 연산을 식별하고 제거하는 것이 훨씬 효과적일 수 있습니다. 알고리즘의 특성을 깊이 이해하는 것이 중요합니다. - 프로파일링의 힘:
cunn_SoftMaxForward연산이 병목 지점임을 프로파일링을 통해 정확히 파악하고, 이를 제거함으로써 성능 향상을 이끌어냈습니다. 성능 최적화는 측정 가능한 데이터에 기반해야 합니다. - 수학적 동등성 활용:
argmax(logits)와argmax(softmax(logits))가 동일한 결과를 반환한다는 수학적 사실을 활용하여 불필요한softmax연산을 제거했습니다. 이는 연산의 본질을 이해하는 것의 중요성을 보여줍니다. - CUDA 그래프 최적화: CUDA 그래프와 같이 성능에 민감한 부분에서도 불필요한 연산을 제거하는 것이 중요합니다. 그래프 내부에 포함된 연산 하나하나가 실제 실행 시간에 영향을 미칩니다.
결론
이번 PR은 SGLang의 EAGLE 추측적 디코딩에서 topk == 1일 때 발생하는 불필요한 torch.softmax 연산을 제거함으로써 추론 성능을 약 1.3% ~ 2.1% 향상시켰습니다. 이는 알고리즘의 특성을 깊이 이해하고, 프로파일링을 통해 병목 지점을 정확히 찾아내어 최적화를 수행한 좋은 사례입니다. 이러한 최적화는 LLM 추론의 효율성을 높이는 데 크게 기여할 것입니다.
References
참고 자료
- https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html
- https://pytorch.org/docs/stable/generated/torch.argmax.html
- https://github.com/sgl-project/sglang
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] 성능 최적화의 함정: DeepSeek-V3.2 정확도 붕괴를 막기 위한 SGLang의 긴급 롤백 분석
- [sglang] SGLang Ngram Speculative Decoding 최적화: MatchState 증분 업데이트 성능 개선
- [sglang] [SGLang] Blackwell(B200)에서 Diffusion Attention 성능을 7배 끌어올리는 Triton 커널 최적화 분석
- [sglang] LTX2 스플릿 로터리 커널 최적화: 헤드 배치 처리로 성능 2배 향상
- [sglang] SGLang NPU 최적화: MoE 모델을 위한 Dual Stream 병렬 처리 도입
PR Analysis 의 다른글
- 이전글 [cpython] Python의 os.fork 후 발생하던 성능 프로파일링 충돌 문제 해결 및 최적화 분석
- 현재글 : [sglang] SGLang EAGLE 디코딩 최적화: 불필요한 Softmax 연산 제거로 성능 향상
- 다음글 [vllm] vLLM, GDN Prefill 커널을 CuteDSL로 최적화하여 성능 향상
댓글