본문으로 건너뛰기

[vllm] vLLM DeepSeek V4 ROCm MTP 지원: 하드웨어 최적화와 추론 성능 향상

PR 링크: vllm-project/vllm#43385 상태: Merged | 변경: +2340 / -52

들어가며

대규모 언어 모델(LLM)의 효율적인 추론은 서비스 제공에 있어 핵심적인 과제입니다. 특히 다양한 하드웨어 플랫폼에서 최적의 성능을 끌어내는 것은 더욱 중요합니다. 이번에 분석할 vLLM PR(vllm-project/vllm의 "[ROCm] [DSv4] [Perf] Support DeepSeek v4 MTP")은 DeepSeek V4 모델의 ROCm(AMD GPU) 환경에서 Multi-Token Prediction (MTP), 즉 Speculative Decoding 기능을 지원하여 추론 성능을 대폭 개선하는 것을 목표로 합니다. 이 PR은 단순히 기능을 추가하는 것을 넘어, ROCm 플랫폼에 특화된 모델링 파일을 구현하고 저수준 커널 최적화를 적용하여 하드웨어 활용도를 극대화하는 중요한 이정표를 제시합니다.

기존에는 DeepSeek V4 모델이 ROCm 환경에서 MTP 기능을 제대로 활용하지 못했기 때문에, 특히 낮은 동시성(batch size) 환경에서 잠재적인 성능 병목이 존재했습니다. 이 PR은 이러한 문제를 해결하고, ROCm에서도 NVIDIA GPU와 유사한 수준의 최적화된 추론 경험을 제공하기 위한 기반을 마련합니다.

코드 변경사항 분석

이번 PR의 핵심 변경사항은 크게 두 가지로 나눌 수 있습니다: ROCm 플랫폼에 특화된 triton_kernels 임포트 로직 개선과 DeepSeek V4 모델의 AMD 전용 구현 파일(vllm/models/deepseek_v4/amd/model.py) 도입 및 그 안에 포함된 저수준 Triton 커널 최적화입니다.

1. vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py: 플랫폼별 triton_kernels 임포트

이 변경은 ROCm 환경에서 triton_kernels 라이브러리를 올바르게 로드하기 위한 중요한 수정입니다. 기존에는 vllm.third_party.triton_kernels 경로에서 임포트했지만, ROCm 환경에서는 triton_kernels 자체 패키지에서 직접 임포트해야 하는 문제가 있었습니다. 이는 Triton 컴파일러가 기대하는 경로와 vLLM 내부 구조 간의 불일치로 인해 발생한 컴파일 오류를 해결합니다.

Before:

    try:
        from vllm.third_party.triton_kernels.tensor_details import (
            bitmatrix as _bm,
        )
        from vllm.third_party.triton_kernels.tensor_details.bitmatrix import (
            BitmatrixMetadata,
            _keyed_add,
            cdiv,
        )
        from vllm.third_party.triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import (  # noqa: E501
            sum_bitmatrix_rows,
        )

After:

    try:
        if current_platform.is_rocm():
            from triton_kernels.tensor_details import bitmatrix as _bm
            from triton_kernels.tensor_details.bitmatrix import (
                BitmatrixMetadata,
                _keyed_add,
                cdiv,
            )
            from triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import (  # noqa: E501
                sum_bitmatrix_rows,
            )
        else:
            from vllm.third_party.triton_kernels.tensor_details import (
                bitmatrix as _bm,
            )
            from vllm.third_party.triton_kernels.tensor_details.bitmatrix import (
                BitmatrixMetadata,
                _keyed_add,
                cdiv,
            )
            from vllm.third_party.triton_kernels.tensor_details.bitmatrix_details.sum_bitmatrix_rows import (  # noqa: E501
                sum_bitmatrix_rows,
            )

이 변경은 current_platform.is_rocm()을 통해 현재 실행 환경이 ROCm인지 확인하고, 그에 따라 적절한 triton_kernels 모듈을 임포트하도록 합니다. 이는 크로스 플랫폼 호환성을 보장하고, 특히 ROCm에서 Triton 커널 컴파일 시 발생할 수 있는 CompilationError를 방지하는 데 필수적입니다. 리뷰어 tjtanaa의 코멘트에서도 이 문제가 triton.compiler.errors.CompilationError와 관련되어 있음을 확인할 수 있습니다.

2. vllm/models/deepseek_v4/amd/model.py: AMD 전용 모델 구현 및 Triton 커널

이 PR의 가장 중요한 구조적 변경은 vllm/models/deepseek_v4/amd/model.py 파일이 새로 생성되었다는 점입니다. 기존에는 이 경로가 ../nvidia/model.py를 가리키는 심볼릭 링크였으나, 이제 ROCm(AMD) 플랫폼에 특화된 DeepSeek V4 모델 구현을 위한 독립적인 파일로 분리되었습니다. 이는 ROCm 환경에서 DeepSeek V4 모델의 성능을 최적화하기 위한 첫걸음이며, 하드웨어별 특성을 반영한 맞춤형 최적화를 가능하게 합니다.

