[vLLM] Tree Attention: 투기적 디코딩용 트리 어텐션
들어가며
투기적 디코딩에서 드래프트 토큰을 검증할 때, 토큰들이 선형 시퀀스가 아닌 트리 구조를 형성한다. 예를 들어 첫 위치에서 3개, 각각에서 다시 2개의 후보를 생성하면 트리가 된다. vLLM의 Tree Attention(vllm/v1/attention/backends/tree_attn.py)은 이 트리 구조에 맞는 어텐션 마스크를 생성하여 한 번의 forward pass로 모든 경로를 검증한다.
핵심 구조/코드 분석
트리 어텐션 바이어스 생성
def _prepare_tree_attn_bias(sorted_tree_choices, depth_counts, dtype, device):
tree_len = len(sorted_tree_choices) + 1 # +1: 루트 노드
tree_attn_mask = torch.full((tree_len, tree_len), -torch.inf, device=device, dtype=dtype)
# 대각선: 자기 자신에 attend
for i in range(tree_len):
tree_attn_mask[i, i] = 0
# 루트: 모든 토큰이 attend
tree_attn_mask[:, 0] = 0
# 조상 노드에 attend
for i in range(len(depth_counts)):
for j in range(depth_counts[i]):
cur_tree_choice = sorted_tree_choices[start + j]
if len(cur_tree_choice) == 1:
continue
ancestor_idx = []
for c in range(len(cur_tree_choice) - 1):
ancestor_idx.append(
sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1
)
tree_attn_mask[j + start + 1, ancestor_idx] = 0
핵심 아이디어는 -inf 마스크다. 트리에서 각 토큰은 자신, 루트, 그리고 자신의 조상 노드에만 attend할 수 있다. 형제 노드나 다른 경로의 토큰은 -inf로 마스킹되어 softmax에서 0이 된다.
TreeAttentionMetadataBuilder
class TreeAttentionMetadataBuilder(AttentionMetadataBuilder):
def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
spec_token_tree = spec.speculative_token_tree if spec else None
tree_choices = ast.literal_eval(spec_token_tree) if spec_token_tree else [(0,)]
depth_counts = _get_depth_counts(tree_choices)
self.tree_attn_bias = _prepare_tree_attn_bias(
tree_choices, depth_counts, dtype=torch.float32, device=device,
)
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
speculative_token_tree는 트리 구조를 Python 리스트로 표현한다. 예: [(0,), (0,0), (0,1), (1,), (1,0)]. 각 튜플은 루트에서의 경로를 나타낸다. reorder_batch_threshold는 트리 크기와 동일하게 설정되어, 트리 크기 이하의 쿼리를 가진 요청을 decode로 분류한다.
Prefill/Decode 분리 처리
def forward(self, layer, query, key, value, kv_cache, attn_metadata, output, ...):
key_cache, value_cache = kv_cache.unbind(0)
if prefill_meta := attn_metadata.prefill_metadata:
unified_attention(
q=query[num_decode_tokens:num_actual_tokens],
k=key_cache, v=value_cache,
out=output[num_decode_tokens:num_actual_tokens],
causal=True, # Prefill은 causal attention
...
)
if decode_meta := attn_metadata.decode_metadata:
unified_attention(
q=query[:num_decode_tokens],
k=key_cache, v=value_cache,
out=output[:num_decode_tokens],
causal=True,
qq_bias=decode_meta.tree_attn_bias, # 트리 바이어스 적용
...
)
Prefill은 일반 causal attention이고, decode에서만 qq_bias로 트리 어텐션 바이어스를 적용한다. unified_attention은 Triton 기반 통합 어텐션 커널이다.
드래프팅용 메타데이터
def build_for_drafting(self, common_attn_metadata, draft_index):
orig_tree_attn_bias = self.tree_attn_bias
if draft_index == 0:
self.tree_attn_bias = torch.empty(0) # 루트: prefill 사용
else:
start, end = 1, 1 + common_attn_metadata.max_query_len
self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
self.tree_attn_bias = orig_tree_attn_bias # 복원
return attn_metadata
드래프팅 시 draft_index=0(루트)에서는 일반 prefill을 사용하고, 이후 인덱스에서는 트리 바이어스의 서브트리를 잘라서 사용한다.
KV Cache 형상
@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str="auto"):
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
(2, num_blocks, block_size, num_kv_heads, head_size) 형상에서 첫 차원 2는 K와 V다. 블록 크기는 반드시 16의 배수여야 한다.
왜 이 설계인가
-
qq_bias를 통한 트리 마스크: Q-Q 바이어스로 트리 마스크를 구현하면, 기존 어텐션 커널(
unified_attention)을 수정하지 않고 트리 어텐션을 적용할 수 있다.-inf값이 softmax를 통과하면서 자연스럽게 마스킹된다. -
사전 계산된 바이어스: 트리 구조는 요청 간에 변하지 않으므로(동일한 투기 설정), 바이어스를 초기화 시 한 번만 계산하고 재사용한다. 매 디코딩 스텝마다 마스크를 재생성하는 오버헤드가 없다.
-
forward_includes_kv_cache_update = False: 일반 어텐션 백엔드는 forward 안에서 KV 캐시 업데이트를 포함하지만, Tree Attention은
do_kv_cache_update를 별도 메서드로 분리했다. 이는 투기적 디코딩에서 검증과 캐시 업데이트의 타이밍이 다를 수 있기 때문이다.
참고 자료
관련 포스트
vLLM 의 다른글
- 이전글 [vLLM] 기타 Model Layers: Pooler, Resampler, Vocab Parallel Embedding 등
- 현재글 : [vLLM] Tree Attention: 투기적 디코딩용 트리 어텐션
- 다음글 [vLLM] Warmup: 커널 JIT 사전 컴파일
댓글