본문으로 건너뛰기

[SGLang] Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리

들어가며

DeepGEMM은 DeepSeek에서 개발한 FP8 GEMM 라이브러리로, Hopper(SM90) 이상의 GPU에서 최적화된 행렬 곱 연산을 제공한다. SGLang은 deep_gemm_wrapper/ 패키지를 통해 이 라이브러리를 래핑하여 JIT 사전 컴파일, SM 수 제어, Blackwell 지원 등을 추가한다.

구조도

deep_gemm_wrapper/
├── __init__.py       ── entrypoint 재수출
├── configurer.py     ── GPU 능력 검사, 활성화 플래그
├── entrypoint.py     ── GEMM 래퍼 함수 4종
└── compile_utils.py  ── JIT 사전 컴파일, 워밍업 실행기

                     ┌─────────────────────┐
                     │   configurer.py      │
                     │  ENABLE_JIT_DEEPGEMM │
                     │  DEEPGEMM_BLACKWELL  │
                     └────────┬────────────┘
                              │
              ┌───────────────┼───────────────┐
              │               │               │
    ┌─────────▼──────┐ ┌─────▼──────┐ ┌──────▼──────────┐
    │ gemm_nt_f8f8   │ │ grouped    │ │ compile_utils   │
    │ bf16           │ │ masked/    │ │ 사전컴파일 &    │
    │                │ │ contig     │ │ 워밍업           │
    └────────────────┘ └────────────┘ └─────────────────┘

핵심 코드 분석

활성화 조건 검사 (configurer.py)

DeepGEMM은 SM90(Hopper) 이상에서만 동작한다. configurer.py는 GPU 능력과 환경변수를 확인하여 활성화 여부를 결정한다.

def _compute_enable_deep_gemm():
    sm_version = get_device_sm()
    if sm_version < 90:
        return False
    try:
        import deep_gemm
    except ImportError:
        return False
    return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()

ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell_supported()
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL

Blackwell(SM100) GPU에서는 UE8M0 스케일 포맷이 추가로 활성화된다.

GEMM 래퍼 함수 (entrypoint.py)

entrypoint.py는 4가지 GEMM 연산을 래핑한다. 각 함수는 입력 검증, 커널 타입 결정, 사전 컴파일 훅, SM 수 제어를 처리한다.

def gemm_nt_f8f8bf16(
    lhs: Tuple[torch.Tensor, torch.Tensor],
    rhs: Tuple[torch.Tensor, torch.Tensor],
    out: torch.Tensor,
):
    m, k = lhs[0].shape
    n, _ = rhs[0].shape
    kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
    _sanity_check_input(lhs)
    _sanity_check_input(rhs)
    with compile_utils.deep_gemm_execution_hook(m, n, k, 1, kernel_type):
        deep_gemm.fp8_gemm_nt(lhs, rhs, out)

4가지 커널 타입은 다음과 같다.

커널 타입 설명 용도
GEMM_NT_F8F8BF16 일반 FP8 GEMM Dense 모델 선형 레이어
GEMM_NT_BF16BF16F32 BF16 GEMM FP8 미지원 시
GROUPED_GEMM_NT_F8F8BF16_MASKED Grouped Masked GEMM MoE 전문가 연산
GROUPED_GEMM_NT_F8F8BF16_CONTIG Grouped Contiguous GEMM MoE 연속 레이아웃

SM 수 동적 제어

연산-통신 오버랩 시 GPU의 SM(Streaming Multiprocessor)을 분할하여 일부는 연산, 나머지는 통신에 할당한다.

@contextmanager
def configure_deep_gemm_num_sms(num_sms):
    if num_sms is None or not ENABLE_JIT_DEEPGEMM:
        yield
    else:
        original_num_sms = deep_gemm.get_num_sms()
        deep_gemm.set_num_sms(num_sms)
        try:
            yield
        finally:
            deep_gemm.set_num_sms(original_num_sms)

JIT 사전 컴파일 (compile_utils.py)

DeepGEMM의 JIT 컴파일은 첫 실행 시 10-20분이 소요될 수 있다. SGLang은 서버 시작 시 예상되는 모든 M 값에 대해 사전 컴파일을 수행한다.

class DeepGemmKernelType(IntEnum):
    GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
    GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
    GEMM_NT_F8F8BF16 = auto()
    GEMM_NT_BF16BF16F32 = auto()

Fast Warmup 모드에서는 대수적으로 간격을 넓히며 샘플링하여 컴파일 시간을 단축한다.

if _FAST_WARMUP:
    _BUILTIN_M_LIST += list(range(1, 1025))   # decode 성능 보장
    next_m, sample_step = 1024, 2
    while next_m < max_prefill_bs:
        _BUILTIN_M_LIST += list(range(next_m, 2 * next_m, sample_step))
        next_m = next_m * 2
        sample_step = sample_step * 2

이렇게 하면 전체 16K개 M 값 대신 약 3K개만 컴파일하면서도 실제 사용되는 대부분의 배치 크기를 커버한다.

워밍업 실행기

각 커널 타입에 대해 전용 워밍업 실행기가 더미 텐서를 생성하고 컴파일을 수행한다.

class _NormalWarmupExecutor(_BaseWarmupExecutor):
    def __init__(self, max_m, n, k, num_groups):
        self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
        self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
        self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)

    def execute(self, m):
        deep_gemm.fp8_gemm_nt(
            (self.lhs_q[:m], self.lhs_s[:m]),
            (self.rhs_q, self.rhs_s), self.out[:m])

메모리 예산을 확인하여 OOM을 방지하고, Symmetric Memory 할당을 임시 비활성화하여 단일 GPU에서 안전하게 컴파일한다.

설계 근거

  1. 래퍼 패턴: DeepGEMM API를 직접 노출하지 않고 래핑함으로써, 사전 컴파일, SM 분할, 입력 검증을 투명하게 추가
  2. 점진적 컴파일: 첫 실행 시 전체 M 범위를 컴파일하여 서빙 중 JIT 지연을 제거
  3. 메모리 안전성: 가용 GPU 메모리를 확인하고 초과 시 max_m을 줄여 OOM 방지

관련 포스트

  • Linear Layer: 양자화 통합 선형 레이어의 설계
  • Batch Overlap: 연산-통신 오버랩 최적화

참고

댓글

관련 포스트

SGLang 의 다른글