새로 추가된 amd/model.py 파일에는 DeepSeek V4의 MegaMoE(Mixture-of-Experts) 레이어를 위한 입력 스테이징(staging)을 최적화하는 Triton 커널인 _deepseek_v4_stage_mega_moe_inputs_kernel이 포함되어 있습니다.

새로운 Triton 커널 (_deepseek_v4_stage_mega_moe_inputs_kernel):

@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
    hidden_states,
    x_fp8,
    x_sf,
    topk_ids,
    topk_weights,
    topk_idx_out,
    topk_weights_out,
    hidden_stride_m: tl.constexpr,
    hidden_stride_k: tl.constexpr,
    x_stride_m: tl.constexpr,
    x_stride_k: tl.constexpr,
    x_sf_stride_m: tl.constexpr,
    x_sf_stride_k: tl.constexpr,
    topk_ids_stride_m: tl.constexpr,
    topk_ids_stride_k: tl.constexpr,
    topk_weights_stride_m: tl.constexpr,
    topk_weights_stride_k: tl.constexpr,
    topk_idx_stride_m: tl.constexpr,
    topk_idx_stride_k: tl.constexpr,
    topk_weights_out_stride_m: tl.constexpr,
    topk_weights_out_stride_k: tl.constexpr,
    hidden_size: tl.constexpr,
    top_k: tl.constexpr,
    BLOCK_K: tl.constexpr,
    GROUP_K: tl.constexpr,
    BLOCK_TOPK: tl.constexpr,
) -> None:
    token_id = tl.program_id(0)
    k_block_id = tl.program_id(1)

    k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
    k_mask = k_offsets < hidden_size
    hidden = tl.load(
        hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
        mask=k_mask,
        other=0.0,
    ).to(tl.float32)

    num_groups: tl.constexpr = BLOCK_K // GROUP_K
    hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
    amax = tl.max(hidden_groups, axis=1)
    amax = tl.maximum(amax, 1.0e-4)

    scale = amax / 448.0
    scale_bits = scale.to(tl.uint32, bitcast=True)
    scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
        tl.uint32
    )
    scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
    rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)

    hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
    scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
    scaled = tl.reshape(scaled, [BLOCK_K])
    fp8 = scaled.to(tl.float8e4nv)
    tl.store(
        x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
        fp8,
        mask=k_mask,
    )

    scale_offsets = tl.arange(0, num_groups)
    packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
    tl.store(
        x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
        packed_scale,
    )

    if k_block_id == 0:
        topk_offsets = tl.arange(0, BLOCK_TOPK)
        topk_mask = topk_offsets < top_k

        ids = tl.load(
            topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
            mask=topk_mask,
            other=0,
        ).to(tl.int64)
        tl.store(
            topk_idx_out
            + token_id * topk_idx_stride_m
            + topk_offsets * topk_idx_stride_k,
            ids,
            mask=topk_mask,
        )

        weights = tl.load(
            topk_weights
            + token_id * topk_weights_stride_m
            + topk_offsets * topk_weights_stride_k,
            mask=topk_mask,
            other=0.0,
        )
        tl.store(
            topk_weights_out
            + token_id * topk_weights_out_stride_m
            + topk_offsets * topk_weights_out_stride_k,
            weights,
            mask=topk_mask,
        )

이 Triton 커널은 hidden_statesfloat8e4nv 형식으로 양자화하고, 동적으로 스케일 팩터(scale factor)를 계산하여 x_sf에 압축된 형태로 저장합니다. 또한, topk_idstopk_weightstopk_idx_outtopk_weights_out으로 스테이징합니다. 이 과정은 MegaMoE 레이어의 다음 연산을 위해 데이터를 효율적으로 준비하는 역할을 합니다.

Python 래퍼 (_stage_deepseek_v4_mega_moe_inputs):

def _stage_deepseek_v4_mega_moe_inputs(
    hidden_states: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    x_fp8: torch.Tensor,
    x_sf: torch.Tensor,
    topk_idx_out: torch.Tensor,
    topk_weights_out: torch.Tensor,
) -> None:
    num_tokens, hidden_size = hidden_states.shape
    if num_tokens == 0:
        return
    if hidden_size % 128 != 0:
        raise ValueError(
            "DeepSeek V4 MegaMoE input staging requires hidden_size to be "
            "a multiple of 128."
        )
    top_k = topk_ids.shape[1]
    if topk_weights.shape != topk_ids.shape:
        raise ValueError(
            "DeepSeek V4 MegaMoE input staging requires topk_weights and "
            "topk_ids to have the same shape."
        )

    block_k = 128
    grid = (num_tokens, triton.cdiv(hidden_size, block_k))
    block_topk = triton.next_power_of_2(top_k)
    _deepseek_v4_stage_mega_moe_inputs_kernel[grid](
        hidden_states,
        x_fp8,
        x_sf,
        topk_ids,
        topk_weights,
        topk_idx_out,
        topk_weights_out,
        hidden_states.stride(0),
        hidden_states.stride(1),
        x_fp8.stride(0),
        x_fp8.stride(1),
        x_sf.stride(0),
        x_sf.stride(1),
        topk_ids.stride(0),
        topk_ids.stride(1),
        topk_weights.stride(0),
        topk_weights.stride(1),
        topk_idx_out.stride(0),
        topk_idx_out.stride(1),
        topk_weights_out.stride(0),
        topk_weights_out.stride(1),
        hidden_size,
        top_k,
        BLOCK_K=block_k,
        GROUP_K=32,
        BLOCK_TOPK=block_topk,
        num_warps=4,
    )

