[sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
PR 링크: sgl-project/sglang#23157 상태: Merged | 변경: +None / -None
들어가며
고성능 LLM 서빙 엔진인 SGLang 프로젝트에서 최근 fused_softcap_kernel의 안정성과 유연성을 크게 향상시킨 업데이트가 있었습니다. 이번 PR의 핵심은 두 가지입니다. 첫째, 수동으로 구현된 tanh 연산에서 발생하던 수치적 정밀도 손실(Catastrophic Cancellation) 문제를 libdevice.tanh를 통해 해결한 것입니다. 둘째, 기존에 1D Flattened Tensor만 처리할 수 있었던 커널을 확장하여 2D Strided Tensor를 올바르게 처리할 수 있도록 리팩토링했습니다.
이 글에서는 실제 코드 변경 사항을 통해 왜 이러한 변화가 시니어 엔지니어링 관점에서 중요한지 분석해 보겠습니다.
1. 수치적 안정성 확보: libdevice.tanh 도입
기존 코드에서는 tanh 함수를 지수 함수(exp)를 사용하여 직접 구현했습니다. 하지만 이는 입력값이 매우 작을 때 심각한 정밀도 문제를 야기합니다.
Before
# Manual tanh implementation using exp
exp2x = tl.exp(2 * x)
x = (exp2x - 1) / (exp2x + 1)
x가 0에 매우 가까울 때, exp(2x)는 1에 수렴합니다. 이때 exp2x - 1 연산에서 Catastrophic Cancellation(유효자리 상실)이 발생하여 부동 소수점 정밀도가 급격히 떨어집니다. 이는 모델의 최종 Logits 값에 오차를 발생시켜 생성 품질에 영향을 줄 수 있습니다.
After
from triton.language.extra import libdevice
# ... 중략 ...
x = x / softcapping_value
x = libdevice.tanh(x)
x = x * softcapping_value
개선된 코드에서는 NVIDIA의 고도로 최적화된 수학 라이브러리인 libdevice.tanh를 직접 호출합니다. libdevice는 하드웨어 수준에서 수치적 안정성이 검증된 알고리즘을 사용하므로, 입력값의 범위와 상관없이 신뢰할 수 있는 결과를 보장합니다.
2. 유연성 확장: 2D Strided Tensor 지원
기존 커널은 모든 텐서를 1차원으로 간주(numel())하고 처리했습니다. 하지만 실제 딥러닝 프레임워크에서 텐서는 메모리 상에 연속적이지 않은(non-contiguous) 상태로 존재할 수 있으며, 특히 2D 레이아웃에서 행(row) 간의 간격(stride)이 다를 수 있습니다.
Before (1D Indexing)
def fused_softcap_kernel(full_logits_ptr, softcapping_value, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0).to(tl.int64)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(full_logits_ptr + offsets, mask=mask)
# ...
기존 방식은 full_logits_ptr로부터 선형적으로 데이터를 읽어옵니다. 만약 텐서가 특정 차원에서 슬라이싱되었거나 Stride가 부여된 경우, 이 방식은 잘못된 메모리 주소에 접근하게 됩니다.
After (2D Grid & Stride Handling)
def fused_softcap_kernel(full_logits_ptr, softcapping_value, ncols, row_stride, BLOCK_SIZE: tl.constexpr):
row = tl.program_id(1).to(tl.int64) # 2D Grid의 행 인덱스
pid = tl.program_id(0).to(tl.int64)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < ncols
# Row Stride를 고려한 포인터 계산
row_ptr = full_logits_ptr + row * row_stride
x = tl.load(row_ptr + offsets, mask=mask)
# ...
tl.store(row_ptr + offsets, x, mask=mask)
호스트 코드(fused_softcap)에서도 이를 지원하기 위한 로직이 추가되었습니다.
if full_logits.is_contiguous():
nrows, ncols = 1, full_logits.numel()
row_stride = ncols
else:
assert full_logits.ndim == 2, "non-contiguous softcap requires 2D tensor"
nrows, ncols = full_logits.shape
row_stride = full_logits.stride(0)
grid = ((ncols + BLOCK_SIZE - 1) // BLOCK_SIZE, nrows)
이제 커널은 program_id(1)을 통해 행을 구분하고, row_stride를 곱해 정확한 행의 시작 주소를 찾습니다. 이는 torch.view나 슬라이싱으로 인해 메모리 레이아웃이 복잡해진 상황에서도 커널이 안전하게 동작하도록 만듭니다.
왜 이게 좋은 최적화인가?
- 수치적 견고함 (Numerical Robustness):
tanh와 같은 비선형 함수를 직접 구현할 때는 부동 소수점 연산의 한계를 항상 고려해야 합니다. 검증된 라이브러리(libdevice)를 사용하는 것은 유지보수와 정확도 측면에서 최선의 선택입니다. - 범용성 (Generality):
is_contiguous()체크와 Stride 기반 인덱싱을 통해, 커널 호출 전에 매번.contiguous()를 호출하여 메모리를 새로 할당하고 복사해야 하는 오버헤드를 방지할 수 있습니다. - Triton의 강점 활용: Triton의 2D Grid 시스템을 활용하여 하드웨어의 병렬성을 더 명확하게 매핑했습니다. 이는 대규모 배치 사이즈 처리 시 효율적인 스케줄링을 가능하게 합니다.
결론
이번 PR은 단순한 기능 추가를 넘어, 딥러닝 커널이 갖추어야 할 정밀도와 메모리 레이아웃 대응 능력을 동시에 개선한 훌륭한 사례입니다. 특히 SGLang과 같이 고성능을 지향하는 프로젝트에서 이러한 디테일한 최적화는 시스템 전체의 신뢰성을 결정짓는 중요한 요소가 됩니다.
참고 자료
- https://triton-lang.org/main/python-api/triton.language.html#triton.language.extra.libdevice.tanh
- https://docs.nvidia.com/cuda/libdevice-users-guide/index.html
- https://pytorch.org/docs/stable/generated/torch.Tensor.stride.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] SGLang의 디코드 성능 향상을 위한 Temperature 및 Softmax 커널 융합
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [sglang] SGLang의 FA3 디코드 최적화: get_scheduler_metadata 도입
- [sglang] AMD ROCm 환경에서의 성능 최적화: Triton을 활용한 Fused QK GemmaRMSNorm 구현
- [triton] Triton Gluon Attention 커널의 Autotuning을 통한 성능 최적화 분석
PR Analysis 의 다른글
- 이전글 [sglang] SGLang 고성능 서빙: 비동기 알림 배치 처리와 SSE 고속 경로 최적화 분석
- 현재글 : [sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원
- 다음글 [vllm] vLLM CPU 성능 최적화: NEON 하드웨어를 위한 고속 Exp 연산 도입
댓글