본문으로 건너뛰기

[axolotl] Axolotl, Marlin W4A16 도입으로 MoE 모델 추론 속도 1.79배 향상 및 품질 개선

PR 링크: axolotl-ai-cloud/axolotl#3745 상태: Merged | 변경: +5660 / -41

들어가며

최근 대규모 언어 모델(LLM) 분야에서는 Mixture-of-Experts (MoE) 아키텍처가 큰 주목을 받고 있습니다. MoE는 여러 개의 작은 신경망(expert) 중에서 입력에 가장 적합한 expert를 선택하여 계산을 수행하는 방식으로, 모델의 크기를 늘리지 않으면서도 성능을 크게 향상시킬 수 있습니다. 하지만 MoE 모델은 여러 expert를 동시에 활성화해야 하므로, 특히 추론 시 상당한 계산 비용이 발생할 수 있습니다.

Axolotl은 이러한 MoE 모델의 효율성을 극대화하기 위해 지속적으로 최적화를 진행해왔습니다. 이번 PR은 특히 DeepSeek-V4와 같은 MoE 모델에서 sm120 아키텍처를 사용하는 경우, 기존 CUTLASS 기반의 W4A4 GEMM(General Matrix Multiply) 연산보다 훨씬 빠르고 정확한 Marlin W4A16 백엔드를 도입하는 것을 목표로 합니다. 이 변경을 통해 추론 속도가 약 1.79배 향상되고, 기존 방식에서 발생하던 활성화 양자화 오류가 제거되어 모델의 정확도 또한 개선되었습니다.

본 글에서는 이 PR에서 어떤 코드가 어떻게 변경되었는지, 그리고 이러한 변경이 왜 성능과 품질 향상으로 이어지는지 상세하게 분석해보겠습니다.

코드 분석

이번 PR의 핵심은 새로운 Marlin W4A16 백엔드를 기존 Axolotl의 MoE 파이프라인에 통합하고, 이를 sm120 GPU에서 기본으로 사용하도록 설정하는 것입니다. 주요 변경 사항은 다음과 같습니다.

1. marlin_w4a16/ 디렉토리 추가

새로운 Marlin 백엔드 관련 코드가 이 디렉토리에 포함되었습니다. 이 백엔드는 vLLM의 런타임 의존성 없이 독립적으로 작동합니다.

  • _csrc/: Marlin MoE GEMM (moe_wna16_marlin_gemm)과 gptq_marlin_repack 함수가 포함되어 있습니다. 이들은 vLLM에서 추출 및 포팅되었으며, torch::stable ABI 대신 일반적인 torch::Tensor ABI를 사용합니다. JIT 컴파일을 통해 sm_120a 용으로 빌드되고 디스크에 캐싱됩니다.
  • prep.py: NVFP4 가중치를 Marlin 형식으로 변환하는 스크립트입니다. vLLM의 순수 PyTorch 스케일 처리 헬퍼를 그대로 사용하며, torchaoNVFP4Tensor와 유사한 수준의 정확도를 보장합니다.
  • __init__.py: Marlin 백엔드의 사용 가능 여부를 확인하고, 필요한 경우 CUDA 확장을 지연 로드하는 기능을 담당합니다.

2. grouped_train.pygrouped_moe.py 수정

기존의 MoE 학습 및 추론 파이프라인에서 새로운 Marlin 백엔드를 활용하도록 수정되었습니다.

  • grouped_fp4_backend 함수 수정 (grouped_moe.py, grouped_train.py): 기존에는 sm120에서 CUTLASS를 기본으로 사용했지만, 이제는 Marlin W4A16 백엔드를 우선적으로 선택합니다. Marlin 백엔드가 사용 가능하고 mode가 "nvfp4"인 경우, Marlin을 반환합니다. 이는 sm120에서 Marlin이 최적의 선택임을 나타냅니다.

    diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py
    index 0f63785d7d..90f18e576c 100644
    --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py
    +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py
    @@ -25,19 +25,19 @@
     
     
     def grouped_fp4_backend(mode: str) -> str | None:
    
  • """Best available base-GEMM backend: 'cutlass' (sm120) | 'deepgemm' (sm90/100)."""

  • """Best available base-GEMM backend: 'marlin' (sm120 W4A16, preferred) | 'cutlass' (sm120) |

  • 'deepgemm' (sm90/100) | 'chunked'. On sm120 Marlin W4A16 is ~1.79x faster than CUTLASS and

  • bit-correct (bf16 activations, no activation quantization)."""

  • try:

  •    from .marlin_w4a16 import marlin_w4a16_available
    
  •    if mode == "nvfp4" and marlin_w4a16_available():
    
  •        return "marlin"
    
  • except Exception:

  •    pass
    

    try: from .cutlass_fp4 import cutlass_fp4_available

    @@ -43,19 +43,19 @@ return "chunked" if torch.cuda.is_available() else None

