본문으로 건너뛰기

[vLLM] Pooling Tasks: 임베딩, 분류, 스코어링

들어가며

LLM은 텍스트 생성만을 위한 것이 아니다. 문장 임베딩, 텍스트 분류, 유사도 스코어링 등 다양한 태스크에 활용된다. vLLM은 이러한 비생성 태스크를 Pooling이라는 통합 프레임워크로 지원한다. 관련 코드는 vllm/v1/pool/vllm/model_executor/layers/pooler/에 위치한다.

공식 문서: https://docs.vllm.ai/en/latest/models/pooling_models.html

공식 문서

vLLM 공식 문서: Pooling Models

핵심 구조/코드 분석

Pooler 추상 인터페이스

모든 풀링 모델의 핵심은 Pooler 추상 클래스이다.

class Pooler(nn.Module, ABC):
    """The interface required for all poolers used in pooling models in vLLM."""

    @abstractmethod
    def get_supported_tasks(self) -> Set[PoolingTask]:
        """Determine which pooling tasks are supported."""
        raise NotImplementedError

    def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
        return PoolingParamsUpdate()

    @abstractmethod
    def forward(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> PoolerOutput:
        raise NotImplementedError

nn.Module을 상속하므로 PyTorch 모델 그래프의 일부로 동작한다. get_supported_tasks()는 해당 풀러가 임베딩, 분류, 토큰 임베딩 중 어떤 태스크를 지원하는지 선언한다.

PoolingMetadata: 풀링의 핵심 메타데이터

@dataclass
class PoolingMetadata:
    prompt_lens: torch.Tensor          # CPU Tensor
    prompt_token_ids: torch.Tensor | None  # Model-device tensor
    prompt_token_ids_cpu: torch.Tensor | None
    pooling_params: list[PoolingParams]
    pooling_states: list[PoolingStates]
    pooling_cursor: PoolingCursor | None = None

PoolingMetadata는 각 요청의 프롬프트 길이, 토큰 ID, 풀링 파라미터를 담는다. PoolingCursor는 chunked prefill에서 부분 처리를 추적하는 역할이다.

PoolingCursor: Chunked Prefill 지원

@dataclass
class PoolingCursor:
    first_token_indices_gpu: torch.Tensor
    last_token_indices_gpu: torch.Tensor
    prompt_lens_cpu: torch.Tensor
    seq_lens_cpu: torch.Tensor
    num_scheduled_tokens_cpu: torch.Tensor

    def is_partial_prefill(self):
        return not torch.all(
            self.prompt_lens_cpu == self.num_scheduled_tokens_cpu
        )

긴 입력을 청크 단위로 처리할 때, 현재 청크가 전체 프롬프트의 일부인지(partial prefill) 판단하는 메서드가 있다. 풀링은 전체 시퀀스를 본 후에만 의미가 있으므로, 부분 처리 시에는 hidden state를 캐시에 누적한다.

PoolingParamsUpdate: 태스크별 설정

@dataclass(frozen=True)
class PoolingParamsUpdate:
    requires_token_ids: bool = False

    def apply(self, params: PoolingParams) -> None:
        params.requires_token_ids = self.requires_token_ids

특정 풀러가 토큰 ID를 필요로 하는 경우(예: 분류 모델에서 특정 토큰 위치를 찾아야 할 때) 이 업데이트 객체를 통해 파라미터를 조정한다. frozen=True 데이터클래스로 불변성을 보장한다.

PoolingStates: 청크 간 상태 유지

class PoolingStates:
    def __init__(self):
        self.hidden_states_cache: list[torch.Tensor] = []

    def clean(self):
        self.hidden_states_cache.clear()

Chunked prefill에서 청크마다 생성되는 hidden state를 누적하고, 마지막 청크에서 이를 결합하여 풀링을 수행한다.

풀링 태스크의 종류

vLLM이 지원하는 풀링 태스크는 다음과 같다:

  • embed: 시퀀스 레벨 임베딩 (문장 임베딩)
  • classify: 텍스트 분류
  • score: 유사도/관련도 스코어링
  • reward: 리워드 모델 스코어
  • token_embed: 토큰 레벨 임베딩
  • token_classify: 토큰 레벨 분류 (NER 등)

각 태스크는 vllm/model_executor/layers/pooler/ 아래의 세부 풀러 구현과 매핑된다. seqwise/(시퀀스 레벨)와 tokwise/(토큰 레벨) 디렉토리로 분류되어 있다.

왜 이 설계인가

  1. 통합 인터페이스: 생성과 비생성 태스크를 같은 엔진에서 처리한다. PagedAttention과 continuous batching의 이점을 임베딩/분류에도 그대로 적용할 수 있다.

  2. Chunked Prefill과의 호환: 임베딩 모델도 긴 입력을 처리해야 하므로, chunked prefill을 지원한다. PoolingStates로 중간 hidden state를 캐시하고, 마지막 청크에서 통합 풀링을 수행한다.

  3. 태스크 선언적 등록: 각 풀러가 get_supported_tasks()로 지원 태스크를 선언하면, 프레임워크가 자동으로 적합한 API 엔드포인트를 활성화한다. 새로운 풀링 태스크를 추가할 때 풀러만 구현하면 된다.

  4. Pooler를 nn.Module로: PyTorch의 모듈 시스템에 통합되어, 양자화나 컴파일 최적화가 자연스럽게 적용된다.

정리

vLLM의 Pooling 시스템은 LLM 서빙 엔진을 범용 트랜스포머 추론 엔진으로 확장한다. 임베딩, 분류, 스코어링을 위한 별도 서빙 인프라 없이, vLLM 하나로 모든 트랜스포머 기반 태스크를 처리할 수 있다는 것이 핵심 가치이다.

댓글

관련 포스트

vLLM 의 다른글