본문으로 건너뛰기

[sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합

PR 링크: sgl-project/sglang#20501 상태: Merged | 변경: +None / -None

들어가며

LLM의 디코드(Decode) 단계는 매 토큰 생성마다 반복되는 연산으로, 전체 추론 성능을 결정짓는 핵심 병목 구간입니다. 기존 SGLang의 sampler.pylogits.div_(temperature)torch.softmax를 별도의 CUDA 커널로 실행했습니다. 이 방식은 두 번의 커널 호출과 6번의 전역 메모리 패스(4번 읽기, 2번 쓰기)를 유발하여, 특히 Llama 3(128K)나 Qwen(152K)과 같은 대규모 어휘 사전(Large-vocab) 모델에서 성능 저하를 야기했습니다. 본 PR은 이 두 연산을 하나의 Triton 커널로 융합하여 메모리 대역폭 사용을 획기적으로 줄였습니다.

코드 분석

1. python/sglang/srt/layers/fused_sampling.py (신규)

핵심 변경 사항은 Triton을 이용한 커널 융합입니다. 어휘 사전 크기에 따라 두 가지 전략을 사용합니다.

  • Single-pass kernel: 어휘 사전이 32,768 이하일 때 사용하며, 데이터를 레지스터에 로드하여 1번의 읽기와 1번의 쓰기만으로 연산을 완료합니다.
  • Multi-pass kernel: 대규모 어휘 사전을 위해 3-pass online softmax 방식을 채택했습니다. triton.autotune을 통해 BLOCK_SIZEnum_warps를 최적화합니다.
# Before (sampler.py)
logits.div_(sampling_info.temperatures)
logits[:] = torch.softmax(logits, dim=-1)

# After (fused_sampling.py - Single-pass 예시)
x = tl.load(logits_ptr + row_idx * logits_stride + offsets, mask=mask)
x = (x / temp).to(tl.float32)
x_max = tl.max(x, axis=0)
exp_x = tl.exp(x - x_max)
prob = exp_x / tl.sum(exp_x, axis=0)
tl.store(output_ptr + row_idx * output_stride + offsets, prob, mask=mask)

2. python/sglang/srt/layers/sampler.py (수정)

단순히 커널을 교체하는 것을 넘어, 성능과 정확성을 모두 잡기 위해 하이브리드 디스패치 로직을 도입했습니다.

# 하이브리드 디스패치 도입
if batch_size < 128:
    # 소규모 배치에서는 PyTorch 네이티브 연산이 더 빠름
    logits.div_(temperatures)
    torch.softmax(logits, dim=-1, out=logits)
else:
    # 대규모 배치에서는 융합된 Triton 커널 사용
    fused_temperature_softmax_inplace(logits, temperatures)

왜 이게 좋은가

이번 최적화의 핵심은 메모리 패스 감소커널 호출 오버헤드 제거입니다.

  1. 성능 향상: 벤치마크 결과, 배치 사이즈가 128 이상인 경우 기존 대비 약 2~4배의 속도 향상을 보였습니다. 특히 대규모 어휘 사전 환경에서 그 효과가 극대화됩니다.
  2. 정확성 유지: 초기 구현에서 PyTorch와 연산 순서가 달라 구조화된 출력(Grammar-constrained decoding)에서 문제가 발생했으나, 3-pass 연산 방식을 채택하여 PyTorch와 수학적으로 동일한 결과를 보장하도록 수정되었습니다.
  3. 하이브리드 전략: 무조건적인 융합이 아닌, 배치 사이즈에 따른 분기 처리를 통해 소규모 배치에서의 오버헤드까지 고려했습니다.

교훈: 커널 융합은 단순히 연산을 합치는 것뿐만 아니라, 메모리 접근 패턴(Memory Access Pattern)을 최적화하는 것이 핵심입니다. 또한, 최적화된 커널이라도 특정 워크로드(소규모 배치)에서는 네이티브 라이브러리가 더 효율적일 수 있음을 인지하고 하이브리드 전략을 취하는 것이 중요합니다.

참고 자료

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글