[sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
PR 링크: sgl-project/sglang#26129 상태: Merged | 변경: +37 / -6
들어가며
대규모 언어 모델(LLM)의 추론 성능은 모델 자체의 효율성뿐만 아니라, 이를 구동하는 프레임워크의 최적화 수준에 크게 좌우됩니다. 특히 GPU에서 텐서 연산을 수행할 때 발생하는 '커널 런치 오버헤드(Kernel Launch Overhead)'는 미세하지만 반복적으로 발생하여 전체 성능에 큰 영향을 미칠 수 있는 병목 지점입니다. GPU 커널 런치는 CPU가 GPU에 작업을 지시하고 결과를 기다리는 과정에서 발생하는 일종의 통신 비용으로, 작은 연산들이 빈번하게 발생할수록 이 오버헤드가 누적되어 성능 저하를 초래합니다.
sglang 프로젝트의 이번 PR은 이러한 문제에 대한 현명한 해결책을 제시합니다. FutureMap._resolve_spec_extras 메서드 내에서 여러 번의 advanced indexing gather 연산이 개별적으로 수행되면서 발생하는 GPU 커널 런치 오버헤드를 torch.compile을 활용하여 효과적으로 줄이는 최적화를 적용했습니다. 이로 인해 spec_v2 디코드 프롤로그에서 반복당 3개의 커널 런치를 줄이는 성과를 달성했습니다.
코드 분석: overlap_utils.py의 변화
이번 PR의 핵심 변경사항은 python/sglang/srt/managers/overlap_utils.py 파일에 집중되어 있습니다. 기존에는 _resolve_spec_extras 메서드 내에서 여러 개의 torch.Tensor에 대해 동일한 indices를 사용하여 개별적인 gather 연산을 수행했습니다. 각 buf[indices] 연산은 별도의 GPU 커널을 런치하게 되어, 불필요한 오버헤드를 발생시켰습니다.
Before: 개별적인 Gather 연산
다음은 PR 적용 전 _resolve_spec_extras 메서드에서 여러 gather 연산이 개별적으로 수행되던 방식입니다. 각 buf[indices] 라인이 잠재적으로 별도의 GPU 커널 런치를 유발합니다.
indices.record_stream(torch.get_device_module(self.device).current_stream())
draft_input.topk_p = self.topk_p_buf[indices]
draft_input.topk_index = self.topk_index_buf[indices]
draft_input.bonus_tokens = self.output_tokens_buf[indices]
if _DEBUG_ASSERT:
_assert_nonneg_and_invalidate(
draft_input.bonus_tokens, self.output_tokens_buf, indices
)
if spec_need_hidden_states():
draft_input.hidden_states = self.hidden_states_buf[indices]
After: torch.compile을 활용한 Gather 퓨전
PR은 새로운 헬퍼 함수 _gather_spec_extras를 도입하고, 이 함수에 @torch.compile(dynamic=True) 데코레이터를 적용하여 여러 gather 연산을 하나로 묶었습니다. 그리고 _resolve_spec_extras 메서드에서는 이 컴파일된 함수를 호출하도록 변경했습니다.
먼저, 새로 추가된 _gather_spec_extras 함수입니다.
--- a/python/sglang/srt/managers/overlap_utils.py
+++ b/python/sglang/srt/managers/overlap_utils.py
@@ -33,6 +33,25 @@ def _assert_nonneg_and_invalidate(
buf[indices] = -1
+@torch.compile(dynamic=True)
+def _gather_spec_extras(
+ indices: torch.Tensor,
+ topk_p_buf: torch.Tensor,
+ topk_index_buf: torch.Tensor,
+ output_tokens_buf: torch.Tensor,
+ hidden_states_buf: Optional[torch.Tensor],
+):
+ """Compiled gather of spec extras. `hidden_states_buf` is None when the
+ build does not capture hidden states."""
+ topk_p = topk_p_buf[indices]
+ topk_index = topk_index_buf[indices]
+ bonus_tokens = output_tokens_buf[indices]
+ hidden_states = (
+ hidden_states_buf[indices] if hidden_states_buf is not None else None
+ )
+ return topk_p, topk_index, bonus_tokens, hidden_states
+
+
def _resolve_future_token_ids_native(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
이어서, _resolve_spec_extras 메서드가 _gather_spec_extras를 호출하도록 변경된 부분입니다.
--- a/python/sglang/srt/managers/overlap_utils.py
+++ b/python/sglang/srt/managers/overlap_utils.py
@@ -135,15 +154,27 @@ def _resolve_spec_extras(self, batch: ScheduleBatch) -> None:
# FIXME: indices = batch.req_pool_indices, pinned 2 iters via
# record_batch_in_overlap; record_stream here is redundant.
indices.record_stream(torch.get_device_module(self.device).current_stream())
- draft_input.topk_p = self.topk_p_buf[indices]
- draft_input.topk_index = self.topk_index_buf[indices]
- draft_input.bonus_tokens = self.output_tokens_buf[indices]
+ hidden_states_buf = (
+ self.hidden_states_buf if spec_need_hidden_states() else None
+ )
+ (
+ draft_input.topk_p,
+ draft_input.topk_index,
+ draft_input.bonus_tokens,
+ hidden_states,
+ ) = _gather_spec_extras(
+ indices,
+ self.topk_p_buf,
+ self.topk_index_buf,
+ self.output_tokens_buf,
+ hidden_states_buf,
+ )
+ if hidden_states is not None:
+ draft_input.hidden_states = hidden_states
if _DEBUG_ASSERT:
_assert_nonneg_and_invalidate(
draft_input.bonus_tokens, self.output_tokens_buf, indices
)
- if spec_need_hidden_states():
- draft_input.hidden_states = self.hidden_states_buf[indices]
def set_input_ids_sentinel(
self, batch: ScheduleBatch, future_indices: torch.Tensor
이 변경으로 인해 topk_p, topk_index, bonus_tokens, 그리고 선택적으로 hidden_states에 대한 여러 gather 연산이 _gather_spec_extras라는 단일 컴파일된 함수 내에서 처리됩니다. 이는 PyTorch의 torch.compile 기능이 여러 연산을 하나의 최적화된 GPU 커널로 퓨전(fuse)할 수 있도록 해주기 때문에 가능합니다.
dynamic=True 옵션은 입력 텐서의 모양(shape)이 호출마다 달라질 수 있는 경우에 유용합니다. LLM 추론 환경에서는 배치 크기나 시퀀스 길이가 동적으로 변할 수 있으므로, 이러한 유연성은 매우 중요합니다.
왜 이게 좋은가: 성능 최적화의 핵심
이 PR은 LLM 추론의 성능을 향상시키는 데 있어 몇 가지 중요한 이점을 제공합니다.
-
GPU 커널 런치 오버헤드 감소: 가장 직접적인 이점은 GPU 커널 런치 횟수가 줄어든다는 것입니다. 기존에는 3~4개의 개별적인
buf[indices]연산이 각각 GPU 커널을 런치했지만, 이제는_gather_spec_extras라는 하나의torch.compile된 함수 호출이 모든 gather 연산을 처리합니다. PR 설명에 따르면, 이 최적화를 통해spec_v2디코드 프롤로그에서 반복당 3개의 커널 런치를 줄였습니다. 이는 LLM 추론의 각 반복(iteration)마다 CPU-GPU 통신 및 스케줄링 비용이 절감되어, 전체 추론 속도가 빨라지는 효과를 가져옵니다. -
연산 퓨전(Operation Fusion)을 통한 효율성 증대:
torch.compile은 파이썬 코드를 분석하여 여러 PyTorch 연산을 하나의 효율적인 저수준(low-level) GPU 커널로 퓨전할 수 있습니다. 이 경우, 동일한indices를 사용하여 여러 버퍼에서 데이터를 가져오는 gather 연산들이 하나의 커널 내에서 동시에 처리될 수 있습니다. 이는 데이터 접근 패턴을 최적화하고 GPU의 메모리 대역폭 활용도를 높여 전반적인 연산 효율성을 증대시킵니다. -
torch.compile의 강력한 활용: PyTorch 2.0부터 도입된torch.compile은 PyTorch 모델과 연산을 자동으로 최적화하는 강력한 도구입니다. 이 PR은torch.compile이 단순한 모델 컴파일을 넘어, 특정 코드 블록의 미세한 연산 패턴까지 최적화하는 데 활용될 수 있음을 보여줍니다.dynamic=True옵션은 동적인 입력 크기 변화에도 유연하게 대응할 수 있도록 하여, 실제 LLM 서비스 환경에서 더욱 효과적인 최적화를 가능하게 합니다.
일반적인 교훈
이 최적화는 다음과 같은 일반적인 교훈을 제공합니다.
- 작은 연산들의 묶음(Batching/Fusion)의 중요성: GPU 프로그래밍에서는 작은 연산들이 반복적으로 호출될 때 발생하는 오버헤드를 줄이는 것이 중요합니다. 특히 동일한 인덱스를 사용하는 여러 gather 연산과 같이 데이터 접근 패턴이 유사한 경우, 이들을 하나로 묶어 처리하는 것이 매우 효과적입니다.
- 자동 최적화 도구의 적극적인 활용:
torch.compile과 같은 프레임워크 수준의 자동 최적화 도구는 개발자가 직접 CUDA 커널을 작성하지 않고도 상당한 성능 향상을 이끌어낼 수 있는 강력한 수단입니다. 성능 병목이 의심되는 코드 영역에 이를 적용하는 것을 고려해볼 필요가 있습니다. - 프로파일링의 중요성: 이러한 최적화는 종종 프로파일링 도구를 통해 병목 지점을 정확히 식별한 후에 이루어집니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] vLLM XPU MOE 성능 최적화: 호스트 오버헤드 감소를 위한 객체 지향적 접근
- 현재글 : [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- 다음글 [cpython] CPython의 PySequence_GetSlice 성능 개선: 불필요한 참조 카운트 연산 제거
댓글