본문으로 건너뛰기

[sglang] [AMD] Triton 커널 퓨전을 통한 Qwen3.5 MoE 라우팅 최적화 분석

PR 링크: sgl-project/sglang#22844 상태: Merged | 변경: +None / -None

들어가며

대규모 언어 모델(LLM)의 추론 성능을 최적화할 때, 개별 연산의 속도만큼 중요한 것이 바로 커널 런칭 오버헤드(Kernel Launch Overhead)를 줄이는 것입니다. 특히 Mixture-of-Experts(MoE) 구조를 사용하는 모델에서는 토큰별로 전문가를 할당하는 라우팅 로직이 크리티컬 패스(Critical Path)에 위치합니다.

최근 SGLang 프로젝트에 반영된 이 PR은 Qwen3.5 모델의 MoE 라우팅 과정에서 발생하는 비효율을 해결합니다. 기존에는 공유 전문가(Shared Expert)의 ID와 가중치를 추가하기 위해 4번의 별도 커널(2개의 elementwise 연산 + 2개의 concat 연산)을 호출해야 했습니다. 이번 업데이트는 이를 단일 Fused Triton Kernel로 통합하여 AMD GPU 환경에서의 서빙 성능을 크게 개선했습니다.

코드 분석: 무엇이 바뀌었나?

1. 기존 방식의 문제점 (Before)

기존 python/sglang/srt/models/qwen2_moe.py의 로직은 PyTorch의 고수준 API를 사용하여 가독성은 좋지만, 내부적으로 여러 번의 GPU 커널 호출을 유발했습니다.

# Before: qwen2_moe.py
shared_expert_id = self.num_experts
shared_ids = torch.full(
    (M, self.num_fused_shared_experts),
    shared_expert_id,
    dtype=topk_output.topk_ids.dtype,
    device=topk_output.topk_ids.device,
)
shared_weights = shared_weights.expand(M, self.num_fused_shared_experts).to(
    topk_output.topk_weights.dtype
)
fused_topk_ids = torch.cat([topk_output.topk_ids, shared_ids], dim=-1)
fused_topk_weights = torch.cat(
    [topk_output.topk_weights, shared_weights], dim=-1
)

위 코드에서 torch.full, expand, 그리고 두 번의 torch.cat은 각각 별도의 GPU 작업을 생성합니다. 특히 Qwen3.5처럼 토큰별로 학습된 shared_expert_gate 가중치를 사용하는 경우, 이러한 오버헤드는 무시할 수 없는 수준이 됩니다.

2. Triton을 이용한 커널 퓨전 (After)

새롭게 도입된 _fused_append_shared_experts_with_weights_kernel은 기존의 복잡한 과정을 한 번의 메모리 읽기/쓰기 사이클 내에서 처리하도록 설계되었습니다.

# After: fused_moe_triton_kernels.py
@triton.jit
def _fused_append_shared_experts_with_weights_kernel(
    topk_ids_ptr, topk_weights_ptr, shared_weights_ptr, 
    out_ids_ptr, out_weights_ptr, N_BASE, K: tl.constexpr, S: tl.constexpr, ...
):
    pid = tl.program_id(0)
    # ... 인덱스 계산 로직 ...

    # 1. 기존 topk 데이터 로드 및 저장
    ids = tl.load(topk_ids_ptr + ids_row_ptr + offs_k, mask=mask_k)
    ws = tl.load(topk_weights_ptr + ids_row_ptr + offs_k, mask=mask_k)
    tl.store(out_ids_ptr + out_row_ptr + offs_k, ids, mask=mask_k)
    tl.store(out_weights_ptr + out_row_ptr + offs_k, ws, mask=mask_k)

    # 2. 공유 전문가(Shared Expert) 데이터 생성 및 저장
    shared_ids = tl.cast(N_BASE + offs_s, ids.dtype)
    shared_ws = tl.load(shared_weights_ptr + pid * S + offs_s, mask=mask_s)
    tl.store(out_ids_ptr + out_row_ptr + K + offs_s, shared_ids, mask=mask_s)
    tl.store(out_weights_ptr + out_row_ptr + K + offs_s, shared_ws, mask=mask_s)

이 커널은 한 번의 실행으로 기존 topk_ids/weights를 복사함과 동시에, 끝부분에 공유 전문가 정보를 덧붙입니다. 파이썬 레벨에서의 torch.cat 호출 없이 GPU 내부 메모리 레이아웃을 직접 제어하므로 매우 효율적입니다.

왜 이게 좋은 최적화인가?

1. 커널 런칭 오버헤드 제거

GPU 연산은 실제 계산 시간보다 커널을 준비하고 실행하는 오버헤드가 더 클 때가 많습니다. 특히 MoE 라우팅처럼 데이터 크기는 작지만 빈번하게 발생하는 연산에서 4개의 커널을 1개로 줄이는 것은 지연 시간(Latency) 단축에 직결됩니다.

2. 메모리 대역폭 효율화

기존 방식은 중간 결과물(shared_ids, shared_weights)을 생성하기 위해 임시 메모리에 쓰고 다시 읽는 과정을 거칩니다. 퓨전된 커널은 레지스터 수준에서 데이터를 처리하고 최종 결과만 메모리에 쓰기 때문에 메모리 대역폭(Memory Bandwidth)을 절약합니다.

3. 실제 성능 향상 수치

벤치마크 결과에 따르면, 이 최적화를 통해 다음과 같은 성능 향상이 확인되었습니다:

  • TTT (Tokens Throughput): 최대 +4.16% 향상
  • Median TPOT (Time Per Output Token):3.99% 개선
  • Median E2EL (End-to-End Latency):3.93% 감소

정확도 측면에서도 GSM8K 테스트 결과 baseline과 동일한 0.95 수준을 유지하여, 최적화로 인한 수치적 손실이 없음을 증명했습니다.

결론

이번 PR은 현대적인 LLM 서빙 엔진에서 Triton과 같은 커널 언어가 왜 필수적인지를 잘 보여줍니다. 고수준 프레임워크의 편의성을 유지하면서도, 병목이 발생하는 지점(Hotspot)을 정밀하게 타격하여 커널을 통합하는 기법은 시니어 엔지니어가 갖춰야 할 핵심 역량 중 하나입니다.

AMD GPU 환경에서 Qwen3.5와 같은 최신 모델을 서빙하고자 한다면, 이러한 하위 레벨의 퓨전 최적화가 전체 시스템의 처리량을 결정짓는 중요한 요소가 될 것입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글