[SGLang] Warmup: GPU 초기화와 JIT 사전 컴파일
들어가며
LLM 서빙에서 서버가 시작되고 첫 번째 요청이 도착하면, JIT 커널 컴파일, CUDA Graph 캡처, Triton 커널 자동 튜닝 등이 발생하여 첫 요청의 지연 시간이 매우 길어진다. SGLang은 서버 시작 시 Warmup 단계를 통해 이 "cold start" 문제를 해결한다.
이 글에서는 python/sglang/srt/entrypoints/warmup.py와 ModelRunner의 초기화 과정을 중심으로 Warmup 시스템을 분석한다.
Warmup이 필요한 이유
서버 시작 후 첫 요청까지 거치는 지연 요소들이다.
서버 시작
│
├── 모델 로딩 (수십 초) ← ModelRunner.load_model()
├── 메모리 풀 할당 ← init_memory_pool()
├── Attention 백엔드 초기화 ← init_attention_backend()
├── CUDA Graph 캡처 (수십 초) ← CudaGraphRunner.__init__()
├── Piecewise Graph 캡처 ← PiecewiseCudaGraphRunner.__init__()
├── JIT 커널 예열 ← kernel_warmup()
└── 커스텀 Warmup 실행 ← execute_warmups()
│
▼
첫 요청 처리 (즉시 응답 가능)
커스텀 Warmup 레지스트리
SGLang은 데코레이터 기반의 Warmup 레지스트리를 제공한다. 사용자가 도메인별 예열 함수를 등록할 수 있다.
_warmup_registry = {}
def warmup(name: str):
def decorator(fn):
_warmup_registry[name] = fn
return fn
return decorator
등록된 warmup 함수들은 서버 시작 시 execute_warmups()로 실행된다.
async def execute_warmups(
disaggregation_mode: str,
warmup_names: List[str],
tokenizer_manager: TokenizerManager,
):
for warmup_name in warmup_names:
if warmup_name not in _warmup_registry:
logger.warning(f"Could not find custom warmup {warmup_name}")
continue
logger.info(f"Running warmup {warmup_name}")
await _warmup_registry[warmup_name](disaggregation_mode, tokenizer_manager)
voice_chat Warmup: 실제 예제
voice_chat warmup은 실시간 음성 채팅을 위해 fused_moe Triton 커널을 사전 캐시한다.
@warmup("voice_chat")
async def voice_chat(disaggregation_mode: str, tokenizer_manager: TokenizerManager):
# fused_moe triton 커널을 사전 캐시
# 이 예열 없이는 실시간 음성 추론이 깨짐
for i in tqdm.trange(1, 512):
size = i * 4
generate_req_input = GenerateReqInput(
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
sampling_params={
"max_new_tokens": 30,
"temperature": 0.8,
"stop_token_ids": [1],
"min_p": 0.0,
},
)
if disaggregation_mode != "null":
generate_req_input.bootstrap_room = 0
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
1부터 511까지 다양한 입력 크기로 요청을 보내 MoE 커널의 모든 형상을 사전 컴파일한다. size = i * 4이므로 4토큰부터 2044토큰까지 4토큰 간격으로 예열한다.
ModelRunner의 내장 Warmup
ModelRunner 초기화 과정 자체가 핵심 warmup이다. initialize() 메서드에서 순차적으로 수행된다.
1. 커널 Warmup
커스텀 CUDA/Triton 커널을 사전 실행하여 JIT 캐시를 채운다.
def initialize(self, pre_model_load_memory):
...
if self.device == "cuda":
self.init_cublas()
self.init_attention_backend()
self.kernel_warmup() # ← 커널 예열
self.init_device_graphs() # ← CUDA Graph 캡처
2. CUDA Graph 캡처
CudaGraphRunner 생성 시 모든 배치 크기에 대해 CUDA Graph를 캡처한다.
class CudaGraphRunner:
def __init__(self, model_runner):
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
# 모든 배치 크기에 대해 캡처
with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in reversed(self.capture_bs):
graph, output_buffers = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.output_buffers[bs] = output_buffers
3. Piecewise CUDA Graph 컴파일
Prefill용 Piecewise Graph도 서버 시작 시 컴파일한다.
class PiecewiseCudaGraphRunner:
def __init__(self, model_runner):
with enable_piecewise_cuda_graph():
# Warmup: JIT 커널 예열
self.warmup_compile(num_tokens=self.capture_num_tokens[0])
# torch.compile 설치 후 모든 크기에 대해 컴파일
install_torch_compiled(patched_model, ...)
with enable_piecewise_cuda_graph_compile():
for num_tokens in reversed(self.capture_num_tokens):
self.warmup_compile(num_tokens=num_tokens)
# 최종 CUDA Graph 캡처
self.capture()
Warmup 타임라인
전체 Warmup 과정을 시간 순서로 정리한다.
t=0 모델 가중치 로딩 시작
│
t=T1 메모리 풀 할당
│
t=T2 Attention 백엔드 초기화
│
t=T3 커널 Warmup (Triton JIT 캐시)
│
t=T4 CUDA Graph 캡처
│ ┌─ bs=1024 캡처
│ ├─ bs=512 캡처
│ ├─ ...
│ └─ bs=1 캡처
│
t=T5 Piecewise CUDA Graph 컴파일+캡처
│ ┌─ torch.compile 설치
│ ├─ 각 토큰 수별 warmup_compile
│ └─ 최종 캡처
│
t=T6 커스텀 Warmup 실행 (선택적)
│ └─ voice_chat, 기타 등록된 warmup
│
t=T7 서버 Ready → 첫 요청 처리 가능
Disaggregation 모드 지원
Warmup은 Prefill/Decode 분리(Disaggregation) 환경도 고려한다.
async def execute_warmups(
disaggregation_mode: str,
warmup_names: List[str],
tokenizer_manager: TokenizerManager,
):
Disaggregation 모드에서는 bootstrap 정보를 더미 값으로 설정한다.
if disaggregation_mode != "null":
generate_req_input.bootstrap_room = 0
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
설계 근거: 레지스트리 패턴의 장점
왜 Warmup을 레지스트리 패턴으로 구현했는가?
| 특징 | 설명 |
|---|---|
| 확장성 | @warmup("name") 데코레이터로 간단 등록 |
| 선택적 실행 | 서버 설정에서 필요한 warmup만 지정 |
| 도메인 특화 | voice_chat, vision 등 모델별 커스텀 예열 |
| 비동기 지원 | async 함수로 논블로킹 실행 |
커스텀 warmup은 서버 설정의 warmup_names 리스트로 제어된다. 불필요한 예열을 건너뛰어 시작 시간을 단축할 수 있다.
Warmup 없이 발생하는 문제
Warmup을 건너뛰면 다음 문제가 발생한다.
- 첫 요청 지연: Triton 커널의 JIT 컴파일로 수 초~수십 초 지연
- CUDA Graph 미캡처: 첫 Decode 배치에서 커널 런칭 오버헤드 발생
- MoE 커널 미캐시: 음성 채팅 같은 실시간 서비스에서 타임아웃
- 불안정한 성능: 초기 요청들의 레이턴시가 불규칙
관련 포스트
- CUDA Graphs: 커널 런칭 오버헤드 제거
- Piecewise CUDA Graph: 분할 그래프 컴파일 전략
- torch.compile & Inductor: PyTorch 컴파일러 통합
- Model Runner: 포워드 패스 실행 엔진의 핵심
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] torch.compile & Inductor: PyTorch 컴파일러 통합
- 현재글 : [SGLang] Warmup: GPU 초기화와 JIT 사전 컴파일
- 다음글 [SGLang] FP8: 8비트 부동소수점 양자화의 구현과 성능
댓글