본문으로 건너뛰기

[sglang] DeepseekV4 모델의 입력 레이어 정규화와 FP8 양자화를 융합하여 성능 최적화

PR 링크: sgl-project/sglang#25043 상태: Merged | 변경: +0 / -0

들어가며

최근 대규모 언어 모델(LLM)의 발전 속도는 눈부십니다. 하지만 모델의 크기가 커질수록 연산량 또한 기하급수적으로 증가하여, 이를 효율적으로 처리하기 위한 최적화 기술의 중요성이 더욱 커지고 있습니다. 특히 GPU와 같은 하드웨어 가속기의 성능을 최대한 활용하는 것은 LLM 서비스의 응답 속도와 처리량에 직접적인 영향을 미칩니다.

이번 PR은 sglang 프로젝트에서 DeepseekV4 모델의 특정 연산 경로를 최적화하여 GPU 연산 효율성을 높이는 것을 목표로 합니다. 구체적으로, input_layernorm과 어텐션 메커니즘으로 이어지는 경로에서 발생하는 여러 커널 호출을 하나의 융합된 커널로 통합하고, FP8 데이터 타입을 활용하여 메모리 대역폭과 연산량을 줄이는 개선을 포함합니다.

이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 최적화가 왜 성능 향상에 기여하는지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.

코드 분석

이번 PR의 핵심 변경 사항은 python/sglang/srt/models/deepseek_v4.py 파일에 집중되어 있습니다. 주요 개선점은 다음과 같습니다.

1. 모듈 레벨 설정 및 커널 융합 준비

PR은 먼저 SGLANG_USE_AITER 환경 변수와 AMD GPU 지원 여부를 확인하는 로직을 모듈 최상단으로 옮겨 효율성을 높였습니다. 또한, GFX950 (MI355) 아키텍처에서 지원되는 융합 커널(fused_rms_fp8_group_quant)을 조건부로 임포트합니다.

Before:

# ... (이전 코드) ...
if _is_hip:
    from aiter import rope_rotate_activation
# ... (이후 코드) ...

After:

# ... (이전 코드) ...
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()

if _use_aiter:
    from aiter import rope_rotate_activation
    if is_gfx95_supported():
        from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
# ... (이후 코드) ...

이 변경은 _is_hip 체크를 더 구체적인 _use_aiter_is_gfx95_supported 조건으로 대체하여, 필요한 경우에만 관련 모듈을 임포트하도록 합니다. 이는 코드의 명확성을 높이고 불필요한 임포트를 방지합니다.

2. 융합 커널 헬퍼 함수 도입

_fused_rmsnorm_fp8_quant라는 새로운 헬퍼 함수가 도입되었습니다. 이 함수는 fused_rms_fp8_group_quant 커널을 호출하며, group_size=128dtype_quant=torch.float8_e4m3fn 설정을 적용합니다. 이 함수는 FP8 형식의 양자화된 결과(x_quant)와 BF16 형식의 결과(x_bf16)를 함께 반환합니다.

def _fused_rmsnorm_fp8_quant(hidden_states, weight, eps):
    x_quant, x_bf16, _, _ = fused_rms_fp8_group_quant(
        hidden_states,
        weight,
        eps,
        inp2=None,
        inp2_weight=None,
        inp2_epsilon=None,
        group_size=128,
        dtype_quant=torch.float8_e4m3fn,
        res1=None,
        output_unquantized_inp1=True,
    )
    return x_quant, x_bf16

이 헬퍼 함수는 input_layernorm과 FP8 양자화 과정을 하나의 단위로 묶어, DeepseekV4DecoderLayer.forward 메소드에서 재사용될 수 있도록 합니다.

3. DeepseekV4DecoderLayer.forward()에서의 조건부 실행