이 Python 래퍼는 Triton 커널을 호출하기 전에 입력 텐서의 유효성을 검사하고, Triton 커널 실행에 필요한 그리드(grid) 및 블록(block) 크기를 설정합니다. hidden_size가 128의 배수여야 한다는 제약 조건은 커널의 효율적인 메모리 접근 패턴을 위한 것으로 보입니다.

왜 이게 좋은 최적화/개선인가?

이 PR은 여러 측면에서 vLLM의 DeepSeek V4 모델 추론에 있어 중요한 최적화 및 개선을 가져옵니다.

1. ROCm 플랫폼 지원 강화 및 안정성 확보

triton_kernels 임포트 로직 수정은 ROCm 환경에서 vLLM이 Triton 커널을 안정적으로 컴파일하고 실행할 수 있도록 합니다. 이는 ROCm 사용자들이 DeepSeek V4 모델을 vLLM에서 문제없이 활용할 수 있는 기반을 마련하며, 플랫폼 간 호환성을 높이는 데 기여합니다. 기존의 컴파일 오류를 해결함으로써 개발 및 배포의 안정성이 크게 향상됩니다.

2. 하드웨어별 최적화의 시작점

vllm/models/deepseek_v4/amd/model.py의 도입은 vLLM이 특정 하드웨어 플랫폼(여기서는 ROCm)의 특성을 최대한 활용하기 위한 아키텍처적 개선입니다. NVIDIA와 AMD GPU는 아키텍처적 차이가 존재하므로, 각 플랫폼에 맞는 최적화된 커널과 모델 구현은 필수적입니다. 이 파일은 ROCm에 특화된 DeepSeek V4 모델의 성능 튜닝을 위한 전용 공간을 제공하며, 향후 더 많은 ROCm 최적화가 이루어질 수 있는 발판을 마련합니다.

3. FP8 양자화를 통한 성능 향상

_deepseek_v4_stage_mega_moe_inputs_kernel에서 hidden_statesfloat8e4nv로 양자화하는 것은 메모리 대역폭(bandwidth)과 연산량을 크게 줄여줍니다. FP8은 bfloat16이나 float16보다 훨씬 적은 비트를 사용하므로, GPU 메모리에서 데이터를 읽고 쓰는 데 필요한 시간이 단축되고, 더 많은 데이터를 동시에 처리할 수 있게 됩니다. 이는 특히 대규모 모델에서 추론 속도 향상에 결정적인 역할을 합니다.

  • 동적 스케일링: amax를 계산하고 이를 기반으로 scale을 동적으로 조정하는 것은 FP8 양자화의 핵심입니다. 이는 낮은 정밀도에서도 수치적 안정성을 유지하며, 모델의 정확도를 크게 저하시키지 않으면서 성능을 극대화하는 기법입니다.
  • 스케일 팩터 압축: scale_expint32로 압축하여 저장하는 것은 메모리 효율성을 더욱 높이는 세부적인 최적화입니다.

4. Speculative Decoding (MTP)을 통한 추론 속도 개선

PR 설명과 벤치마크 결과에 따르면, MTP(Multi-Token Prediction) 기능은 특히 낮은 동시성(low batch size) 환경에서 상당한 성능 향상을 가져옵니다.

Max concurrency No-MTP out tok/s MTP out tok/s Out tok/s delta
1 20.10 35.10 +74.7%
8 131.81 181.72 +37.9%
  • 낮은 동시성에서의 큰 이득: max concurrency가 1일 때 출력 토큰 처리량(output tokens/s)이 74.7% 증가하고, 8일 때 37.9% 증가하는 것을 확인할 수 있습니다. 이는 MTP가 적은 요청을 처리할 때 첫 토큰 생성 시간(TTFT)과 전체 추론 시간(E2E latency)을 크게 단축시켜 사용자 경험을 개선함을 의미합니다.
  • TTFT 및 E2E Latency 개선: mean TTFT ms는 1일 때 441.1ms에서 318.4ms로, 8일 때 1562.4ms에서 984.4ms로 감소했습니다. mean E2E ms 또한 1일 때 12737.3ms에서 7292.5ms로, 8일 때 15520.4ms에서 10532.8ms로 크게 줄어들었습니다.
  • 높은 동시성에서의 트레이드오프: max concurrency가 32 이상에서는 MTP의 성능 이득이 줄어들거나 오히려 감소하는 경향을 보입니다. 이는 MTP가 드래프트 모델의 추측을 활용하는 방식이 높은 동시성 환경에서는 오버헤드가 될 수 있음을 시사합니다. PR 설명에서도

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글