[SGLang] DFlash: Flash 기반 고속 드래프팅
들어가며
DFlash는 SGLang의 speculative decoding 알고리즘 중 하나로, 드래프트 모델이 블록 단위로 토큰을 한 번에 생성하는 방식이다. EAGLE이 autoregressive하게 한 토큰씩 드래프트를 생성하는 반면, DFlash는 mask token을 사용하여 블록 전체를 병렬로 생성한다. 이를 통해 드래프팅 단계 자체의 latency를 크게 줄인다. 검증은 선형(non-tree) 구조로 수행하며, 타겟 모델의 hidden states를 활용하여 드래프트 KV 캐시를 효율적으로 관리한다.
구조도
┌──────────────────────────────────────────────────────┐
│ DFlashWorker │
│ │
│ ┌──────────────┐ ┌──────────────────────────┐ │
│ │ Target Worker │ │ Draft Worker (TpModelWorker│ │
│ │ (target_worker│ │ 별도 attention backend) │ │
│ └──────┬───────┘ └───────────┬──────────────┘ │
│ │ │ │
│ Target Extend: Draft: │
│ hidden_states ─────────▶ target_hidden │
│ │ │ │
│ │ ┌────▼─────────┐ │
│ │ │ Block Draft │ │
│ │ │ [M][M]...[M] │ │
│ │ │ mask_token_id │ │
│ │ │ → block_size │ │
│ │ │ tokens │ │
│ │ └────┬──────────┘ │
│ │ │ │
│ │ ┌────▼─────────┐ │
│ │ │DFlashVerify │ │
│ │ │Input │ │
│ ▼ │ .draft_token │ │
│ ┌──────────────┐ │ .positions │ │
│ │Target Verify │◀────│ .custom_mask │ │
│ │(linear verify)│ └──────────────┘ │
│ └──────┬───────┘ │
│ │ │
│ ▼ │
│ accept_len + bonus → DFlashDraftInput (다음 step) │
└──────────────────────────────────────────────────────┘
핵심 코드 분석
1. DFlashWorker 초기화
python/sglang/srt/speculative/dflash_worker.py에서 DFlashWorker는 별도의 드래프트 워커를 생성한다.
class DFlashWorker:
def __init__(self, server_args, ..., target_worker):
self.target_worker = target_worker
self.draft_window_size = server_args.speculative_dflash_draft_window_size
self.use_compact_draft_cache = self.draft_window_size is not None
draft_server_args = deepcopy(server_args)
draft_server_args.attention_backend = draft_backend # flashinfer/fa3/fa4
self.draft_worker = TpModelWorker(
server_args=draft_server_args, is_draft_worker=True,
req_to_token_pool=shared_req_to_token_pool,
token_to_kv_pool_allocator=target_token_to_kv_pool_allocator,
)
self.block_size = server_args.speculative_num_draft_tokens
draft_window_size가 설정되면 compact draft cache 모드가 활성화되어, 최근 윈도우만 드래프트 KV 캐시에 유지한다. 드래프트 attention backend는 flashinfer, fa3, fa4만 지원한다.
2. Mask Token 기반 블록 드래프팅
DFlash의 핵심은 mask token을 사용한 병렬 드래프트 생성이다.
self._mask_token = draft_config.mask_token
self._mask_token_id_override = draft_config.mask_token_id
self._mask_token_id = self._resolve_mask_token_id(
mask_token=self._mask_token,
mask_token_id=self._mask_token_id_override,
)
드래프트 모델은 [mask] 토큰 블록을 입력으로 받아 한 번의 forward로 block_size개의 토큰을 동시에 예측한다. EAGLE처럼 step별로 반복하지 않아 드래프팅 latency가 줄어든다.
3. DFlashDraftInput: 상태 관리
python/sglang/srt/speculative/dflash_info.py에서 드래프트 상태를 관리한다.
@dataclass
class DFlashDraftInput(SpecInput):
verified_id: torch.Tensor # 현재 토큰 (요청당 1개)
target_hidden: torch.Tensor # [sum(ctx_lens), K * hidden_size]
ctx_lens: torch.Tensor # 요청당 context 길이
draft_seq_lens: torch.Tensor # 드래프트 KV 캐시에 반영된 길이
def __post_init__(self):
super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT)
target_hidden은 타겟 모델에서 추출한 hidden states로, 드래프트 모델의 KV 캐시를 직접 구성하는 데 사용된다. ctx_lens가 요청별로 몇 개의 새 토큰이 아직 드래프트 캐시에 반영되지 않았는지를 추적한다.
4. DFlashVerifyInput: 선형 검증
@dataclass
class DFlashVerifyInput(SpecInput):
draft_token: torch.Tensor
positions: torch.Tensor
draft_token_num: int
topk: int = 1 # DFlash는 항상 선형(non-tree)
custom_mask: torch.Tensor | None = None
topk = 1로 고정된 것이 DFlash의 특징이다. 트리 구조가 아닌 선형 체인으로 검증하여 구현이 단순하고, custom mask는 표준 causal masking을 따른다.
5. 검증 로직: accept_len + bonus
def verify(self, *, batch, logits_output, page_size):
candidates = self.draft_token.view(bs, self.draft_token_num)
if not sampling_info.is_all_greedy and is_dflash_sampling_verify_available():
accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus(
candidates=candidates,
next_token_logits=logits_output.next_token_logits,
sampling_info=sampling_info,
)
else:
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
accept_len, bonus = compute_dflash_accept_len_and_bonus(
candidates=candidates, target_predict=target_predict.view(bs, self.draft_token_num),
)
DFlash는 선형 체인이므로 검증이 단순하다. 첫 번째 불일치 위치까지를 accept_len으로, 불일치 위치의 타겟 예측을 bonus 토큰으로 반환한다.
6. Fused KV Materialization
DFlash는 타겟의 hidden states에서 드래프트 KV 캐시를 직접 구성하는 fused 경로를 제공한다.
self._use_fused_kv_materialize = is_cuda()
if self._use_fused_kv_materialize:
self._init_fused_kv_helper()
def _init_fused_kv_helper(self):
FusedKVMaterializeHelper = _get_fused_kv_materialize_helper()
self._fused_kv_helper = FusedKVMaterializeHelper(
layers=layers, rotary_emb=rotary_emb,
num_kv_heads=first_attn.num_kv_heads,
head_dim=first_attn.head_dim, device=self.device,
)
Fused KV materialization은 QKV projection + RoPE + KV cache write를 하나의 Triton 커널로 통합하여, 레이어별 sequential 처리를 병렬화한다.
7. 페이지 정렬 캐시 해제
def _compute_paged_keep_slots(*, prefix_lens, commit_lens, draft_token_num, page_size):
extended_lens = prefix_lens + int(draft_token_num)
new_lens = prefix_lens + commit_lens.to(seq_dtype)
aligned_new_lens = ((new_lens + page_size - 1) // page_size) * page_size
keep_lens = torch.minimum(aligned_new_lens, extended_lens)
keep_slots = (keep_lens - prefix_lens).to(torch.int64)
keep_slots.clamp_(min=0, max=int(draft_token_num))
return keep_slots
paged KV cache에서는 페이지 단위로만 해제할 수 있으므로, 수락된 토큰 이후의 전체 페이지만 반환한다.
EAGLE vs DFlash 비교
| 항목 | EAGLE | DFlash |
|---|---|---|
| 드래프팅 방식 | Autoregressive (step별) | 블록 병렬 (mask token) |
| 드래프팅 latency | N * forward | 1 * forward |
| 검증 구조 | 트리 (topk > 1) | 선형 (topk = 1) |
| KV 캐시 구성 | 드래프트 모델 자체 생성 | 타겟 hidden states에서 직접 구성 |
| Overlap V2 지원 | 예 | 아니오 |
| 적합한 모델 | 범용 | 블록 드래프팅 지원 모델 |
설계 근거
블록 드래프팅: autoregressive 드래프트는 step마다 GPU idle time이 발생한다. 블록 단위 생성은 한 번의 forward로 전체 드래프트를 완료하여 latency를 최소화한다.
선형 검증의 단순함: 트리 검증은 복잡한 인덱스 관리가 필요하지만, 선형 검증은 첫 불일치까지만 찾으면 되므로 구현이 간단하고 커널 오버헤드가 적다.
Draft Windowing: draft_window_size로 드래프트 KV 캐시를 최근 윈도우로 제한하여, 긴 시퀀스에서도 메모리 사용량을 일정하게 유지한다.
관련 포스트
참고
관련 포스트
- [sglang] DeepSeek-V4의 Latency 최적화: Fused mHC Post/Pre Kernel 도입
- [sglang] sglang ROCm MXFP4 어텐션에서 불필요한 contiguous copy 제거를 통한 성능 최적화
- [sglang] sglang의 torch.compile 활용: Advanced Indexing Gather 최적화로 LLM 추론 가속화
- [sglang] sglang diffusion 모델 성능 향상: Cache-DiT와 torch.compile의 최적화된 적용 순서
- [sglang] NixlKVManager 성능 향상: 비동기 및 멀티스레드 KV 전송 도입
SGLang 의 다른글
- 이전글 [SGLang] N-gram Draft: 모델 프리 투기적 디코딩
- 현재글 : [SGLang] DFlash: Flash 기반 고속 드래프팅
- 다음글 [SGLang] EAGLE CUDA Graph: 드래프트 모델 가속
댓글