본문으로 건너뛰기

[axolotl] Axolotl 커스텀 Triton 커널 — entropy/softmax 최대 5배 가속

PR 링크: axolotl-ai-cloud/axolotl#3510 상태: Merged | 변경: +1346 / -0

들어가며

RLHF(Reinforcement Learning from Human Feedback) 학습에서 entropy_from_logitsselective_log_softmax는 매 step마다 호출되는 핵심 연산이다. PyTorch 기본 구현은 log_softmax + gather 또는 chunked iteration으로 수행되는데, 대형 vocabulary(151K+ tokens)에서는 중간 텐서 할당과 다중 커널 launch가 병목이 된다. 이 PR은 두 연산을 각각 single-pass Triton 커널로 대체하여 메모리 사용량을 줄이고 속도를 최대 5배 향상시킨다.

핵심 코드 분석

1. Online Entropy 커널 — 단일 패스 계산

기존 구현은 chunk 단위로 log_softmax + element-wise 곱셈 + sum을 반복했다. Triton 커널은 online algorithm으로 단일 패스에서 entropy를 계산한다.

Before (PyTorch chunked):

def entropy_from_logits_original(logits, chunk_size=128):
    flat_logits = logits.reshape(-1, num_classes)
    entropies = []
    for chunk in flat_logits.split(chunk_size, dim=0):
        logps = F.log_softmax(chunk, dim=-1)
        chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
        entropies.append(chunk_entropy)
    return torch.cat(entropies, dim=0).reshape(original_shape)

After (Triton online):

@triton.jit
def _entropy_online_kernel(logits_ptr, output_ptr, stride_row, V, BLOCK_V):
    row = tl.program_id(0)
    row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row

    running_max = tl.full([], float("-inf"), dtype=tl.float32)
    running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
    running_weighted = tl.full([], 0.0, dtype=tl.float32)

    for v_start in range(0, V, BLOCK_V):
        offs = v_start + tl.arange(0, BLOCK_V)
        mask = offs < V
        x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)

        block_max = tl.max(x, axis=0)
        new_max = tl.maximum(running_max, block_max)
        correction = tl.exp(running_max - new_max)
        running_sum_exp = running_sum_exp * correction
        running_weighted = running_weighted * correction

        exp_x = tl.exp(x - new_max)
        running_sum_exp += tl.sum(exp_x, axis=0)
        running_weighted += tl.sum(exp_x * x, axis=0)
        running_max = new_max

    entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
    tl.store(output_ptr + row, entropy)

핵심은 running max correction 기법이다. 각 블록에서 새로운 최댓값을 발견하면, 이전까지의 sum_expweighted_sumexp(old_max - new_max) 보정을 적용한다. 이로써 전체 vocabulary를 한 번만 순회하면서 수치적으로 안정적인 entropy를 계산한다.

2. Selective Log Softmax — Forward + Backward 모두 Fused

selective_log_softmax은 전체 softmax를 계산하지 않고 선택된 index의 log probability만 반환한다. Triton으로 forward와 backward를 모두 구현했다.

Forward 커널:

@triton.jit
def _selective_logsoftmax_fwd_kernel(
    logits_ptr, index_ptr, output_ptr, logsumexp_ptr, ...
):
    # Online logsumexp (한 번의 순회)
    for v_start in range(0, V, BLOCK_V):
        ...
        running_sum_exp = running_sum_exp * tl.exp(running_max - new_max)
        exp_x = tl.exp(x - new_max)
        running_sum_exp += tl.sum(exp_x, axis=0)
        running_max = new_max

    lse = tl.log(running_sum_exp) + running_max
    tl.store(logsumexp_ptr + row, lse)

    # Gather and subtract
    selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask)
    tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask)

Backward 커널의 핵심 — Fused Scatter:

@triton.jit
def _selective_logsoftmax_bwd_kernel(...):
    for v_start in range(0, V, BLOCK_V):
        offs = v_start + tl.arange(0, BLOCK_V)
        softmax_j = tl.exp(x - lse)
        grad_j = -softmax_j * grad_sum

        # Scatter: 별도 패스 없이 inline으로 처리
        match = offs[:, None] == indices[None, :]  # [BLOCK_V, K_BLOCK]
        scatter_contrib = tl.sum(
            tl.where(match, grad_out[None, :], 0.0), axis=1
        )
        grad_j += scatter_contrib
        tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask)

Backward에서 scatter를 별도 패스 없이 base gradient 계산과 동시에 처리한다. offs[:, None] == indices[None, :] broadcast 비교로 read-after-write 문제를 회피한다.

3. Monkeypatch 통합

def _apply_trl_trainer_utils_patches(self):
    if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
        trl.trainer.utils.selective_log_softmax = selective_log_softmax
    if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
        trl.trainer.utils.entropy_from_logits = entropy_from_logits

trl 라이브러리의 함수를 런타임에 교체하는 monkeypatch 패턴으로, 사용자 코드 수정 없이 Triton 커널이 적용된다.

왜 이게 좋은가

  • 메모리 절약: 중간 log_softmax 텐서 (vocab_size 크기)를 할당하지 않는다. 151K vocabulary 기준으로 행당 수백 KB를 절약한다.
  • 커널 launch 감소: chunk loop의 N번 launch를 1번으로 줄인다.
  • Non-contiguous 지원: stride 기반 커널로 .contiguous() 복사 없이 transpose된 텐서를 처리한다.
  • Autograd 통합: torch.autograd.Function으로 래핑하여 기존 학습 루프에서 backward가 자동으로 동작한다.

정리

  • Online max correction 알고리즘은 softmax 계열 연산을 단일 패스로 처리하는 핵심 기법이다.
  • Triton의 tl.where + broadcast 패턴으로 scatter를 fuse하면 별도 atomicAdd 없이 안전하게 gradient를 분배할 수 있다.
  • Monkeypatch 방식은 upstream 라이브러리 수정 없이 최적화를 주입하는 실용적인 전략이다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글