[sglang] AMD GPU에서 FP8 MLA를 활용한 Diffusion 모델 성능 최적화
PR 링크: sgl-project/sglang#20319 상태: Merged | 변경: +0 / -0
들어가며
최근 고성능 생성형 AI 모델의 추론 속도를 높이기 위해 하드웨어 가속기별 최적화가 필수적입니다. 본 PR은 AMD MI355X GPU 환경에서 Diffusion 모델의 추론 성능을 개선하기 위해, 기존의 FP8 per-tensor flash attention 대신 FP8 MLA(Multi-Head Latent Attention) ASM 커널을 도입했습니다. 이 최적화는 특히 Prefill 단계의 연산 효율을 극대화하여 전체 추론 시간을 크게 단축하는 것을 목표로 합니다.
코드 분석
1. python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py
핵심 변경 사항은 flash_attn_fp8_pertensor_func를 제거하고, 고도로 최적화된 mla_prefill_ps_asm_fwd와 mla_reduce_v1 커널을 사용하는 _mla_prefill_ps_attention 함수를 구현한 것입니다.
Before (기존 방식)
# 기존에는 범용 FP8 per-tensor flash attention을 사용
# 이 방식은 MLA 구조에 특화된 최적화가 부족함
output = aiter.flash_attn_fp8_pertensor_func(q, k, v, ...)
After (최적화된 방식)
# MLA 커널의 제약사항을 준수하며 ASM 커널 호출
# qk_head_dim=192 제약을 맞추기 위해 zero-padding 적용
pad_qk = _MLA_PREFILL_QK_HEAD_DIM - D_q
if pad_qk > 0:
q = torch.nn.functional.pad(q, (0, pad_qk))
k = torch.nn.functional.pad(k, (0, pad_qk))
# ASM 커널과 reduce 커널을 순차적으로 실행
aiter.mla_prefill_ps_asm_fwd(...)
aiter.mla_reduce_v1(...)
또한, _build_mla_prefill_metadata 함수를 추가하여 ASM 커널이 요구하는 복잡한 persistent-scheduling 메타데이터(indptr, work partitioning 등)를 동적으로 생성하도록 했습니다. 이는 커널이 GPU 메모리 상에서 효율적으로 작업을 분할하여 병렬성을 극대화할 수 있게 합니다.
왜 이게 좋은가
이번 최적화의 핵심은 하드웨어 특화 커널(ASM)을 활용하여 연산 밀도를 높인 점입니다. 특히 MI355X와 같은 최신 AMD GPU 아키텍처에서 FP8 연산 성능을 최대한 끌어내도록 설계되었습니다.
- 성능 수치: Wan2.2-T2V 모델 기준, 81프레임 생성 시 Denoising Stage에서 약 19.4%의 속도 향상을 달성했습니다.
- 범용성 고려: 모든 모델이 MLA 커널의 제약(v_head_dim=128 등)을 만족하지는 않습니다. 이를 위해
_can_use_mla_prefill가드 함수를 두어, 조건 불만족 시 자동으로 BF16 flash attention으로 폴백(fallback)하도록 설계하여 안정성을 확보했습니다. - 교훈: 특정 아키텍처(gfx950)에 최적화된 ASM 커널을 사용할 때는 정밀한 shape guard와 메타데이터 관리가 필수적입니다. 또한, 컴파일러 최적화와 커스텀 커널 간의 복잡한 상호작용을 고려하여
@torch.compiler.disable을 전략적으로 배치하는 것이 디버깅 비용을 줄이는 현명한 선택이 될 수 있습니다.
리뷰어 피드백 반영
리뷰 과정에서 @torch.compiler.disable을 전체 forward 함수에 적용한 점에 대해 논의가 있었습니다. 유지보수 측면에서는 특정 분기만 비활성화하는 것이 이상적이지만, 현재는 FP8/BF16 혼합 경로와 커스텀 커널 호출이 복잡하게 얽혀 있어, 안정성을 위해 전체 비활성화를 선택했습니다. 이는 성능과 안정성 사이의 트레이드오프를 고려한 실용적인 결정입니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html
- https://docs.sglang.ai/developer_guide/contribution_guide.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] [AMD/ROCm] Temporal Unfolding을 통한 VAE Conv3D 성능 최적화 분석
- [sglang] SGLang의 AMD AITER AllReduce 최적화: 하드코딩된 제약 제거 및 성능 개선
- [sglang] SGLang의 AMD GPU 최적화: RMSNorm과 FP8 Per-token Quantization 커널 융합
- [sglang] SGLang AMD 환경에서의 GLM-5-FP8 성능 벤치마크 도입 및 최적화
- [sglang] AMD ROCm 환경에서의 DeepSeek-V4 성능 최적화: Aiter MHC 커널 통합 분석
PR Analysis 의 다른글
- 이전글 [sglang] SGLang: ROCm 환경에서 RMSNorm 최적화 - Triton에서 aiter 커널로 전환
- 현재글 : [sglang] AMD GPU에서 FP8 MLA를 활용한 Diffusion 모델 성능 최적화
- 다음글 [sglang] [AMD/ROCm] Temporal Unfolding을 통한 VAE Conv3D 성능 최적화 분석
댓글