본문으로 건너뛰기

[sglang] FlashInfer v0.6.7 MXFP8 Gemm 통합: CUTLASS와 TensorRT-LLM 백엔드 분리

PR 링크: sgl-project/sglang#21576 상태: Merged | 변경: +75 / -28

들어가며

Blackwell GPU(SM100+)에서 MXFP8 양자화를 활용한 Gemm 연산은 높은 처리량의 핵심입니다. FlashInfer v0.6.7은 TensorRT-LLM 기반의 MXFP8 Gemm 커널을 새로 제공하는데, 기존 CUTLASS 백엔드와는 weight 레이아웃과 scale factor 처리 방식이 다릅니다. 이번 PR은 두 백엔드의 전처리 로직을 명확히 분리하고, TensorRT-LLM 경로에 필요한 shuffle_matrix_a/shuffle_matrix_sf_a 호출을 추가합니다.

핵심 코드 분석

1. Weight 전처리 분기

Before:

if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
    from flashinfer import block_scale_interleave
    new_swizzled = block_scale_interleave(scale_u8.contiguous()).contiguous()
# ...
copy_or_rebind_param(layer, "weight_scale_inv_swizzled", new_swizzled)

After:

if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
    from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
    copy_or_rebind_param(layer, "weight",
        shuffle_matrix_a(weight.contiguous().view(torch.uint8), epilogue_tile_m)
        .view(torch.float8_e4m3fn))
    copy_or_rebind_param(layer, "weight_scale_inv",
        shuffle_matrix_sf_a(scale_u8.contiguous().view(torch.uint8).reshape(n, k // 32),
            epilogue_tile_m, num_elts_per_sf=32).reshape_as(scale_u8).contiguous())
elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass():
    from flashinfer import block_scale_interleave
    copy_or_rebind_param(layer, "weight_scale_inv",
        block_scale_interleave(scale_u8.contiguous()).contiguous())

TRT-LLM은 weight 자체도 shuffle_matrix_a로 재배열해야 하며, scale factor는 shuffle_matrix_sf_a로 처리합니다. CUTLASS는 scale만 block_scale_interleave하면 됩니다. 이전에는 swizzled scale을 별도 파라미터(weight_scale_inv_swizzled)로 저장했지만, 이제 weight_scale_inv를 직접 덮어쓰는 방식으로 단순화되었습니다.

2. Gemm 호출 시 백엔드별 분기

Before:

output = flashinfer_mm_mxfp8(q_input, weight_t, x_scale_u8, weight_scale_t,
    out_dtype=output_dtype, backend="auto")

After:

if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
    output = flashinfer_mm_mxfp8(q_input, weight_t, x_scale_u8,
        weight_scale.contiguous().view(-1),
        out_dtype=output_dtype, use_8x4_sf_layout=False, backend="trtllm")
elif get_fp8_gemm_runner_backend().is_flashinfer_cutlass():
    output = flashinfer_mm_mxfp8(q_input, weight_t, x_scale_u8, weight_scale_t,
        out_dtype=output_dtype, use_8x4_sf_layout=False, backend="cutlass")

backend="auto" 대신 명시적으로 "trtllm" 또는 "cutlass"를 지정합니다.

왜 이게 좋은가

  1. 정확한 하드웨어 활용: TRT-LLM 커널은 Blackwell의 특정 행렬 레이아웃에 최적화되어 있어, 올바른 shuffle 없이는 정확도가 보장되지 않습니다.
  2. 백엔드 독립성: 각 백엔드의 전처리 요구사항을 명확히 분리하여, 향후 새 백엔드 추가가 용이합니다.
  3. 단순화: 별도의 weight_scale_inv_swizzled 파라미터와 관련 버전 추적 코드가 제거되었습니다.

정리

MXFP8 양자화의 세 가지 백엔드(Triton, CUTLASS, TRT-LLM)의 전처리 경로를 명확히 분리한 PR입니다. 특히 weight와 scale factor의 메모리 레이아웃이 백엔드마다 다르다는 점을 코드 구조에 반영한 것이 핵심입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글