[SGLang] Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리
들어가며
DeepGEMM은 DeepSeek에서 개발한 FP8 GEMM 라이브러리로, Hopper(SM90) 이상의 GPU에서 최적화된 행렬 곱 연산을 제공한다. SGLang은 deep_gemm_wrapper/ 패키지를 통해 이 라이브러리를 래핑하여 JIT 사전 컴파일, SM 수 제어, Blackwell 지원 등을 추가한다.
구조도
deep_gemm_wrapper/
├── __init__.py ── entrypoint 재수출
├── configurer.py ── GPU 능력 검사, 활성화 플래그
├── entrypoint.py ── GEMM 래퍼 함수 4종
└── compile_utils.py ── JIT 사전 컴파일, 워밍업 실행기
┌─────────────────────┐
│ configurer.py │
│ ENABLE_JIT_DEEPGEMM │
│ DEEPGEMM_BLACKWELL │
└────────┬────────────┘
│
┌───────────────┼───────────────┐
│ │ │
┌─────────▼──────┐ ┌─────▼──────┐ ┌──────▼──────────┐
│ gemm_nt_f8f8 │ │ grouped │ │ compile_utils │
│ bf16 │ │ masked/ │ │ 사전컴파일 & │
│ │ │ contig │ │ 워밍업 │
└────────────────┘ └────────────┘ └─────────────────┘
핵심 코드 분석
활성화 조건 검사 (configurer.py)
DeepGEMM은 SM90(Hopper) 이상에서만 동작한다. configurer.py는 GPU 능력과 환경변수를 확인하여 활성화 여부를 결정한다.
def _compute_enable_deep_gemm():
sm_version = get_device_sm()
if sm_version < 90:
return False
try:
import deep_gemm
except ImportError:
return False
return envs.SGLANG_ENABLE_JIT_DEEPGEMM.get()
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell_supported()
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
Blackwell(SM100) GPU에서는 UE8M0 스케일 포맷이 추가로 활성화된다.
GEMM 래퍼 함수 (entrypoint.py)
entrypoint.py는 4가지 GEMM 연산을 래핑한다. 각 함수는 입력 검증, 커널 타입 결정, 사전 컴파일 훅, SM 수 제어를 처리한다.
def gemm_nt_f8f8bf16(
lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
):
m, k = lhs[0].shape
n, _ = rhs[0].shape
kernel_type = compile_utils.DeepGemmKernelType.GEMM_NT_F8F8BF16
_sanity_check_input(lhs)
_sanity_check_input(rhs)
with compile_utils.deep_gemm_execution_hook(m, n, k, 1, kernel_type):
deep_gemm.fp8_gemm_nt(lhs, rhs, out)
4가지 커널 타입은 다음과 같다.
| 커널 타입 | 설명 | 용도 |
|---|---|---|
GEMM_NT_F8F8BF16 |
일반 FP8 GEMM | Dense 모델 선형 레이어 |
GEMM_NT_BF16BF16F32 |
BF16 GEMM | FP8 미지원 시 |
GROUPED_GEMM_NT_F8F8BF16_MASKED |
Grouped Masked GEMM | MoE 전문가 연산 |
GROUPED_GEMM_NT_F8F8BF16_CONTIG |
Grouped Contiguous GEMM | MoE 연속 레이아웃 |
SM 수 동적 제어
연산-통신 오버랩 시 GPU의 SM(Streaming Multiprocessor)을 분할하여 일부는 연산, 나머지는 통신에 할당한다.
@contextmanager
def configure_deep_gemm_num_sms(num_sms):
if num_sms is None or not ENABLE_JIT_DEEPGEMM:
yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
JIT 사전 컴파일 (compile_utils.py)
DeepGEMM의 JIT 컴파일은 첫 실행 시 10-20분이 소요될 수 있다. SGLang은 서버 시작 시 예상되는 모든 M 값에 대해 사전 컴파일을 수행한다.
class DeepGemmKernelType(IntEnum):
GROUPED_GEMM_NT_F8F8BF16_MASKED = auto()
GROUPED_GEMM_NT_F8F8BF16_CONTIG = auto()
GEMM_NT_F8F8BF16 = auto()
GEMM_NT_BF16BF16F32 = auto()
Fast Warmup 모드에서는 대수적으로 간격을 넓히며 샘플링하여 컴파일 시간을 단축한다.
if _FAST_WARMUP:
_BUILTIN_M_LIST += list(range(1, 1025)) # decode 성능 보장
next_m, sample_step = 1024, 2
while next_m < max_prefill_bs:
_BUILTIN_M_LIST += list(range(next_m, 2 * next_m, sample_step))
next_m = next_m * 2
sample_step = sample_step * 2
이렇게 하면 전체 16K개 M 값 대신 약 3K개만 컴파일하면서도 실제 사용되는 대부분의 배치 크기를 커버한다.
워밍업 실행기
각 커널 타입에 대해 전용 워밍업 실행기가 더미 텐서를 생성하고 컴파일을 수행한다.
class _NormalWarmupExecutor(_BaseWarmupExecutor):
def __init__(self, max_m, n, k, num_groups):
self.lhs_q, self.lhs_s = _empty_token_fp8((max_m, k))
self.rhs_q, self.rhs_s = _empty_block_fp8((n, k))
self.out = torch.empty((max_m, n), device="cuda", dtype=torch.bfloat16)
def execute(self, m):
deep_gemm.fp8_gemm_nt(
(self.lhs_q[:m], self.lhs_s[:m]),
(self.rhs_q, self.rhs_s), self.out[:m])
메모리 예산을 확인하여 OOM을 방지하고, Symmetric Memory 할당을 임시 비활성화하여 단일 GPU에서 안전하게 컴파일한다.
설계 근거
- 래퍼 패턴: DeepGEMM API를 직접 노출하지 않고 래핑함으로써, 사전 컴파일, SM 분할, 입력 검증을 투명하게 추가
- 점진적 컴파일: 첫 실행 시 전체 M 범위를 컴파일하여 서빙 중 JIT 지연을 제거
- 메모리 안전성: 가용 GPU 메모리를 확인하고 초과 시 max_m을 줄여 OOM 방지
관련 포스트
- Linear Layer: 양자화 통합 선형 레이어의 설계
- Batch Overlap: 연산-통신 오버랩 최적화
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] RoPE 변형: 로타리 위치 인코딩의 다양한 구현
- 현재글 : [SGLang] Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리
- 다음글 [SGLang] Sparsity Algorithms: QUEST와 DeepSeek NSA 희소 패턴
댓글