[Triton] tl.histogram을 ATOMS.POPC.INC로 변경
들어가며
Triton의 tl.histogram 연산은 GPU에서 히스토그램을 계산한다. 기존 구현은 warp 내에서 ballot + popcount로 warp-level 히스토그램을 만든 뒤 shared memory에서 합산하는 복잡한 알고리즘이었다. 이 PR은 이를 간단한 atomic add 방식으로 교체하여 코드를 대폭 단순화하고 성능도 개선한다.
핵심 코드 분석
Before
// 약 90줄의 warp-level histogram + cross-warp reduction
SmallVector<Value> computeWarpLevelHistogram(...) {
// ballot for each bit of bin index (log2(num_bins) ballots)
for (int j = 0; j < numBits; ++j) {
Value bitSet = b.and_(value, b.i32_val(1 << j));
Value cmp = b.icmp_ne(bitSet, zero);
Value bit = targetInfo.ballot(rewriter, loc, ...);
ballotBits.push_back(bit);
}
// XOR + AND to get indicator for each bin
for (int k = 0; k < warpLevelHistogram.size(); k++) {
// ... bitwise magic to identify which elements match bin k
warpLevelHistogram[k] = b.add(warpLevelHistogram[k], bitCount);
}
}
warp 크기 이상의 bin이 필요하고, 복잡한 bitwise 연산이 필요했다.
After
static void atomicAddOne(Value ptr, Location loc, ConversionPatternRewriter &rewriter) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
LLVM::AtomicRMWOp::create(rewriter, loc, LLVM::AtomicBinOp::add, ptr,
b.i32_val(1), LLVM::AtomicOrdering::monotonic);
}
// 각 element에 대해 범위 체크 후 atomic add
Value numBinsValue = b.i32_val(numBins);
for (int i = 0; i < srcValues.size(); ++i) {
Value updatePred = b.icmp_ult(srcValues[i], numBinsValue);
if (!maskValues.empty())
updatePred = b.and_(updatePred, maskValues[i]);
// if (pred) atomicAddOne(shared_mem[srcValues[i]])
auto [prevBlock, ifBlock, thenBlock] = createIfBlock(rewriter, loc, updatePred);
rewriter.setInsertionPointToStart(ifBlock);
Value sharedMemPtr = b.gep(..., srcValues[i]);
atomicAddOne(sharedMemPtr, loc, rewriter);
}
각 값에 대해 범위를 체크하고(icmp_ult), 유효하면 shared memory의 해당 bin에 atomic add 1을 수행한다. NVIDIA에서는 이것이 ATOMS.POPC.INC 명령어로 컴파일된다.
왜 이게 좋은가
- 코드 단순화: 약 90줄의 복잡한 ballot/popcount 로직이 20줄의 atomic add로 교체되었다.
- 범위 외 값 처리:
icmp_ult로 bin 범위 밖의 값을 predicate-out하여 정확성이 향상되었다. - num_bins 제약 완화: 기존에는
numBins >= numThreadsPerWarp패딩이 필요했지만 이제 불필요하다.
정리
+49/-107로 코드가 절반 이상 줄었고, 테스트에서 실제 SASS에 ATOMS.POPC.INC가 포함되는지 검증한다. 더 단순한 알고리즘이 하드웨어 지원 덕분에 더 빠를 수 있다는 좋은 사례다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글