DeepseekV4DecoderLayer.forward 메소드는 이제 _use_aiter_is_gfx95_supported 조건에 따라 두 가지 경로로 나뉩니다. 조건이 충족되면 새로 도입된 _fused_rmsnorm_fp8_quant 함수를 호출하여 x_quant를 생성하고, 이를 self.self_attn()에 전달합니다. 그렇지 않으면 기존의 self.input_layernorm(hidden_states)를 호출하고 x_quantNone으로 설정합니다.

Before:

# ... (이전 코드) ...
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            x=hidden_states,
            positions=positions,
            forward_batch=forward_batch,
        )
# ... (이후 코드) ...

After:

# ... (이전 코드) ...
        if _use_aiter and _is_gfx95_supported:
            x_quant, hidden_states = _fused_rmsnorm_fp8_quant(
                hidden_states,
                self.input_layernorm.weight,
                self.input_layernorm.variance_epsilon,
            )
        else:
            x_quant = None
            hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            x=hidden_states,
            positions=positions,
            forward_batch=forward_batch,
            x_quant=x_quant,
        )
# ... (이후 코드) ...

이 변경은 FP8 양자화가 필요한 경우에만 해당 연산을 수행하도록 하여, 지원되지 않는 환경에서는 기존 로직을 유지하면서 호환성을 보장합니다.

4. MQALayer.forward()_forward_prepare()에서의 FP8 처리

MQALayer_forward_prepare 함수는 x_quant 인자를 받도록 수정되었습니다. wq_awkv 연산 시, x_quant가 제공되면 이를 사용하고, 그렇지 않으면 원래의 x를 사용합니다. 이는 FP8 양자화된 결과(x_quant)를 어텐션 메커니즘의 wq_awkv 선형 계층에 전달하는 역할을 합니다.

Before (_forward_prepare):

# ... (이전 코드) ...
        # [bs, q_lora_rank]
        q, _ = self.wq_a(x)
# ... (이전 코드) ...
        # [bs, head_dim]
        kv, _ = self.wkv(x)
# ... (이후 코드) ...

After (_forward_prepare):

# ... (이전 코드) ...
        # fp8 tuple: Fp8LinearMethod.apply handles isinstance(x, tuple), skips per1x128 quant
        q, _ = self.wq_a(x_quant if x_quant is not None else x)
# ... (이전 코드) ...
        # fp8 tuple: same as wq_a
        kv, _ = self.wkv(x_quant if x_quant is not None else x)
# ... (이후 코드) ...

또한, MQALayer.forward 함수도 x_quant 인자를 받도록 수정되었습니다. 이 변경은 input_layernorm에서 생성된 FP8 결과(x_quant)가 어텐션 연산(self.self_attn)으로 전달될 수 있도록 하는 핵심적인 연결고리 역할을 합니다.

이러한 변경의 중요한 점은 Fp8LinearMethod.apply가 FP8 튜플(x_quant)을 받았을 때 내부적으로 isinstance(x, tuple) 체크를 통해 FP8 양자화 과정을 건너뛰고 FP8 데이터(x[0])와 스케일(x[1])을 직접 사용한다는 것입니다. 이는 불필요한 중복 양자화를 방지합니다.

반면, indexer, compressor, cp_all_gather와 같은 다른 소비자들은 FP8 튜플을 직접 처리하지 못합니다. 이들은 .dtype 속성을 요구하거나 .contiguous()와 같은 메소드를 호출하는데, 튜플은 이러한 속성이나 메소드를 가지지 않기 때문입니다. PR은 이러한 소비자들에게는 원래의 BF16 결과(x)를 전달하여 호환성을 유지합니다. 이는 x_quant if x_quant is not None else x와 같은 조건부 로직을 통해 구현됩니다.

왜 이게 좋은가

