본문으로 건너뛰기

[sglang] AMD GPU에서 FP8 MLA를 활용한 Diffusion 모델 성능 최적화

PR 링크: sgl-project/sglang#20319 상태: Merged | 변경: +0 / -0

들어가며

최근 고성능 생성형 AI 모델의 추론 속도를 높이기 위해 하드웨어 가속기별 최적화가 필수적입니다. 본 PR은 AMD MI355X GPU 환경에서 Diffusion 모델의 추론 성능을 개선하기 위해, 기존의 FP8 per-tensor flash attention 대신 FP8 MLA(Multi-Head Latent Attention) ASM 커널을 도입했습니다. 이 최적화는 특히 Prefill 단계의 연산 효율을 극대화하여 전체 추론 시간을 크게 단축하는 것을 목표로 합니다.

코드 분석

1. python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py

핵심 변경 사항은 flash_attn_fp8_pertensor_func를 제거하고, 고도로 최적화된 mla_prefill_ps_asm_fwdmla_reduce_v1 커널을 사용하는 _mla_prefill_ps_attention 함수를 구현한 것입니다.

Before (기존 방식)

# 기존에는 범용 FP8 per-tensor flash attention을 사용
# 이 방식은 MLA 구조에 특화된 최적화가 부족함
output = aiter.flash_attn_fp8_pertensor_func(q, k, v, ...)

After (최적화된 방식)

# MLA 커널의 제약사항을 준수하며 ASM 커널 호출
# qk_head_dim=192 제약을 맞추기 위해 zero-padding 적용
pad_qk = _MLA_PREFILL_QK_HEAD_DIM - D_q
if pad_qk > 0:
    q = torch.nn.functional.pad(q, (0, pad_qk))
    k = torch.nn.functional.pad(k, (0, pad_qk))

# ASM 커널과 reduce 커널을 순차적으로 실행
aiter.mla_prefill_ps_asm_fwd(...)
aiter.mla_reduce_v1(...)

또한, _build_mla_prefill_metadata 함수를 추가하여 ASM 커널이 요구하는 복잡한 persistent-scheduling 메타데이터(indptr, work partitioning 등)를 동적으로 생성하도록 했습니다. 이는 커널이 GPU 메모리 상에서 효율적으로 작업을 분할하여 병렬성을 극대화할 수 있게 합니다.

왜 이게 좋은가

이번 최적화의 핵심은 하드웨어 특화 커널(ASM)을 활용하여 연산 밀도를 높인 점입니다. 특히 MI355X와 같은 최신 AMD GPU 아키텍처에서 FP8 연산 성능을 최대한 끌어내도록 설계되었습니다.

  • 성능 수치: Wan2.2-T2V 모델 기준, 81프레임 생성 시 Denoising Stage에서 약 19.4%의 속도 향상을 달성했습니다.
  • 범용성 고려: 모든 모델이 MLA 커널의 제약(v_head_dim=128 등)을 만족하지는 않습니다. 이를 위해 _can_use_mla_prefill 가드 함수를 두어, 조건 불만족 시 자동으로 BF16 flash attention으로 폴백(fallback)하도록 설계하여 안정성을 확보했습니다.
  • 교훈: 특정 아키텍처(gfx950)에 최적화된 ASM 커널을 사용할 때는 정밀한 shape guard와 메타데이터 관리가 필수적입니다. 또한, 컴파일러 최적화와 커스텀 커널 간의 복잡한 상호작용을 고려하여 @torch.compiler.disable을 전략적으로 배치하는 것이 디버깅 비용을 줄이는 현명한 선택이 될 수 있습니다.

리뷰어 피드백 반영

리뷰 과정에서 @torch.compiler.disable을 전체 forward 함수에 적용한 점에 대해 논의가 있었습니다. 유지보수 측면에서는 특정 분기만 비활성화하는 것이 이상적이지만, 현재는 FP8/BF16 혼합 경로와 커스텀 커널 호출이 복잡하게 얽혀 있어, 안정성을 위해 전체 비활성화를 선택했습니다. 이는 성능과 안정성 사이의 트레이드오프를 고려한 실용적인 결정입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글