본문으로 건너뛰기

[SGLang] ForwardBatch: ScheduleBatch에서 GPU 텐서로의 변환

들어가며

SGLang의 배치 처리는 3단계 데이터 변환을 거친다. ScheduleBatch(CPU 스케줄링) → ModelWorkerBatch(GPU 전송 준비) → ForwardBatch(GPU 텐서). ForwardBatch는 이 파이프라인의 최종 단계로, GPU에서 모델이 직접 소비하는 텐서 집합이다.

이 글에서는 python/sglang/srt/model_executor/forward_batch_info.py를 중심으로 ForwardBatch의 설계를 분석한다.

배치 데이터 흐름

소스 코드의 주석이 이 흐름을 명확히 설명한다.

ScheduleBatch → ModelWorkerBatch → ForwardBatch

- ScheduleBatch: scheduler.py에서 관리. 고수준 스케줄링 데이터.
  대부분의 데이터가 CPU에 존재.
- ModelWorkerBatch: tp_worker.py에서 관리.
  GPU 모델 포워드에 필요한 데이터의 부분집합.
  CPU에서 GPU로 전환되는 단계.
- ForwardBatch: model_runner.py에서 관리.
  저수준 텐서 데이터. 대부분 GPU 텐서.

ForwardMode: 실행 모드 열거형

ForwardBatch의 동작은 ForwardMode에 의해 결정된다.

class ForwardMode(IntEnum):
    EXTEND = auto()          # Prefill (시스템 프롬프트 처리)
    DECODE = auto()          # 토큰 1개 생성
    MIXED = auto()           # Chunked Prefill (EXTEND + DECODE 혼합)
    IDLE = auto()            # DP Attention 패딩용 빈 배치
    TARGET_VERIFY = auto()   # Speculative Decoding 검증
    DRAFT_EXTEND = auto()    # Draft 모델 확장
    SPLIT_PREFILL = auto()   # PD Multiplexing 분할 프리필
    DLLM_EXTEND = auto()     # Diffusion LLM

모드별 헬퍼 메서드가 분기 로직을 단순하게 만든다.

def is_extend(self, include_draft_extend_v2=False):
    return (
        self == ForwardMode.EXTEND
        or self == ForwardMode.MIXED
        or self == ForwardMode.DRAFT_EXTEND
        or self == ForwardMode.TARGET_VERIFY
        or self == ForwardMode.SPLIT_PREFILL
        or self == ForwardMode.DLLM_EXTEND
    )

def is_cuda_graph(self):
    return (
        self == ForwardMode.DECODE
        or self == ForwardMode.TARGET_VERIFY
        or self == ForwardMode.IDLE
        or self == ForwardMode.DLLM_EXTEND
    )

ForwardBatch 핵심 필드

ForwardBatch는 @dataclass로 정의되며, 60개 이상의 필드를 가진다. 핵심 필드를 분류하면 다음과 같다.

┌─────────────── 공통 필드 ────────────────┐
 forward_mode    : ForwardMode           
 batch_size      : int                   
 input_ids       : torch.Tensor          
 req_pool_indices: torch.Tensor          
 seq_lens        : torch.Tensor          
 out_cache_loc   : torch.Tensor          
 positions       : torch.Tensor          
 seq_lens_sum    : int                   
├─────────────── Extend 전용 ─────────────┤
 extend_num_tokens    : int              
 extend_seq_lens      : torch.Tensor     
 extend_prefix_lens   : torch.Tensor     
 extend_start_loc     : torch.Tensor     
├─────────────── Decode 전용 ─────────────┤
 (공통 필드로 충분)                        
├─────────────── DP Attention ────────────┤
 global_num_tokens_gpu: torch.Tensor     
 dp_padding_mode      : DpPaddingMode    
 can_run_dp_cuda_graph: bool             
├─────────────── Speculative ─────────────┤
 spec_info            : SpecInput        
 spec_algorithm       : SpeculativeAlgorithm│
 capture_hidden_mode  : CaptureHiddenMode│
└─────────────────────────────────────────┘

init_new: 배치 변환의 핵심

ForwardBatch.init_new()ModelWorkerBatch를 받아 ForwardBatch를 구성한다. 이 과정에서 CPU 데이터를 GPU 텐서로 변환한다.

@classmethod
def init_new(cls, batch: ModelWorkerBatch, model_runner: ModelRunner):
    ret = cls(
        forward_mode=batch.forward_mode,
        batch_size=len(batch.seq_lens),
        input_ids=batch.input_ids,
        req_pool_indices=batch.req_pool_indices,
        seq_lens=batch.seq_lens,
        out_cache_loc=batch.out_cache_loc,
        seq_lens_sum=batch.seq_lens_sum,
        req_to_token_pool=model_runner.req_to_token_pool,
        token_to_kv_pool=model_runner.token_to_kv_pool,
        attn_backend=model_runner.attn_backend,
        ...
    )

ModelRunner의 메모리 풀과 Attention 백엔드 참조가 이 단계에서 주입된다. ForwardBatch는 이 참조를 통해 KV 캐시에 접근한다.

Position 계산

모드에 따라 Position 계산 방식이 달라진다.

Decode 모드 - 시퀀스 길이에서 1을 빼서 마지막 위치를 구한다.

if ret.forward_mode.is_decode() or ret.forward_mode.is_target_verify():
    if ret.positions is None:
        ret.positions = clamp_position(batch.seq_lens)

Extend 모드 - prefix_lens와 seq_lens로 연속 position을 계산한다.

else:
    ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, dtype=torch.int32).to(device)
    ret.extend_prefix_lens = torch.tensor(batch.extend_prefix_lens, dtype=torch.int32).to(device)
    positions, ret.extend_start_loc = compute_position(
        model_runner.server_args.attention_backend,
        ret.extend_prefix_lens,
        ret.extend_seq_lens,
        ret.extend_num_tokens,
    )

DP Attention을 위한 global_num_tokens

Data Parallel Attention에서는 각 DP 워커의 토큰 수를 공유해야 한다. global_num_tokens_gpu가 이 역할을 한다.

if batch.global_num_tokens is not None:
    global_num_tokens = batch.global_num_tokens
    ret.global_num_tokens_gpu = torch.tensor(
        global_num_tokens, dtype=torch.int64
    ).to(device, non_blocking=True)

CaptureHiddenMode: Hidden State 캡처

Speculative Decoding에서 target 모델의 hidden state를 draft 모델에 전달해야 한다. CaptureHiddenMode가 캡처 범위를 결정한다.

class CaptureHiddenMode(IntEnum):
    NULL = 0   # 캡처 안 함
    LAST = 1   # 마지막 토큰만 캡처
    FULL = 2   # 모든 토큰 캡처

    def need_capture(self):
        return self != CaptureHiddenMode.NULL

설계 근거: 3단계 변환이 필요한 이유

왜 ScheduleBatch에서 바로 GPU 텐서를 만들지 않는가? 세 가지 관심사를 분리하기 위해서이다.

단계 관심사 실행 위치
ScheduleBatch 요청 우선순위, 프리필/디코드 분류 CPU
ModelWorkerBatch TP 분산에 필요한 최소 데이터 CPU → GPU 전송
ForwardBatch 텐서 레이아웃, Attention 메타데이터 GPU

이 분리 덕분에 Scheduler는 GPU 텐서 레이아웃을 몰라도 되고, ModelRunner는 스케줄링 정책을 몰라도 된다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글