본문으로 건너뛰기

[sglang] SGLang EAGLE 디코딩 최적화: 불필요한 Softmax 연산 제거로 성능 향상

PR 링크: sgl-project/sglang#26235 상태: Merged | 변경: +26 / -6

들어가며

최근 대규모 언어 모델(LLM)의 추론 속도 향상을 위해 다양한 최적화 기법이 연구되고 있습니다. 그중 하나인 추측적 디코딩(Speculative Decoding)은 작은 모델(draft model)이 생성한 토큰 시퀀스를 큰 모델(target model)이 검증하는 방식으로, 전체 추론 속도를 크게 향상시킬 수 있습니다. SGLang은 EAGLE이라는 추측적 디코딩 방식을 구현하여 LLM 추론 성능을 높이고 있습니다.

이번 PR은 SGLang의 EAGLE 디코딩 방식에서 topk == 1일 때 발생하는 불필요한 연산을 제거하여 성능을 개선하는 것을 목표로 합니다. 특히, greedy 디코딩 경로에서 매 스텝마다 수행되던 torch.softmax 연산이 실제로는 아무런 영향을 주지 않음에도 불구하고 계산 비용을 발생시키는 문제를 해결합니다.

코드 변경 분석

이번 PR은 주로 eagle_worker_v2.pyeagle_draft_extend_cuda_graph_runner.py 파일에서 EAGLE 디코딩 로직을 수정했습니다. 핵심 변경 사항은 topk == 1인 경우, 전체 어휘(vocabulary)에 대한 torch.softmax 연산을 건너뛰고 바로 torch.argmax를 사용하여 가장 확률이 높은 토큰을 선택하도록 변경한 것입니다.

1. eagle_worker_v2.py

이 파일은 EAGLE 디코딩의 메인 워커 로직을 담당합니다. 두 곳에서 topk == 1일 때의 최적화가 이루어졌습니다.

1.1. draft_forward 함수 내 변경

draft_forward 함수는 작은 모델이 토큰을 생성하는 핵심 루프입니다. 각 스텝마다 다음 토큰의 로짓(logits)을 계산하고, 이를 기반으로 topk 샘플링을 수행합니다. 기존 코드에서는 topk 값에 상관없이 항상 torch.softmax를 먼저 계산한 후 fast_topk 함수를 호출했습니다.

Before:

probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)

After:

if self.topk == 1:
    # topk=1 → degenerate single-path tree; `topk_p` is unused
    # downstream, so skip softmax and just argmax over logits.
    topk_index = torch.argmax(
        logits_output.next_token_logits, dim=-1, keepdim=True
    )
    topk_p = torch.ones_like(topk_index, dtype=torch.float32)
else:
    probs = torch.softmax(logits_output.next_token_logits, dim=-1)
    topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)

topk == 1인 경우, torch.softmax 연산을 건너뛰고 torch.argmax를 직접 사용하여 가장 높은 로짓 값을 가진 토큰의 인덱스를 얻습니다. 이때 topk_p는 1로 채워진 텐서로 설정됩니다. 이는 topk == 1일 때, 즉 greedy 디코딩 경로에서는 topk_p 값이 실제로 사용되지 않기 때문입니다. topk_index = argmax(logits)topk_index = argmax(softmax(logits))와 동일한 결과를 반환하므로, softmax 연산은 불필요한 계산이 됩니다.

1.2. _draft_extend_for_decode 함수 내 변경

이 함수는 생성된 draft 토큰을 기반으로 실제 모델의 출력을 확장하는 부분입니다. 이 함수 내부에서도 마찬가지로 topk == 1일 때 torch.softmax 연산을 건너뛰도록 수정되었습니다.

Before:

probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)

After:

if self.topk == 1:
    ret_topk_index = torch.argmax(
        draft_logits_output.next_token_logits, dim=-1, keepdim=True
    )
    ret_topk_p = torch.ones_like(ret_topk_index, dtype=torch.float32)
else:
    probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
    ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)

draft_forward와 동일한 로직으로, topk == 1일 때 softmax 연산을 제거하고 argmax를 사용하도록 변경되었습니다.

2. eagle_draft_extend_cuda_graph_runner.py

