본문으로 건너뛰기

[vLLM] Tree Attention: 투기적 디코딩용 트리 어텐션

들어가며

투기적 디코딩에서 드래프트 토큰을 검증할 때, 토큰들이 선형 시퀀스가 아닌 트리 구조를 형성한다. 예를 들어 첫 위치에서 3개, 각각에서 다시 2개의 후보를 생성하면 트리가 된다. vLLM의 Tree Attention(vllm/v1/attention/backends/tree_attn.py)은 이 트리 구조에 맞는 어텐션 마스크를 생성하여 한 번의 forward pass로 모든 경로를 검증한다.

핵심 구조/코드 분석

트리 어텐션 바이어스 생성

def _prepare_tree_attn_bias(sorted_tree_choices, depth_counts, dtype, device):
    tree_len = len(sorted_tree_choices) + 1  # +1: 루트 노드
    tree_attn_mask = torch.full((tree_len, tree_len), -torch.inf, device=device, dtype=dtype)

    # 대각선: 자기 자신에 attend
    for i in range(tree_len):
        tree_attn_mask[i, i] = 0

    # 루트: 모든 토큰이 attend
    tree_attn_mask[:, 0] = 0

    # 조상 노드에 attend
    for i in range(len(depth_counts)):
        for j in range(depth_counts[i]):
            cur_tree_choice = sorted_tree_choices[start + j]
            if len(cur_tree_choice) == 1:
                continue
            ancestor_idx = []
            for c in range(len(cur_tree_choice) - 1):
                ancestor_idx.append(
                    sorted_tree_choices.index(cur_tree_choice[:c + 1]) + 1
                )
            tree_attn_mask[j + start + 1, ancestor_idx] = 0

핵심 아이디어는 -inf 마스크다. 트리에서 각 토큰은 자신, 루트, 그리고 자신의 조상 노드에만 attend할 수 있다. 형제 노드나 다른 경로의 토큰은 -inf로 마스킹되어 softmax에서 0이 된다.

TreeAttentionMetadataBuilder

class TreeAttentionMetadataBuilder(AttentionMetadataBuilder):
    def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
        spec_token_tree = spec.speculative_token_tree if spec else None
        tree_choices = ast.literal_eval(spec_token_tree) if spec_token_tree else [(0,)]
        depth_counts = _get_depth_counts(tree_choices)
        self.tree_attn_bias = _prepare_tree_attn_bias(
            tree_choices, depth_counts, dtype=torch.float32, device=device,
        )
        self.reorder_batch_threshold = self.tree_attn_bias.shape[0]

speculative_token_tree는 트리 구조를 Python 리스트로 표현한다. 예: [(0,), (0,0), (0,1), (1,), (1,0)]. 각 튜플은 루트에서의 경로를 나타낸다. reorder_batch_threshold는 트리 크기와 동일하게 설정되어, 트리 크기 이하의 쿼리를 가진 요청을 decode로 분류한다.

Prefill/Decode 분리 처리

def forward(self, layer, query, key, value, kv_cache, attn_metadata, output, ...):
    key_cache, value_cache = kv_cache.unbind(0)

    if prefill_meta := attn_metadata.prefill_metadata:
        unified_attention(
            q=query[num_decode_tokens:num_actual_tokens],
            k=key_cache, v=value_cache,
            out=output[num_decode_tokens:num_actual_tokens],
            causal=True,  # Prefill은 causal attention
            ...
        )

    if decode_meta := attn_metadata.decode_metadata:
        unified_attention(
            q=query[:num_decode_tokens],
            k=key_cache, v=value_cache,
            out=output[:num_decode_tokens],
            causal=True,
            qq_bias=decode_meta.tree_attn_bias,  # 트리 바이어스 적용
            ...
        )

Prefill은 일반 causal attention이고, decode에서만 qq_bias로 트리 어텐션 바이어스를 적용한다. unified_attention은 Triton 기반 통합 어텐션 커널이다.

드래프팅용 메타데이터

def build_for_drafting(self, common_attn_metadata, draft_index):
    orig_tree_attn_bias = self.tree_attn_bias
    if draft_index == 0:
        self.tree_attn_bias = torch.empty(0)  # 루트: prefill 사용
    else:
        start, end = 1, 1 + common_attn_metadata.max_query_len
        self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
    attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
    self.tree_attn_bias = orig_tree_attn_bias  # 복원
    return attn_metadata

드래프팅 시 draft_index=0(루트)에서는 일반 prefill을 사용하고, 이후 인덱스에서는 트리 바이어스의 서브트리를 잘라서 사용한다.

KV Cache 형상

@staticmethod
def get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size, cache_dtype_str="auto"):
    if block_size % 16 != 0:
        raise ValueError("Block size must be a multiple of 16.")
    return (2, num_blocks, block_size, num_kv_heads, head_size)

(2, num_blocks, block_size, num_kv_heads, head_size) 형상에서 첫 차원 2는 K와 V다. 블록 크기는 반드시 16의 배수여야 한다.

왜 이 설계인가

  1. qq_bias를 통한 트리 마스크: Q-Q 바이어스로 트리 마스크를 구현하면, 기존 어텐션 커널(unified_attention)을 수정하지 않고 트리 어텐션을 적용할 수 있다. -inf 값이 softmax를 통과하면서 자연스럽게 마스킹된다.

  2. 사전 계산된 바이어스: 트리 구조는 요청 간에 변하지 않으므로(동일한 투기 설정), 바이어스를 초기화 시 한 번만 계산하고 재사용한다. 매 디코딩 스텝마다 마스크를 재생성하는 오버헤드가 없다.

  3. forward_includes_kv_cache_update = False: 일반 어텐션 백엔드는 forward 안에서 KV 캐시 업데이트를 포함하지만, Tree Attention은 do_kv_cache_update를 별도 메서드로 분리했다. 이는 투기적 디코딩에서 검증과 캐시 업데이트의 타이밍이 다를 수 있기 때문이다.

참고 자료

댓글

관련 포스트

vLLM 의 다른글