본문으로 건너뛰기

[sglang] SGLang Triton 커널 최적화: libdevice.tanh 도입과 2D Strided Tensor 지원

PR 링크: sgl-project/sglang#23157 상태: Merged | 변경: +None / -None

들어가며

고성능 LLM 서빙 엔진인 SGLang 프로젝트에서 최근 fused_softcap_kernel의 안정성과 유연성을 크게 향상시킨 업데이트가 있었습니다. 이번 PR의 핵심은 두 가지입니다. 첫째, 수동으로 구현된 tanh 연산에서 발생하던 수치적 정밀도 손실(Catastrophic Cancellation) 문제를 libdevice.tanh를 통해 해결한 것입니다. 둘째, 기존에 1D Flattened Tensor만 처리할 수 있었던 커널을 확장하여 2D Strided Tensor를 올바르게 처리할 수 있도록 리팩토링했습니다.

이 글에서는 실제 코드 변경 사항을 통해 왜 이러한 변화가 시니어 엔지니어링 관점에서 중요한지 분석해 보겠습니다.

1. 수치적 안정성 확보: libdevice.tanh 도입

기존 코드에서는 tanh 함수를 지수 함수(exp)를 사용하여 직접 구현했습니다. 하지만 이는 입력값이 매우 작을 때 심각한 정밀도 문제를 야기합니다.

Before

# Manual tanh implementation using exp
exp2x = tl.exp(2 * x)
x = (exp2x - 1) / (exp2x + 1)

x가 0에 매우 가까울 때, exp(2x)는 1에 수렴합니다. 이때 exp2x - 1 연산에서 Catastrophic Cancellation(유효자리 상실)이 발생하여 부동 소수점 정밀도가 급격히 떨어집니다. 이는 모델의 최종 Logits 값에 오차를 발생시켜 생성 품질에 영향을 줄 수 있습니다.

After

from triton.language.extra import libdevice

# ... 중략 ...
x = x / softcapping_value
x = libdevice.tanh(x)
x = x * softcapping_value

개선된 코드에서는 NVIDIA의 고도로 최적화된 수학 라이브러리인 libdevice.tanh를 직접 호출합니다. libdevice는 하드웨어 수준에서 수치적 안정성이 검증된 알고리즘을 사용하므로, 입력값의 범위와 상관없이 신뢰할 수 있는 결과를 보장합니다.

2. 유연성 확장: 2D Strided Tensor 지원

기존 커널은 모든 텐서를 1차원으로 간주(numel())하고 처리했습니다. 하지만 실제 딥러닝 프레임워크에서 텐서는 메모리 상에 연속적이지 않은(non-contiguous) 상태로 존재할 수 있으며, 특히 2D 레이아웃에서 행(row) 간의 간격(stride)이 다를 수 있습니다.

Before (1D Indexing)

def fused_softcap_kernel(full_logits_ptr, softcapping_value, n_elements, BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0).to(tl.int64)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(full_logits_ptr + offsets, mask=mask)
    # ...

기존 방식은 full_logits_ptr로부터 선형적으로 데이터를 읽어옵니다. 만약 텐서가 특정 차원에서 슬라이싱되었거나 Stride가 부여된 경우, 이 방식은 잘못된 메모리 주소에 접근하게 됩니다.

After (2D Grid & Stride Handling)

def fused_softcap_kernel(full_logits_ptr, softcapping_value, ncols, row_stride, BLOCK_SIZE: tl.constexpr):
    row = tl.program_id(1).to(tl.int64) # 2D Grid의 행 인덱스
    pid = tl.program_id(0).to(tl.int64)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < ncols

    # Row Stride를 고려한 포인터 계산
    row_ptr = full_logits_ptr + row * row_stride
    x = tl.load(row_ptr + offsets, mask=mask)
    # ...
    tl.store(row_ptr + offsets, x, mask=mask)

호스트 코드(fused_softcap)에서도 이를 지원하기 위한 로직이 추가되었습니다.

if full_logits.is_contiguous():
    nrows, ncols = 1, full_logits.numel()
    row_stride = ncols
else:
    assert full_logits.ndim == 2, "non-contiguous softcap requires 2D tensor"
    nrows, ncols = full_logits.shape
    row_stride = full_logits.stride(0)

grid = ((ncols + BLOCK_SIZE - 1) // BLOCK_SIZE, nrows)

이제 커널은 program_id(1)을 통해 행을 구분하고, row_stride를 곱해 정확한 행의 시작 주소를 찾습니다. 이는 torch.view나 슬라이싱으로 인해 메모리 레이아웃이 복잡해진 상황에서도 커널이 안전하게 동작하도록 만듭니다.

왜 이게 좋은 최적화인가?

  1. 수치적 견고함 (Numerical Robustness): tanh와 같은 비선형 함수를 직접 구현할 때는 부동 소수점 연산의 한계를 항상 고려해야 합니다. 검증된 라이브러리(libdevice)를 사용하는 것은 유지보수와 정확도 측면에서 최선의 선택입니다.
  2. 범용성 (Generality): is_contiguous() 체크와 Stride 기반 인덱싱을 통해, 커널 호출 전에 매번 .contiguous()를 호출하여 메모리를 새로 할당하고 복사해야 하는 오버헤드를 방지할 수 있습니다.
  3. Triton의 강점 활용: Triton의 2D Grid 시스템을 활용하여 하드웨어의 병렬성을 더 명확하게 매핑했습니다. 이는 대규모 배치 사이즈 처리 시 효율적인 스케줄링을 가능하게 합니다.

결론

이번 PR은 단순한 기능 추가를 넘어, 딥러닝 커널이 갖추어야 할 정밀도메모리 레이아웃 대응 능력을 동시에 개선한 훌륭한 사례입니다. 특히 SGLang과 같이 고성능을 지향하는 프로젝트에서 이러한 디테일한 최적화는 시스템 전체의 신뢰성을 결정짓는 중요한 요소가 됩니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글