[axolotl] Axolotl 커스텀 Triton 커널 — entropy/softmax 최대 5배 가속
PR 링크: axolotl-ai-cloud/axolotl#3510 상태: Merged | 변경: +1346 / -0
들어가며
RLHF(Reinforcement Learning from Human Feedback) 학습에서 entropy_from_logits와 selective_log_softmax는 매 step마다 호출되는 핵심 연산이다. PyTorch 기본 구현은 log_softmax + gather 또는 chunked iteration으로 수행되는데, 대형 vocabulary(151K+ tokens)에서는 중간 텐서 할당과 다중 커널 launch가 병목이 된다. 이 PR은 두 연산을 각각 single-pass Triton 커널로 대체하여 메모리 사용량을 줄이고 속도를 최대 5배 향상시킨다.
핵심 코드 분석
1. Online Entropy 커널 — 단일 패스 계산
기존 구현은 chunk 단위로 log_softmax + element-wise 곱셈 + sum을 반복했다. Triton 커널은 online algorithm으로 단일 패스에서 entropy를 계산한다.
Before (PyTorch chunked):
def entropy_from_logits_original(logits, chunk_size=128):
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
entropies.append(chunk_entropy)
return torch.cat(entropies, dim=0).reshape(original_shape)
After (Triton online):
@triton.jit
def _entropy_online_kernel(logits_ptr, output_ptr, stride_row, V, BLOCK_V):
row = tl.program_id(0)
row_ptr = logits_ptr + tl.cast(row, tl.int64) * stride_row
running_max = tl.full([], float("-inf"), dtype=tl.float32)
running_sum_exp = tl.full([], 0.0, dtype=tl.float32)
running_weighted = tl.full([], 0.0, dtype=tl.float32)
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
mask = offs < V
x = tl.load(row_ptr + offs, mask=mask, other=float("-inf")).to(tl.float32)
block_max = tl.max(x, axis=0)
new_max = tl.maximum(running_max, block_max)
correction = tl.exp(running_max - new_max)
running_sum_exp = running_sum_exp * correction
running_weighted = running_weighted * correction
exp_x = tl.exp(x - new_max)
running_sum_exp += tl.sum(exp_x, axis=0)
running_weighted += tl.sum(exp_x * x, axis=0)
running_max = new_max
entropy = tl.log(running_sum_exp) + running_max - running_weighted / running_sum_exp
tl.store(output_ptr + row, entropy)
핵심은 running max correction 기법이다. 각 블록에서 새로운 최댓값을 발견하면, 이전까지의 sum_exp와 weighted_sum에 exp(old_max - new_max) 보정을 적용한다. 이로써 전체 vocabulary를 한 번만 순회하면서 수치적으로 안정적인 entropy를 계산한다.
2. Selective Log Softmax — Forward + Backward 모두 Fused
selective_log_softmax은 전체 softmax를 계산하지 않고 선택된 index의 log probability만 반환한다. Triton으로 forward와 backward를 모두 구현했다.
Forward 커널:
@triton.jit
def _selective_logsoftmax_fwd_kernel(
logits_ptr, index_ptr, output_ptr, logsumexp_ptr, ...
):
# Online logsumexp (한 번의 순회)
for v_start in range(0, V, BLOCK_V):
...
running_sum_exp = running_sum_exp * tl.exp(running_max - new_max)
exp_x = tl.exp(x - new_max)
running_sum_exp += tl.sum(exp_x, axis=0)
running_max = new_max
lse = tl.log(running_sum_exp) + running_max
tl.store(logsumexp_ptr + row, lse)
# Gather and subtract
selected = tl.load(logits_row_ptr + safe_indices, mask=valid_mask)
tl.store(output_row_ptr + k_offs, selected - lse, mask=valid_mask)
Backward 커널의 핵심 — Fused Scatter:
@triton.jit
def _selective_logsoftmax_bwd_kernel(...):
for v_start in range(0, V, BLOCK_V):
offs = v_start + tl.arange(0, BLOCK_V)
softmax_j = tl.exp(x - lse)
grad_j = -softmax_j * grad_sum
# Scatter: 별도 패스 없이 inline으로 처리
match = offs[:, None] == indices[None, :] # [BLOCK_V, K_BLOCK]
scatter_contrib = tl.sum(
tl.where(match, grad_out[None, :], 0.0), axis=1
)
grad_j += scatter_contrib
tl.store(grad_logits_row_ptr + offs, grad_j, mask=mask)
Backward에서 scatter를 별도 패스 없이 base gradient 계산과 동시에 처리한다. offs[:, None] == indices[None, :] broadcast 비교로 read-after-write 문제를 회피한다.
3. Monkeypatch 통합
def _apply_trl_trainer_utils_patches(self):
if trl.trainer.utils.selective_log_softmax is not selective_log_softmax:
trl.trainer.utils.selective_log_softmax = selective_log_softmax
if trl.trainer.utils.entropy_from_logits is not entropy_from_logits:
trl.trainer.utils.entropy_from_logits = entropy_from_logits
trl 라이브러리의 함수를 런타임에 교체하는 monkeypatch 패턴으로, 사용자 코드 수정 없이 Triton 커널이 적용된다.
왜 이게 좋은가
- 메모리 절약: 중간
log_softmax텐서 (vocab_size 크기)를 할당하지 않는다. 151K vocabulary 기준으로 행당 수백 KB를 절약한다. - 커널 launch 감소: chunk loop의 N번 launch를 1번으로 줄인다.
- Non-contiguous 지원: stride 기반 커널로
.contiguous()복사 없이 transpose된 텐서를 처리한다. - Autograd 통합:
torch.autograd.Function으로 래핑하여 기존 학습 루프에서 backward가 자동으로 동작한다.
정리
- Online max correction 알고리즘은 softmax 계열 연산을 단일 패스로 처리하는 핵심 기법이다.
- Triton의
tl.where+ broadcast 패턴으로 scatter를 fuse하면 별도 atomicAdd 없이 안전하게 gradient를 분배할 수 있다. - Monkeypatch 방식은 upstream 라이브러리 수정 없이 최적화를 주입하는 실용적인 전략이다.
참고 자료
- Triton 공식 튜토리얼 — Triton 커널 작성 기초
- Online Normalizer Calculation (Milakov & Gimelshein, 2018) — Online softmax 알고리즘 원논문
- trl 라이브러리 — RLHF 학습 프레임워크
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray] find_gcs_addresses 결과 캐싱으로 프로세스 스캔 비용 제거
- 현재글 : [axolotl] Axolotl 커스텀 Triton 커널 — entropy/softmax 최대 5배 가속
- 다음글 [axolotl] Triton LoRA 커널 Autotune 테스트 안정화: pytest-xdist 환경에서의 모듈 격리 전략
댓글