본문으로 건너뛰기

[sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화

PR 링크: sgl-project/sglang#25463 상태: Merged | 변경: +21 / -5

들어가며

최근 sglang 프로젝트의 Pull Request(PR)에서는 ROCm 환경에서 MXFP4 (Mixed Precision FP4) 디코딩 경로의 어텐션 메커니즘에서 발생하는 불필요한 메모리 복사를 제거하여 성능을 개선했습니다. 특히 _use_aiter_gfx95 조건 하에서 w_vc.dtype == torch.uint8일 때, 어텐션 계산 후 발생하는 transposeflatten 연산에서 의도치 않은 contiguous 복사가 발생하여 약 4µs/레이어의 오버헤드가 발생했습니다. MoE(Mixture of Experts) 모델의 경우 수십 개의 레이어를 거치므로, 이는 전체 디코딩 단계에서 상당한 지연 시간을 유발합니다. 이 PR은 이러한 비효율성을 제거하고 성능을 향상시키는 것을 목표로 합니다.

코드 분석

python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py

이 PR의 핵심 변경 사항은 forward_absorb_core 함수 내에서 발생합니다. 특히 ROCm 환경에서 MXFP4 데이터 타입을 사용할 때 어텐션 결과 처리 방식에 변화가 있습니다.

변경 전 코드:

            if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
                x = attn_output.transpose(0, 1)
                attn_bmm_output = torch.empty(
                    x.shape[0],
                    x.shape[1],
                    self.w_vc.shape[2],
                    device=x.device,
                    dtype=torch.bfloat16,
                )
                batched_gemm_afp4wfp4_pre_quant(
                    x,
                    self.w_vc.transpose(-2, -1),
                    self.w_scale,
                    attn_bmm_output,
                )
            else:
                # ... (기존 로직)

            if self.o_proj.weight.dtype == torch.uint8:
                attn_bmm_output = attn_bmm_output.transpose(0, 1)
                attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
            elif self.o_proj.weight.dtype == torch.float8_e4m3fn:
                # ...
            else:
                attn_bmm_output = attn_bmm_output.flatten(1, 2)

변경 전 코드에서는 batched_gemm_afp4wfp4_pre_quant 커널의 출력(attn_bmm_output)이 (heads, batch, v_head_dim) 형태를 가집니다. 이후 transpose(0, 1) 연산을 통해 (batch, heads, v_head_dim) 형태로 변경하는데, 이 과정에서 텐서가 non-contiguous 상태가 됩니다. 이어서 .flatten(1, 2) 연산이 호출될 때, PyTorch는 non-contiguous 텐서를 contiguous하게 만들기 위해 내부적으로 .contiguous()를 호출하며, 이는 추가적인 메모리 복사를 유발합니다. 프로파일링 결과 이 복사 작업이 elementwise_kernel_manual_unroll 커널로 나타나며 약 4.4µs의 시간을 소모했습니다.

변경 후 코드:

            if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
                x = attn_output.transpose(0, 1)
                B_heads, M_batch = x.shape[0], x.shape[1]
                N_vdim = self.w_vc.shape[2]
                # Allocate in (batch, heads, dim) so the post-GEMM
                # transpose+flatten is a free view instead of a copy.
                _bmm_buf = torch.empty(
                    M_batch,
                    B_heads,
                    N_vdim,
                    device=x.device,
                    dtype=torch.bfloat16,
                )
                attn_bmm_output = _bmm_buf.transpose(0, 1)
                batched_gemm_afp4wfp4_pre_quant(
                    x,
                    self.w_vc.transpose(-2, -1),
                    self.w_scale,
                    attn_bmm_output,
                )
            else:
                _bmm_buf = None
                # ... (기존 로직)

            if _bmm_buf is not None:
                # _bmm_buf is already (batch, heads, dim) contiguous
                if self.o_proj.weight.dtype == torch.uint8:
                    attn_bmm_output = fused_flatten_mxfp4_quant(_bmm_buf)
                elif self.o_proj.weight.dtype == torch.float8_e4m3fn:
                    attn_bmm_output = fused_flatten_fp8_group_quant(
                        _bmm_buf, group_size=128, dtype_quant=torch.float8_e4m3fn
                    )
                else:
                    attn_bmm_output = _bmm_buf.flatten(1, 2)
            elif self.o_proj.weight.dtype == torch.uint8:
                attn_bmm_output = attn_bmm_output.transpose(0, 1)
                attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
            elif self.o_proj.weight.dtype == torch.float8_e4m3fn:
                # ...

변경 후에는 batched_gemm_afp4wfp4_pre_quant 커널에 전달될 출력 버퍼(_bmm_buf)를 처음부터 원하는 최종 레이아웃인 (batch, heads, v_head_dim) 형태로 미리 할당합니다. 그리고 이 버퍼의 transpose(0, 1) 뷰를 커널의 출력으로 사용합니다. 이렇게 하면 커널의 출력은 (heads, batch, v_head_dim) 형태가 되지만, 실제 데이터는 이미 contiguous한 _bmm_buf에 쓰여집니다. 이후 .flatten(1, 2) 연산은 _bmm_buf에 대해 수행되며, 이는 이미 contiguous하므로 별도의 복사 없이 0-cost view 연산으로 처리됩니다. 결과적으로 elementwise_kernel_manual_unroll 커널이 사라지고, 약 4µs의 불필요한 복사 오버헤드가 제거되었습니다.

이 변경은 _use_aiter_gfx95self.w_vc.dtype == torch.uint8 조건 하에서만 적용되어, MXFP4를 사용하는 특정 ROCm 환경(MI355X/gfx950)에만 영향을 미칩니다. 다른 백엔드나 플랫폼은 변경되지 않습니다.

왜 이게 좋은가?

성능 향상

이 PR의 가장 큰 장점은 명확한 성능 향상입니다. 프로파일링 결과에 따르면, 이전에는 어텐션 계산 후 transposeflatten 연산 사이에 약 4.4µs의 contiguous 복사 오버헤드가 존재했습니다. MoE 모델의 경우 61개의 레이어를 거치므로, 이는 레이어당 4µs * 61 = 244µs, 전체 디코딩 단계에서는 약 268µs의 불필요한 지연 시간을 발생시켰습니다.

변경 후에는 이 복사 작업이 제거되어 해당 오버헤드가 사라졌습니다. 실제 벤치마크 결과는 다음과 같은 성능 향상을 보여줍니다:

  • Output Throughput (tokens/s):
    • 1k/1k 시나리오에서 최대 +2.8% 향상
    • 8k/1k 시나리오에서 최대 +2.3% 향상
  • Mean TPOT (ms):
    • 1k/1k 시나리오에서 최대 +2.9% 향상
    • 8k/1k 시나리오에서 최대 +2.3% 향상

이러한 수치는 특히 긴 시퀀스(8k/1k) 처리 시에도 꾸준한 성능 개선을 보여주며, 모델의 전반적인 처리량을 높이고 응답 시간을 단축하는 데 기여합니다.

일반적인 교훈

  1. 메모리 레이아웃과 Contiguity의 중요성: 딥러닝 모델, 특히 GPU에서 실행되는 모델의 성능은 메모리 접근 패턴과 데이터의 연속성(contiguity)에 크게 좌우됩니다. transpose, view, flatten과 같은 연산은 텐서의 모양만 변경할 뿐 데이터를 재배열하지 않을 수 있습니다. 하지만 이러한 연산 후 데이터가 연속적이지 않다면(non-contiguous), 후속 연산에서 암시적인 복사가 발생하여 성능 저하의 원인이 됩니다. 이 PR은 이러한 점을 명확히 보여줍니다.
  2. 커널 수준에서의 최적화: Triton과 같은 커스텀 커널을 사용할 때, 메모리 할당 및 데이터 이동 방식을 세밀하게 제어할 수 있습니다. 이 PR은 batched_gemm_afp4wfp4_pre_quant 커널이 임의의 출력 스트라이드를 지원한다는 점을 활용하여, 커널 자체에서 원하는 레이아웃으로 데이터를 쓰도록 유도했습니다. 이는 라이브러리 API의 기능을 깊이 이해하고 활용하는 것이 얼마나 중요한지를 보여줍니다.
  3. 프로파일링 기반 최적화: 성능 병목 현상을 정확히 파악하기 위해서는 프로파일링이 필수적입니다. 이 PR은 프로파일링 트레이스에서 elementwise_kernel_manual_unroll (즉, direct_copy_kernel_cuda)을 식별하고, 그 원인이 되는 불필요한 복사를 제거하는 방식으로 진행되었습니다. 이는 체계적인 성능 개선 접근 방식의 중요성을 강조합니다.
  4. 조건부 최적화: 모든 코드 경로에 동일한 최적화를 적용하는 것은 위험할 수 있습니다. 이 PR은 특정 하드웨어(ROCm, gfx950)와 특정 데이터 타입(MXFP4, torch.uint8) 조건 하에서만 최적화를 적용하여, 다른 환경에서의 부작용을 최소화했습니다. 이는 코드의 안정성을 유지하면서 성능을 개선하는 좋은 전략입니다.

정확성 검증

성능 최적화 과정에서 모델의 정확성이 저하되지 않는 것이 중요합니다. 이 PR은 GSM8K 데이터셋을 사용한 정확성 검증에서 exact_match 지표가 0.932로, 요구되는 임계값(≥ 0.90)을 상회하며 통과했습니다. 이는 성능 개선이 정확성 희생 없이 이루어졌음을 의미합니다.

리뷰 댓글 분석

리뷰 댓글은 주로 CI(Continuous Integration) 관련 요청과 확인에 집중되었습니다. @Fridge003, @ch-wan, @HaiShaw에게 run-ci 라벨을 추가해달라는 요청이 있었고, @HaiShaw/tag-and-rerun-ci 명령어로 이를 처리했습니다. 이후 amd-bot의 CI 상태 보고에 따르면, PR에서 수정한 코드는 ROCm 환경의 특정 조건(_use_aiter_gfx95 and self.w_vc.dtype == torch.uint8)에만 해당하므로, 다른 AMD GPU(MI325, gfx942)나 NVIDIA GPU, NPU 환경에서 실행되는 테스트 케이스의 실패는 이 PR과 직접적인 관련이 없다고 분석되었습니다. 이는 이 PR이 매우 제한적인 범위에만 영향을 미치며, 해당 범위 외의 코드에는 영향을 주지 않음을 시사합니다. 따라서 CI 실패는 대부분 PR의 변경 사항과는 무관한 것으로 간주되어도 무방합니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글