이번 PR의 주요 목표는 input_layernorm → 어텐션 경로에서 발생하는 불필요한 커널 호출을 줄여 성능을 향상시키는 것입니다. 구체적으로 다음과 같은 이점들이 있습니다.

  1. 커널 호출 수 감소: PR 설명에 따르면, GFX950 (MI355) 환경에서 SGLANG_USE_AITER=1 설정 시, 각 레이어마다 2개의 dynamic_per_group_scaled_quant_kernel 호출이 제거됩니다. 45개 레이어를 가진 모델의 경우, 총 90개의 커널 호출이 줄어드는 효과가 있습니다. 프로파일링 결과에서도 dynamic_per_group_scaled_quant_kernel의 호출 수가 2,376에서 1,370으로 약 1,006건 감소한 것이 확인되었습니다.
  2. FP8 활용: FP8 데이터 타입을 사용함으로써 연산에 필요한 메모리 대역폭을 줄이고, 일부 경우 연산 자체의 속도를 높일 수 있습니다. 특히 FP8 양자화와 RMSNorm 연산을 융합한 fused_rms_fp8_group_quant 커널은 GPU의 연산 효율성을 극대화합니다.
  3. 성능 향상: 이러한 최적화는 실제 속도 향상으로 이어집니다. PR에 포함된 GSM8K 5-shot 정확도 테스트 결과, 이전 대비 처리량(Output throughput)이 311.765 token/s로 측정되었습니다. (이전 수치는 제공되지 않았지만, 커널 호출 수 감소와 FP8 활용은 분명한 성능 개선 요인입니다.)
  4. 코드 명확성 및 유지보수성: 모듈 레벨 설정을 통합하고 헬퍼 함수를 도입함으로써 코드의 가독성과 유지보수성이 향상되었습니다. 조건부 로직을 통해 다양한 하드웨어 및 환경 설정에 대한 지원을 유연하게 관리할 수 있습니다.

일반적 교훈

  • 연산 융합 (Kernel Fusion): 여러 개의 작은 커널 호출을 하나의 큰 커널로 융합하는 것은 GPU 연산에서 매우 효과적인 최적화 기법입니다. 이는 커널 실행 오버헤드를 줄이고, 데이터 재사용성을 높이며, 메모리 접근 패턴을 개선할 수 있습니다.
  • 데이터 타입 활용: FP8과 같은 저정밀도 데이터 타입을 적절히 활용하면 메모리 및 연산 효율성을 크게 향상시킬 수 있습니다. 다만, 데이터 타입 변환 및 호환성 문제를 신중하게 고려해야 합니다.
  • 조건부 로직: 특정 하드웨어 또는 환경 설정에 의존적인 최적화를 적용할 때는, 조건부 로직을 사용하여 호환성을 유지하고 불필요한 연산을 피하는 것이 중요합니다.
  • 프로파일링의 중요성: 성능 병목 지점을 정확히 파악하고 최적화 효과를 검증하기 위해서는 상세한 프로파일링이 필수적입니다.

리뷰 피드백 반영

PR 설명에는 별도의 리뷰 댓글이 제공되지 않았습니다. 하지만 PR 설명 자체에 Consumer compatibility 섹션이 포함되어 있어, 다양한 소비자(indexer, compressor 등)와의 호환성 문제를 미리 인지하고 해결하려는 노력이 엿보입니다. FP8 튜플을 직접 처리하지 못하는 소비자들에게는 원래의 BF16 데이터를 전달하는 방식은 이러한 호환성 문제를 해결하기 위한 중요한 설계 결정입니다.

References

  • torch.compile - 이 PR에서 사용된 최적화 기법과 관련될 수 있는 PyTorch의 컴파일 기능에 대한 공식 문서입니다. (직접적인 함수는 아니지만, 성능 최적화 맥락에서 관련성이 있습니다.)
  • aiter.ops.triton.fused_fp8_quant.fused_rms_fp8_group_quant - PR에서 사용된 융합 커널의 소스 코드 (공식 문서 링크가 없을 경우 소스 코드로 대체)
  • DeepseekV4DecoderLayer - 수정된 DeepseekV4DecoderLayer 클래스의 위치 (소스 코드 링크)
  • MQALayer - 수정된 MQALayer 클래스의 위치 (소스 코드 링크)

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글