본문으로 건너뛰기

[sglang] SGLang의 AMD GPU 최적화: RMSNorm과 FP8 Per-token Quantization 커널 융합

PR 링크: sgl-project/sglang#21403 상태: Merged | 변경: +None / -None

들어가며

대규모 언어 모델(LLM)의 추론 성능을 극대화하기 위해서는 GPU 메모리 대역폭을 효율적으로 사용하는 것이 핵심입니다. 특히 FP8과 같은 저정밀도 연산을 사용할 때, 여러 연산을 별도의 커널로 실행하면 불필요한 글로벌 메모리 읽기/쓰기(Round-trip)가 발생합니다. 이번 SGLang PR은 GLM-4.7-FP8 모델을 위해 RMSNormFP8 per-token quantization 연산을 하나의 커널로 융합하여 성능을 개선했습니다.

코드 분석

1. communicator.py: 커널 융합 로직 추가

기존에는 RMSNorm과 Quantization이 분리되어 있었으나, aiter 라이브러리의 add_rmsnorm_quant를 활용하여 이를 하나로 합쳤습니다. group_size=0을 설정하여 per-token quantization을 명시적으로 처리합니다.

# Before: 별도의 RMSNorm 레이어 호출
hidden_states = self.input_layernorm(hidden_states)

# After: Fused 커널 호출
_aiter_rmsnorm_quant(
    out_fp8,
    hidden_states,
    scale,
    weight,
    epsilon,
    0,  # group_size=0 → per-token
)

또한, quant_format 체크 로직을 "fp8" in quant_format에서 quant_format == "fp8"로 변경하여, fp8_per_token 경로가 기존의 fused_rms_fp8_group_quant 경로와 충돌하지 않도록 정밀하게 제어했습니다.

2. glm4_moe.py: 자동 감지 및 설정

모델 초기화 시점에 CompressedTensorsW8A8Fp8 전략을 사용하는지 감지하여, 자동으로 fp8_per_token 모드를 활성화합니다.

# GLM-4.7-FP8 모델에서 FP8 per-token 감지 로직
if isinstance(scheme, CompressedTensorsW8A8Fp8) and scheme.strategy == QuantizationStrategy.CHANNEL:
    return "fp8_per_token"

3. fp8_utils.py: 데이터 타입 처리 개선

apply_fp8_ptpc_linear 함수가 튜플 형태의 입력(fp8_tensor, scale)을 직접 처리할 수 있도록 수정하여, 중복적인 per-token quantization 연산을 방지했습니다.

# 튜플 입력 처리로 중복 연산 제거
if isinstance(input, tuple):
    q_input, x_scale = input
    output = aiter.gemm_a8w8_bpreshuffle(q_input, weight, x_scale, weight_scale, None, torch.bfloat16)

왜 이게 좋은가

이번 최적화의 핵심은 메모리 접근 횟수 감소입니다. RMSNorm 결과를 메모리에 썼다가 다시 읽어서 Quantization을 수행하는 대신, 레지스터 수준에서 연산을 완료하고 바로 FP8 결과물을 생성합니다.

  • 성능 향상: MI355X 환경에서 ITL(Inter-Token Latency) Decode 속도가 약 1% 향상되었습니다.
  • 정확도 유지: GSM8K 벤치마크 결과가 0.948에서 0.943으로, 오차 범위 내에서 안정적인 정확도를 유지함을 확인했습니다.
  • 교훈: 커널 융합(Kernel Fusion)은 메모리 바운드(Memory-bound) 연산이 많은 LLM 추론에서 가장 효과적인 최적화 기법 중 하나입니다. 특히 ROCm/aiter 환경에서는 커널 호출 오버헤드를 줄이는 것이 성능의 병목을 해결하는 열쇠입니다.

리뷰어 피드백 반영

리뷰 과정에서 apply_fp8_ptpc_linear가 범용 함수로 오해받을 수 있다는 지적이 있었습니다. 이에 따라 해당 함수가 aiter 경로에서만 사용됨을 명시하는 docstring을 추가하여 코드의 의도를 명확히 했습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글