-def _route(idx, E, dev): +def _route(idx, E, dev, tile=TILE): flat = idx.reshape(-1) order = flat.argsort() rep = torch.arange(idx.size(0), device=dev).repeat_interleave(idx.size(1))[order] exp_sorted = flat[order] counts = torch.bincount(flat, minlength=E)

  • ptiles = (counts + TILE - 1) // TILE
  • roff = torch.cat([ptiles.new_zeros(1), ptiles.cumsum(0)]) * TILE
  • ptiles = (counts + tile - 1) // tile
  • roff = torch.cat([ptiles.new_zeros(1), ptiles.cumsum(0)]) * tile coff = torch.cat([counts.new_zeros(1), counts.cumsum(0)]) local = torch.arange(exp_sorted.numel(), device=dev) - coff[exp_sorted] padded_row = roff[exp_sorted] + local m_indices = torch.repeat_interleave(torch.arange(E, dtype=torch.int32, device=dev), ptiles)
  • Mt = int(ptiles.sum()) * TILE
  • Mt = int(ptiles.sum()) * tile return rep, padded_row, m_indices, counts, Mt

@@ -96,13 +96,24 @@ def grouped_fp4_moe_forward(hidden, idx, wts, gate_up_nv, down_nv, limit, mode, Idim = down_nv.qdata.size(2) * 2 # down K = I (packed K/2) dev = hidden.device backend = backend or grouped_fp4_backend(mode)

  • rep, padded_row, m_indices, counts, Mt = _route(idx, E, dev)
  • if backend == "marlin":

  •    from .marlin_w4a16.backend import MARLIN_TILE
    
  •    tile = MARLIN_TILE
    
  • else:

  •    tile = TILE
    
  • rep, padded_row, m_indices, counts, Mt = _route(idx, E, dev, tile) wflat = wts.reshape(-1)[idx.reshape(-1).argsort()]

    A = hidden.new_zeros(Mt, H) A[padded_row] = hidden[rep]

  • if backend == "cutlass":
  • marlin_base = None

  • if backend == "marlin":

  •    from .marlin_w4a16.backend import build_marlin_forward_base, marlin_base_forward
    
  •    marlin_base = build_marlin_forward_base(gate_up_nv, down_nv)
    
  •    gu = marlin_base_forward(marlin_base, 0, A, m_indices).float()
    
  • elif backend == "cutlass": from .cutlass_fp4.grouped import quant_act

    gu_eng = _engine(Mt, 2 * (down_nv.qdata.size(2) * 2), H, E, mode)  # N = 2I
    

@@ -125,7 +136,9 @@ def grouped_fp4_moe_forward(hidden, idx, wts, gate_up_nv, down_nv, limit, mode, g, u = gu.chunk(2, dim=-1) h = (F.silu(g.clamp(max=limit)) * u.clamp(min=-limit, max=limit)).to(hidden.dtype)

  • if backend == "cutlass":
  • if backend == "marlin":

  •    dn = marlin_base_forward(marlin_base, 1, h.contiguous(), m_indices)
    
  • elif backend == "cutlass": from .cutlass_fp4.grouped import quant_act

    dn_eng = _engine(Mt, H, down_nv.qdata.size(2) * 2, E, mode)  # K = I
    

- **`_train_backend` 함수 수정 (`grouped_train.py`)**: 학습 시에도 `sm120` GPU에서 Marlin 백엔드를 우선적으로 사용하도록 변경되었습니다.

- **`_base_forward` 함수 수정 (`grouped_train.py`)**: `_base_forward` 함수가 Marlin 백엔드를 지원하도록 확장되었습니다. Marlin 백엔드의 경우, `build_marlin_forward_base`와 `marlin_base_forward` 함수를 사용하여 게이트-업(gate_up) 및 다운(down) 연산을 수행합니다.

- **`_route` 함수 수정 (`grouped_moe.py`)**: Marlin 백엔드는 `MARLIN_TILE` (기본값 64)을 사용하며, 이는 기존 `TILE` (기본값 128)보다 작은 패딩 크기를 허용합니다. 이 변경은 `_route` 함수에 `tile` 인자를 추가하여 반영되었습니다.

```diff
diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py
index 0f63785d7d..90f18e576c 100644
--- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py
+++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/grouped_moe.py
@@ -1,9 +1,12 @@
"""Unified grouped NVFP4 MoE forward for DeepSeek-V4 experts (config-gated, dsv4_fp4_grouped_mode).

-Contiguous-grouped: tokens sorted by expert, padded per-expert to TILE(128); ONE grouped gate_up
-GEMM -> clamped-SwiGLU -> ONE grouped down GEMM, with GPU-vectorized routing/pack/scatter (no
-per-expert Python loop). The base GEMM auto-dispatches to the best fp4 path:
-    DeepGEMM (sm90/sm100)  ->  CUTLASS grouped (sm120)  ->  chunked-dequant (any GPU, fallback).
+Contiguous-grouped: tokens sorted by expert, padded per-expert to TILE; ONE grouped gate_up GEMM
+-> clamped-SwiGLU -> ONE grouped down GEMM, with GPU-vectorized routing/pack/scatter (no per-expert
+Python loop). The base GEMM auto-dispatches to the best fp4 path:
+    Marlin W4A16 (sm120, pad-64)  ->  DeepGEMM (sm90/sm100)  ->  CUTLASS grouped (sm120, pad-128)
+->  chunked-dequant (any GPU, fallback).
Marlin W4A16 (bf16 activations, bit-correct) is preferred on sm120: ~1.79x faster than CUTLASS and
no activation-quant error. See marlin_w4a16/.
@@ -25,19 +28,19 @@


def grouped_fp4_backend(mode: str) -> str | None:
   """Best available base-GEMM backend: 'cutlass' (sm120) | 'deepgemm' (sm90/100)."""
-    try:
-        from .cutlass_fp4 import cutlass_fp4_available
-        if cutlass_fp4_available():
-            return "cutlass"
-    except Exception:
-        pass
-    try:
-        from .deepgemm_fp4 import deepgemm_fp4_available
-        if deepgemm_fp4_available():
-            return "deepgemm"
-    except Exception:
-        pass
+    try:
+        from .marlin_w4a16 import marlin_w4a16_available
+
+        if mode == "nvfp4" and marlin_w4a16_available():
+            return "marlin"
+    except Exception:
+        pass
+    try:
+        from .cutlass_fp4 import cutlass_fp4_available
+
+        if cutlass_fp4_available():
+            return "cutlass"
+    except Exception:
+        pass
+    try:
+        from .deepgemm_fp4 import deepgemm_fp4_available
+
+        if deepgemm_fp4_available():
+            return "deepgemm"
+    except Exception:
+        pass
   return "chunked" if torch.cuda.is_available() else None


-def _route(idx, E, dev):
+def _route(idx, E, dev, tile=TILE):
   flat = idx.reshape(-1)
   order = flat.argsort()
   rep = torch.arange(idx.size(0), device=dev).repeat_interleave(idx.size(1))[order]
   exp_sorted = flat[order]
   counts = torch.bincount(flat, minlength=E)
-    ptiles = (counts + TILE - 1) // TILE
-    roff = torch.cat([ptiles.new_zeros(1), ptiles.cumsum(0)]) * TILE
+    ptiles = (counts + tile - 1) // tile
+    roff = torch.cat([ptiles.new_zeros(1), ptiles.cumsum(0)]) * tile
   coff = torch.cat([counts.new_zeros(1), counts.cumsum(0)])
   local = torch.arange(exp_sorted.numel(), device=dev) - coff[exp_sorted]
   padded_row = roff[exp_sorted] + local
   m_indices = torch.repeat_interleave(torch.arange(E, dtype=torch.int32, device=dev), ptiles)
-    Mt = int(ptiles.sum()) * TILE
+    Mt = int(ptiles.sum()) * tile
   return rep, padded_row, m_indices, counts, Mt


def grouped_fp4_moe_forward(hidden, idx, wts, gate_up_nv, down_nv, limit, mode, backend=None, mxfp4_cache=None):
@@ -96,13 +99,24 @@ def grouped_fp4_moe_forward(hidden, idx, wts, gate_up_nv, down_nv, limit, mode,
   Idim = down_nv.qdata.size(2) * 2  # down K = I (packed K/2)
   dev = hidden.device
   backend = backend or grouped_fp4_backend(mode)
-    rep, padded_row, m_indices, counts, Mt = _route(idx, E, dev)
+    if backend == "marlin":
+        from .marlin_w4a16.backend import MARLIN_TILE
+        tile = MARLIN_TILE
+    else:
+        tile = TILE
+    rep, padded_row, m_indices, counts, Mt = _route(idx, E, dev, tile)
   wflat = wts.reshape(-1)[idx.reshape(-1).argsort()]

   A = hidden.new_zeros(Mt, H)
   A[padded_row] = hidden[rep]

-    if backend == "cutlass":
+    marlin_base = None
+    if backend == "marlin":
+        from .marlin_w4a16.backend import build_marlin_forward_base, marlin_base_forward
+
+        marlin_base = build_marlin_forward_base(gate_up_nv, down_nv)
+        gu = marlin_base_forward(marlin_base, 0, A, m_indices).float()
+    elif backend == "cutlass":
       from .cutlass_fp4.grouped import quant_act

       gu_eng = _engine(Mt, 2 * (down_nv.qdata.size(2) * 2), H, E, mode)  # N = 2I
@@ -125,7 +139,9 @@ def grouped_fp4_moe_forward(hidden, idx, wts, gate_up_nv, down_nv, limit, mode,
   g, u = gu.chunk(2, dim=-1)
   h = (F.silu(g.clamp(max=limit)) * u.clamp(min=-limit, max=limit)).to(hidden.dtype)

-    if backend == "cutlass":
+    if backend == "marlin":
+        dn = marlin_base_forward(marlin_base, 1, h.contiguous(), m_indices)
+    elif backend == "cutlass":
       from .cutlass_fp4.grouped import quant_act

       dn_eng = _engine(Mt, H, down_nv.qdata.size(2) * 2, E, mode)  # K = I
  • grouped_fp4_available 함수 수정 (grouped_train.py): Marlin 백엔드의 사용 가능 여부를 확인하는 로직이 추가되었습니다.

3. dequant_grouped.py 수정

역전파(backward) 과정에서 사용되는 _grouped_dx_fp8_kernel 함수가 수정되었습니다. 기존에는 BM (Block Matrix) 값이 고정되어 있었으나, 이제는 block_m 인자를 통해 외부에서 전달받도록 변경되었습니다. 이 block_m 값은 Marlin 백엔드를 사용할 때 64로, CUTLASS 백엔드를 사용할 때 128로 설정되어, 각 백엔드의 특성에 맞는 최적의 패딩을 적용합니다.

diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/dequant_grouped.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/dequant_grouped.py
index 51cdb31935..246e2bdd8f 100644
--- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/dequant_grouped.py
+++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/dequant_grouped.py
@@ -187,10 +187,13 @@ def nvfp4_dequant_fp8(qdata: torch.Tensor, scale: torch.Tensor, per_tensor: torc
     return out
 
 
+# BM is the per-expert tile (= the routing pad TILE); passed by the caller (128 for the cutlass
+# path, 64 for the marlin path) rather than fixed, so one expert weight maps to each BM-row block.
+# It is a constexpr (and an autotune key) so BN/BK/warps/stages still autotune per (N, K, BM).
 @triton.autotune(
-    configs=[triton.Config({

## 참고 자료
- https://pytorch.org/docs/stable/generated/torch.compile.html

> ⚠️ **알림:** 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글