본문으로 건너뛰기

[SGLang] Warmup: GPU 초기화와 JIT 사전 컴파일

들어가며

LLM 서빙에서 서버가 시작되고 첫 번째 요청이 도착하면, JIT 커널 컴파일, CUDA Graph 캡처, Triton 커널 자동 튜닝 등이 발생하여 첫 요청의 지연 시간이 매우 길어진다. SGLang은 서버 시작 시 Warmup 단계를 통해 이 "cold start" 문제를 해결한다.

이 글에서는 python/sglang/srt/entrypoints/warmup.py와 ModelRunner의 초기화 과정을 중심으로 Warmup 시스템을 분석한다.

Warmup이 필요한 이유

서버 시작 후 첫 요청까지 거치는 지연 요소들이다.

서버 시작
  │
  ├── 모델 로딩 (수십 초)          ← ModelRunner.load_model()
  ├── 메모리 풀 할당               ← init_memory_pool()
  ├── Attention 백엔드 초기화       ← init_attention_backend()
  ├── CUDA Graph 캡처 (수십 초)    ← CudaGraphRunner.__init__()
  ├── Piecewise Graph 캡처         ← PiecewiseCudaGraphRunner.__init__()
  ├── JIT 커널 예열                ← kernel_warmup()
  └── 커스텀 Warmup 실행           ← execute_warmups()
  │
  ▼
첫 요청 처리 (즉시 응답 가능)

커스텀 Warmup 레지스트리

SGLang은 데코레이터 기반의 Warmup 레지스트리를 제공한다. 사용자가 도메인별 예열 함수를 등록할 수 있다.

_warmup_registry = {}

def warmup(name: str):
    def decorator(fn):
        _warmup_registry[name] = fn
        return fn
    return decorator

등록된 warmup 함수들은 서버 시작 시 execute_warmups()로 실행된다.

async def execute_warmups(
    disaggregation_mode: str,
    warmup_names: List[str],
    tokenizer_manager: TokenizerManager,
):
    for warmup_name in warmup_names:
        if warmup_name not in _warmup_registry:
            logger.warning(f"Could not find custom warmup {warmup_name}")
            continue
        logger.info(f"Running warmup {warmup_name}")
        await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager)

voice_chat Warmup: 실제 예제

voice_chat warmup은 실시간 음성 채팅을 위해 fused_moe Triton 커널을 사전 캐시한다.

@warmup("voice_chat")
async def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager):
    # fused_moe triton 커널을 사전 캐시
    # 이 예열 없이는 실시간 음성 추론이 깨짐
    for i in tqdm.trange(1, 512):
        size = i * 4
        generate_req_input = GenerateReqInput(
            input_ids=(np.random.randint(2**16, size=[size])).tolist(),
            sampling_params={
                "max_new_tokens": 30,
                "temperature": 0.8,
                "stop_token_ids": [1],
                "min_p": 0.0,
            },
        )
        if disaggregation_mode != "null":
            generate_req_input.bootstrap_room = 0
            generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST

        await tokenizer_manager.generate_request(generate_req_input, None).__anext__()

1부터 511까지 다양한 입력 크기로 요청을 보내 MoE 커널의 모든 형상을 사전 컴파일한다. size = i * 4이므로 4토큰부터 2044토큰까지 4토큰 간격으로 예열한다.

ModelRunner의 내장 Warmup

ModelRunner 초기화 과정 자체가 핵심 warmup이다. initialize() 메서드에서 순차적으로 수행된다.

1. 커널 Warmup

커스텀 CUDA/Triton 커널을 사전 실행하여 JIT 캐시를 채운다.

def initialize(self, pre_model_load_memory):
    ...
    if self.device == "cuda":
        self.init_cublas()
        self.init_attention_backend()
        self.kernel_warmup()       # ← 커널 예열
        self.init_device_graphs()  # ← CUDA Graph 캡처

2. CUDA Graph 캡처

CudaGraphRunner 생성 시 모든 배치 크기에 대해 CUDA Graph를 캡처한다.

class CudaGraphRunner:
    def __init__(self, model_runner):
        self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)

        # 모든 배치 크기에 대해 캡처
        with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
            with graph_capture() as graph_capture_context:
                self.stream = graph_capture_context.stream
                for bs in reversed(self.capture_bs):
                    graph, output_buffers = self.capture_one_batch_size(bs, forward)
                    self.graphs[bs] = graph
                    self.output_buffers[bs] = output_buffers

3. Piecewise CUDA Graph 컴파일

Prefill용 Piecewise Graph도 서버 시작 시 컴파일한다.

class PiecewiseCudaGraphRunner:
    def __init__(self, model_runner):
        with enable_piecewise_cuda_graph():
            # Warmup: JIT 커널 예열
            self.warmup_compile(num_tokens=self.capture_num_tokens[0])

            # torch.compile 설치 후 모든 크기에 대해 컴파일
            install_torch_compiled(patched_model, ...)

            with enable_piecewise_cuda_graph_compile():
                for num_tokens in reversed(self.capture_num_tokens):
                    self.warmup_compile(num_tokens=num_tokens)

            # 최종 CUDA Graph 캡처
            self.capture()

Warmup 타임라인

전체 Warmup 과정을 시간 순서로 정리한다.

t=0    모델 가중치 로딩 시작
       │
t=T1   메모리 풀 할당
       │
t=T2   Attention 백엔드 초기화
       │
t=T3   커널 Warmup (Triton JIT 캐시)
       │
t=T4   CUDA Graph 캡처
       │  ┌─ bs=1024 캡처
       │  ├─ bs=512 캡처
       │  ├─ ...
       │  └─ bs=1 캡처
       │
t=T5   Piecewise CUDA Graph 컴파일+캡처
       │  ┌─ torch.compile 설치
       │  ├─ 각 토큰 수별 warmup_compile
       │  └─ 최종 캡처
       │
t=T6   커스텀 Warmup 실행 (선택적)
       │  └─ voice_chat, 기타 등록된 warmup
       │
t=T7   서버 Ready → 첫 요청 처리 가능

Disaggregation 모드 지원

Warmup은 Prefill/Decode 분리(Disaggregation) 환경도 고려한다.

async def execute_warmups(
    disaggregation_mode: str,
    warmup_names: List[str],
    tokenizer_manager: TokenizerManager,
):

Disaggregation 모드에서는 bootstrap 정보를 더미 값으로 설정한다.

if disaggregation_mode != "null":
    generate_req_input.bootstrap_room = 0
    generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST

설계 근거: 레지스트리 패턴의 장점

왜 Warmup을 레지스트리 패턴으로 구현했는가?

특징 설명
확장성 @warmup("name") 데코레이터로 간단 등록
선택적 실행 서버 설정에서 필요한 warmup만 지정
도메인 특화 voice_chat, vision 등 모델별 커스텀 예열
비동기 지원 async 함수로 논블로킹 실행

커스텀 warmup은 서버 설정의 warmup_names 리스트로 제어된다. 불필요한 예열을 건너뛰어 시작 시간을 단축할 수 있다.

Warmup 없이 발생하는 문제

Warmup을 건너뛰면 다음 문제가 발생한다.

  1. 첫 요청 지연: Triton 커널의 JIT 컴파일로 수 초~수십 초 지연
  2. CUDA Graph 미캡처: 첫 Decode 배치에서 커널 런칭 오버헤드 발생
  3. MoE 커널 미캐시: 음성 채팅 같은 실시간 서비스에서 타임아웃
  4. 불안정한 성능: 초기 요청들의 레이턴시가 불규칙

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글