[sglang] AMD에서 MoE Gate router gemm을 tgemm.mm으로 교체
PR 링크: sgl-project/sglang#21657 상태: Merged | 변경: +8 / -32
들어가며
AMD GPU에서 DeepSeek-V2의 MoE gate router는 hidden_states와 weight의 행렬곱을 수행한다. 기존에는 행렬 크기(M<=256)에 따라 gemm_a16w16_atomic과 gemm_a16w16을 수동으로 분기하고, zero-allocated 출력 버퍼를 관리하는 복잡한 로직이 있었다. 이 PR은 이를 aiter의 tgemm.mm 자동 디스패처 하나로 교체한다.
핵심 코드 분석
Router GEMM 함수 단순화
Before:
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
from sglang.srt.utils import BumpAllocator
def aiter_dsv3_router_gemm(
hidden_states, weight, gemm_output_zero_allocator=None,
):
M = hidden_states.shape[0]
N = weight.shape[0]
y = None
if M <= 256:
if gemm_output_zero_allocator != None:
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
else:
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
if y is not None:
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
else:
logits = gemm_a16w16(hidden_states, weight)
return logits
After:
from aiter.tuned_gemm import tgemm
def aiter_dsv3_router_gemm(hidden_states, weight):
"""Use aiter tuned GEMM dispatcher (tgemm.mm) to automatically select the GEMM kernel."""
return tgemm.mm(hidden_states, weight, otype=hidden_states.dtype)
26줄이 3줄로 줄었다. tgemm.mm이 내부적으로 행렬 크기에 따라 최적의 GEMM 커널을 자동 선택한다.
호출측 인자 정리
Before:
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256 and self.weight.shape[0] <= 256:
logits = aiter_dsv3_router_gemm(
hidden_states, self.weight, gemm_output_zero_allocator
)
After:
elif _use_aiter:
logits = aiter_dsv3_router_gemm(hidden_states, self.weight)
gfx95(MI300X) 전용이었던 조건이 모든 AMD GPU(_use_aiter)로 확장되고, 크기 조건과 allocator 인자가 불필요해졌다.
왜 이게 좋은가
- 성능 회귀 해결:
--context-length 13824로 gpt-oss 실행 시 발생하던 성능 회귀 수정 (#21691 관련) - 코드 단순화: 수동 크기 분기와 BumpAllocator 의존성 제거
- AMD GPU 범용 지원: gfx95 전용에서 모든 aiter 지원 AMD GPU로 확장
정리
aiter 라이브러리의 tgemm.mm 자동 디스패처를 활용하여 수동 GEMM 커널 분기 로직을 제거한 깔끔한 리팩터링이다. 코드 복잡도는 크게 줄었고, AMD GPU 전반에서 최적 GEMM 커널이 자동 선택된다.
참고 자료
- sgl-project/sglang#21657 — 원본 PR
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [llm-compressor] GPTQ Block Quantization 지원
- 현재글 : [sglang] AMD에서 MoE Gate router gemm을 tgemm.mm으로 교체
- 다음글 [sglang] CI에서 NVIDIA wheel 로컬 캐싱으로 830MB 반복 다운로드 방지
댓글