[axolotl] Axolotl: Triton 커널을 활용한 Entropy 및 Selective Log Softmax 최적화
PR 링크: axolotl-ai-cloud/axolotl#3510 상태: Merged | 변경: +1346 / -0
들어가며
대규모 언어 모델(LLM) 훈련은 막대한 컴퓨팅 자원을 요구하며, 작은 최적화 하나하나가 전체 훈련 시간에 큰 영향을 미칩니다. 특히, 모델의 출력 로짓(logits)에서 엔트로피를 계산하거나 특정 로짓에 대해서만 로그 소프트맥스를 적용하는 연산은 훈련 과정에서 빈번하게 발생하며, 이들의 효율성은 전체 훈련 속도를 좌우하는 중요한 요소입니다. 기존 PyTorch 구현은 범용성을 위해 설계되었지만, 특정 패턴의 연산에서는 비효율적일 수 있습니다. axolotl-ai-cloud/axolotl 레포지토리의 이 PR은 이러한 문제를 해결하기 위해 NVIDIA의 Triton 언어를 사용하여 커스텀 GPU 커널을 구현함으로써, entropy_from_logits와 selective_log_softmax 연산의 성능을 획기적으로 개선합니다.
이 PR의 핵심 목표는 다음과 같습니다:
entropy_from_logits연산의 속도 및 메모리 효율성 개선.selective_log_softmax연산의 순방향(forward) 및 역방향(backward) 전파 속도 및 메모리 효율성 개선.torch.compile이 역방향 전파를 지원하지 않는selective_log_softmax의 한계를 극복.
코드 분석: 무엇이 왜 좋은 최적화/개선인가
이 PR은 주로 두 가지 핵심 연산인 entropy_from_logits와 selective_log_softmax에 대한 Triton 커널 구현을 통해 최적화를 달성합니다. 각 연산에 대한 벤치마크 파일이 추가되어 성능 개선을 명확히 보여줍니다.
1. benchmarks/bench_entropy.py
이 파일은 entropy_from_logits 함수의 성능을 벤치마킹합니다. 기존 PyTorch 구현과 Triton 커널 구현을 비교하여 시간 및 메모리 사용량을 측정합니다.
기존 entropy_from_logits_original 구현
# Before: benchmarks/bench_entropy.py
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
"""Original chunked implementation (reference)."""
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
entropies.append(chunk_entropy)
return torch.cat(entropies, dim=0).reshape(original_shape)
문제점:
- 청크(Chunking) 처리:
flat_logits.split(chunk_size, dim=0)를 사용하여 로짓을 청크로 나누어 처리합니다. 이는 큰 텐서를 GPU 메모리에 한 번에 로드하지 않기 위한 전략일 수 있지만, 각 청크마다 별도의 커널 호출 및 메모리 할당/해제가 발생하여 오버헤드가 큽니다. - 여러 단계의 연산:
F.log_softmax,torch.exp, 곱셈, 뺄셈,sum등 여러 PyTorch 연산이 순차적으로 호출됩니다. 각 연산은 별도의 GPU 커널을 실행하며, 중간 결과물을 GPU 메모리에 쓰고 읽는 과정에서 불필요한 메모리 대역폭 소모(memory bound)와 레이턴시가 발생합니다. - 메모리 비효율성:
logps,chunk_entropy등 중간 텐서들이 생성되어 추가적인 메모리를 사용합니다. - 비연속(Non-contiguous) 텐서 처리:
logits.contiguous()호출이 없으므로, 입력logits가 비연속(non-contiguous) 메모리 레이아웃을 가질 경우,F.log_softmax내부에서 암묵적인contiguous()복사가 발생하여 성능 저하를 유발할 수 있습니다.
Triton 커널 entropy_from_logits 구현 (추정)
PR diff에는 Triton 커널의 실제 코드가 포함되어 있지 않지만, 벤치마크 파일에서 axolotl.monkeypatch.trainer.utils.entropy_from_logits를 사용하고 있으며, PR 설명에 따르면 이는 Triton 커널로 구현되었습니다. Triton 커널은 다음과 같은 방식으로 최적화를 수행했을 것입니다.
# After: axolotl/monkeypatch/trainer/utils.py (Triton kernel, conceptual representation)
# @triton.jit
# def _entropy_from_logits_kernel(logits_ptr, output_ptr, ...):
# # Load logits, compute log_softmax, exp, multiply, sum in a single fused kernel
# # Avoid intermediate tensor allocations
# # Handle strided access directly
#
def entropy_from_logits(logits: torch.Tensor):
# Call the Triton kernel
# ...
개선점:
- 커널 퓨전(Kernel Fusion): Triton은 여러 연산을 하나의 GPU 커널로 묶는 커널 퓨전을 가능하게 합니다.
log_softmax,exp, 곱셈, 뺄셈,sum연산을 단일 커널 내에서 처리함으로써, 중간 결과물을 GPU 레지스터나 공유 메모리에 유지하고 글로벌 메모리 접근을 최소화합니다. 이는 메모리 대역폭 병목 현상을 크게 줄여줍니다. - 메모리 접근 패턴 최적화: Triton 커널은 GPU의 하드웨어 특성에 맞춰 메모리 접근 패턴을 최적화할 수 있습니다. 예를 들어, Coalesced Memory Access를 활용하여 GPU의 처리량을 극대화합니다.
- 비연속 텐서 지원: Triton 커널은 스트라이드(strided) 메모리 접근을 직접 처리할 수 있으므로, 입력
logits가 비연속이더라도 명시적인contiguous()복사 없이 효율적으로 연산을 수행할 수 있습니다. 벤치마크의benchmark_noncontiguous섹션에서 이 장점이 명확히 드러납니다. - 오버헤드 감소: 청크 처리로 인한 반복적인 커널 호출 및 동기화 오버헤드를 제거합니다.
2. benchmarks/bench_selective_logsoftmax.py
이 파일은 selective_log_softmax 함수의 성능을 벤치마킹합니다. 특정 인덱스에 해당하는 로짓에 대해서만 로그 소프트맥스를 계산하는 연산입니다.
기존 selective_log_softmax_original 구현 (추정)
PR diff에는 selective_log_softmax_original의 코드가 직접 포함되어 있지 않지만, 일반적으로 PyTorch에서는 다음과 유사하게 구현될 수 있습니다.
# Before: axolotl/monkeypatch/trainer/utils.py (conceptual representation)
# def selective_log_softmax_original(logits: torch.Tensor, index: torch.Tensor):
# # Create a mask or gather relevant logits
# # Apply log_softmax to selected logits
# # Scatter results back or return a sparse tensor
# # This often involves multiple PyTorch ops and intermediate tensors
문제점:
- 복잡한 인덱싱: 특정 인덱스에 해당하는 로짓만 선택하여 연산하는 과정은 PyTorch의 기본 연산으로는 여러 단계의 인덱싱, 마스킹,
log_softmax적용, 결과 병합 등의 복잡한 과정을 거쳐야 합니다. 이는 여러 커널 호출과 중간 텐서 생성을 유발합니다. - 역방향 전파의 어려움: PR 설명에 따르면
torch.compile조차 이 연산의 역방향 전파를 제대로 처리하지 못하는 문제가 있었습니다. 이는 복잡한 인덱싱과 조건부 연산이 섞여 있을 때torch.autograd시스템이 효율적인 역방향 그래프를 생성하기 어렵기 때문일 수 있습니다.
Triton 커널 selective_log_softmax 구현 (추정)
마찬가지로 Triton 커널의 실제 코드는 없지만, PR 설명과 벤치마크 결과를 통해 그 장점을 유추할 수 있습니다.
# After: axolotl/monkeypatch/trainer/utils.py (Triton kernel, conceptual representation)
# @triton.jit
# def _selective_log_softmax_kernel(logits_ptr, index_ptr, output_ptr, ...):
# # Load logits and index for each element
# # Compute log_softmax for selected elements directly in the kernel
# # Store results
#
# @triton.jit
# def _selective_log_softmax_backward_kernel(grad_output_ptr, logits_ptr, index_ptr, grad_logits_ptr, ...):
# # Compute gradients for selected elements directly
#
def selective_log_softmax(logits: torch.Tensor, index: torch.Tensor):
# Call the Triton forward kernel
# Define custom autograd.Function to call Triton backward kernel
# ...
개선점:
- 단일 커널 내 선택적 연산: Triton 커널은 각 스레드가 자신의
logits와index를 로드하여 필요한log_softmax계산을 수행하고 결과를 저장하는 과정을 단일 커널 내에서 효율적으로 처리할 수 있습니다. 이는 인덱싱 및 마스킹 오버헤드를 제거하고 메모리 접근을 최적화합니다. - 순방향 및 역방향 전파 지원: Triton은 커스텀
torch.autograd.Function을 통해 순방향 및 역방향 커널을 모두 구현할 수 있습니다. 이를 통해torch.compile이 지원하지 못했던selective_log_softmax의 역방향 전파를 효율적으로 처리할 수 있게 됩니다. 이는 훈련 안정성과 성능에 매우 중요합니다. - 메모리 효율성: 중간 텐서 생성을 최소화하여 메모리 사용량을 줄입니다.
왜 이게 좋은가
이 PR은 LLM 훈련의 핵심 연산들을 Triton 커널로 최적화하여 다음과 같은 중요한 이점을 제공합니다.
1. 성능 향상
PR 설명과 벤치마크 결과는 상당한 성능 향상을 보여줍니다.
-
entropy_from_logits:- 네이티브 PyTorch 대비 ~5배 빠름.
torch.compile대비 ~3배 빠름.- 벤치마크 결과에서
B=16, L=4096(약 65K rows) 설정에서original이 100ms 이상 걸리는 반면,triton은 10ms대로 줄어들어 약 10배의 속도 향상을 보입니다. - 비연속 텐서의 경우,
original+copy는contiguous()호출로 인한 오버헤드가 추가되어 성능이 더 떨어지지만,triton-strided는 여전히 높은 성능을 유지합니다.
-
selective_log_softmax:- 순방향 + 역방향 전파(fwd+bwd)에서 ~3배 빠름.
torch.compile이 역방향 전파를 지원하지 않는 문제를 해결.- 벤치마크 결과에서
B=16, L=4096설정에서original의 FWD+BWD 시간이 200ms 이상인 반면,triton은 60ms대로 줄어들어 약 3배의 속도 향상을 보입니다.
이러한 속도 향상은 LLM 훈련 시 에폭당 시간을 크게 단축시켜 전체 훈련 비용을 절감하고 연구 반복 속도를 높이는 데 기여합니다.
2. 메모리 효율성
벤치마크 결과는 피크 메모리 오버헤드(peak overhead) 감소도 명확히 보여줍니다.
entropy_from_logits:B=16, L=4096설정에서original은 100MB 이상의 피크 오버헤드를 보이는 반면,triton은 10MB 미만으로 약 90MB 이상 절약합니다.selective_log_softmax: 유사하게original대비triton이 훨씬 적은 메모리 오버헤드를 가집니다.
메모리 사용량 감소는 더 큰 배치 크기(batch size)나 더 긴 시퀀스 길이(sequence length)로 모델을 훈련할 수 있게 하여, GPU 자원 활용도를 극대화하고 OOM(Out Of Memory) 오류 발생 가능성을 줄여줍니다.
3. torch.compile의 한계 극복
selective_log_softmax의 경우, torch.compile이 역방향 전파를 제대로 지원하지 못하는 문제가 있었습니다. Triton 커널은 torch.autograd.Function을 통해 커스텀 역방향 커널을 구현할 수 있으므로, 이러한 torch.compile의 한계를 극복하고 해당 연산에 대한 완전한 순방향/역방향 최적화를 제공합니다. 이는 axolotl과 같은 훈련 프레임워크의 안정성과 유연성을 높이는 데 중요합니다.
일반적 교훈
이 PR은 다음과 같은 중요한 교훈을 제공합니다.
- 커널 퓨전의 중요성: 여러 PyTorch 연산을 조합하여 사용하는 경우, 중간 텐서 생성과 글로벌 메모리 접근으로 인한 오버헤드가 발생하기 쉽습니다. Triton과 같은 도구를 사용하여 이러한 연산들을 하나의 커널로 퓨전하면 GPU 성능을 크게 향상시킬 수 있습니다.
- 메모리 접근 패턴 최적화: GPU는 메모리 대역폭에 의해 성능이 좌우되는 경우가 많습니다. Triton을 사용하면 데이터 로드 및 저장 패턴을 GPU 하드웨어에 최적화하여 캐시 효율성을 높이고 Coalesced Memory Access를 활용할 수 있습니다.
torch.compile의 보완재:torch.compile은 많은 경우에 훌륭한 최적화 도구이지만, 모든 연산 패턴에 대해 완벽하게 작동하지 않을 수 있습니다. 특히 복잡한 인덱싱이나 조건부 로직이 포함된 연산, 또는autograd그래프가 복잡해지는 경우에는 Triton과 같은 커스텀 커널이torch.compile의 한계를 보완하고 더 깊은 최적화를 제공할 수 있습니다.- 벤치마킹의 중요성: 최적화의 효과를 정량적으로 측정하고 검증하기 위해서는 정확한 벤치마킹이 필수적입니다. 이 PR에서 추가된 벤치마크 파일들은 최적화의 가치를 명확히 보여줍니다.
결론
axolotl의 이 PR은 Triton 커널을 활용하여 LLM 훈련의 핵심 연산인 entropy_from_logits와 selective_log_softmax의 성능을 획기적으로 개선했습니다. 이는 훈련 속도 향상, 메모리 사용량 감소, 그리고 torch.compile의 한계 극복이라는 세 가지 중요한 이점을 가져옵니다. 이러한 최적화는 대규모 모델 훈련의 효율성을 높이고, 더 많은 연구와 개발을 가능하게 하는 데 크게 기여할 것입니다. 커스텀 GPU 커널 개발은 복잡하지만, 특정 병목 현상을 해결하는 데 있어 가장 강력한 도구 중 하나임을 다시 한번 보여주는 사례입니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html
- https://pytorch.org/docs/stable/generated/torch.exp.html
- https://pytorch.org/docs/stable/generated/torch.sum.html
- https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html
- https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [axolotl] Triton LoRA 커널 Autotune 테스트 안정화: pytest-xdist 환경에서의 모듈 격리 전략
- 현재글 : [axolotl] Axolotl: Triton 커널을 활용한 Entropy 및 Selective Log Softmax 최적화
- 다음글 [triton] ConSan Multi-CTA 지원 추가
댓글