이 파일은 CUDA 그래프를 사용하여 _draft_extend_for_decode 로직을 최적화하는 부분을 담당합니다. CUDA 그래프는 연산을 캡처하여 재사용함으로써 오버헤드를 줄입니다. 이 부분에서도 topk == 1일 때 torch.softmax 연산이 불필요하게 포함되어 있었습니다.

Before:

probs = torch.softmax(ret.next_token_logits, dim=-1)
ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)

After:

if self.topk == 1:
    ret.topk_index = torch.argmax(
        ret.next_token_logits, dim=-1, keepdim=True
    )
    ret.topk_p = torch.ones_like(ret.topk_index, dtype=torch.float32)
else:
    probs = torch.softmax(ret.next_token_logits, dim=-1)
    ret.topk_p, ret.topk_index = fast_topk(probs, self.topk, dim=-1)

CUDA 그래프 캡처 내에서도 topk == 1일 경우 softmax 연산을 제거하고 argmax를 사용하도록 수정되었습니다. 이는 CUDA 그래프 실행 시에도 불필요한 연산이 수행되는 것을 방지하여 성능을 향상시킵니다.

왜 이게 좋은가?

이번 PR의 핵심은 topk == 1이라는 특정 조건 하에서 torch.softmax 연산이 불필요하다는 점을 정확히 파악하고 이를 제거했다는 것입니다.

성능 향상

PR 설명에 따르면, 이 변경으로 인해 다음과 같은 성능 향상이 측정되었습니다 (Kimi-K2.5-NVFP4 / TP=4 / 80K ctx / EAGLE3 3-step / bs=1 워크로드 기준):

  • cunn_SoftMaxForward 연산이 GPU 프로파일링에서 완전히 사라졌습니다.
  • 이전에는 DRAFT_DECODE 루프에서 2번, _draft_extend_for_decode에서 1번, DRAFT_EXTEND CUDA 그래프 내에서 1번, 총 4번의 softmax 연산이 수행되었으나, 이제는 이 연산이 제거되었습니다.
  • Mean TPOT (Time Per Output Token): 2.41 ms → 2.36 ms (-0.05 ms)
  • Median TPOT: 2.37 ms → 2.34 ms (-0.03 ms)
  • 1000/Mean (Tokens/sec): 414.9 → 423.7 (+8.8 tok/s, +2.1%)
  • 1000/Med: 421.9 → 427.4 (+5.5 tok/s, +1.3%)

이 수치들은 작아 보일 수 있지만, LLM 추론에서 초당 처리할 수 있는 토큰 수가 증가하는 것은 매우 의미 있는 개선입니다. 특히 대규모 트래픽을 처리하는 서비스에서는 이러한 작은 개선들이 모여 큰 성능 향상을 가져올 수 있습니다.

일반적인 교훈

  1. 조건부 최적화의 중요성: 모든 코드 경로에 대해 동일한 최적화를 적용하는 것보다, 특정 조건(예: topk == 1)에서 불필요한 연산을 식별하고 제거하는 것이 훨씬 효과적일 수 있습니다. 알고리즘의 특성을 깊이 이해하는 것이 중요합니다.
  2. 프로파일링의 힘: cunn_SoftMaxForward 연산이 병목 지점임을 프로파일링을 통해 정확히 파악하고, 이를 제거함으로써 성능 향상을 이끌어냈습니다. 성능 최적화는 측정 가능한 데이터에 기반해야 합니다.
  3. 수학적 동등성 활용: argmax(logits)argmax(softmax(logits))가 동일한 결과를 반환한다는 수학적 사실을 활용하여 불필요한 softmax 연산을 제거했습니다. 이는 연산의 본질을 이해하는 것의 중요성을 보여줍니다.
  4. CUDA 그래프 최적화: CUDA 그래프와 같이 성능에 민감한 부분에서도 불필요한 연산을 제거하는 것이 중요합니다. 그래프 내부에 포함된 연산 하나하나가 실제 실행 시간에 영향을 미칩니다.

결론

이번 PR은 SGLang의 EAGLE 추측적 디코딩에서 topk == 1일 때 발생하는 불필요한 torch.softmax 연산을 제거함으로써 추론 성능을 약 1.3% ~ 2.1% 향상시켰습니다. 이는 알고리즘의 특성을 깊이 이해하고, 프로파일링을 통해 병목 지점을 정확히 찾아내어 최적화를 수행한 좋은 사례입니다. 이러한 최적화는 LLM 추론의 효율성을 높이는 데 크게 기여할 것입니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글