본문으로 건너뛰기

[Triton] tl.histogram을 ATOMS.POPC.INC로 변경

들어가며

Triton의 tl.histogram 연산은 GPU에서 히스토그램을 계산한다. 기존 구현은 warp 내에서 ballot + popcount로 warp-level 히스토그램을 만든 뒤 shared memory에서 합산하는 복잡한 알고리즘이었다. 이 PR은 이를 간단한 atomic add 방식으로 교체하여 코드를 대폭 단순화하고 성능도 개선한다.

핵심 코드 분석

Before

// 약 90줄의 warp-level histogram + cross-warp reduction
SmallVector<Value> computeWarpLevelHistogram(...) {
    // ballot for each bit of bin index (log2(num_bins) ballots)
    for (int j = 0; j < numBits; ++j) {
        Value bitSet = b.and_(value, b.i32_val(1 << j));
        Value cmp = b.icmp_ne(bitSet, zero);
        Value bit = targetInfo.ballot(rewriter, loc, ...);
        ballotBits.push_back(bit);
    }
    // XOR + AND to get indicator for each bin
    for (int k = 0; k < warpLevelHistogram.size(); k++) {
        // ... bitwise magic to identify which elements match bin k
        warpLevelHistogram[k] = b.add(warpLevelHistogram[k], bitCount);
    }
}

warp 크기 이상의 bin이 필요하고, 복잡한 bitwise 연산이 필요했다.

After

static void atomicAddOne(Value ptr, Location loc, ConversionPatternRewriter &rewriter) {
    auto b = TritonLLVMOpBuilder(loc, rewriter);
    LLVM::AtomicRMWOp::create(rewriter, loc, LLVM::AtomicBinOp::add, ptr,
                              b.i32_val(1), LLVM::AtomicOrdering::monotonic);
}

// 각 element에 대해 범위 체크 후 atomic add
Value numBinsValue = b.i32_val(numBins);
for (int i = 0; i < srcValues.size(); ++i) {
    Value updatePred = b.icmp_ult(srcValues[i], numBinsValue);
    if (!maskValues.empty())
        updatePred = b.and_(updatePred, maskValues[i]);
    // if (pred) atomicAddOne(shared_mem[srcValues[i]])
    auto [prevBlock, ifBlock, thenBlock] = createIfBlock(rewriter, loc, updatePred);
    rewriter.setInsertionPointToStart(ifBlock);
    Value sharedMemPtr = b.gep(..., srcValues[i]);
    atomicAddOne(sharedMemPtr, loc, rewriter);
}

각 값에 대해 범위를 체크하고(icmp_ult), 유효하면 shared memory의 해당 bin에 atomic add 1을 수행한다. NVIDIA에서는 이것이 ATOMS.POPC.INC 명령어로 컴파일된다.

왜 이게 좋은가

  • 코드 단순화: 약 90줄의 복잡한 ballot/popcount 로직이 20줄의 atomic add로 교체되었다.
  • 범위 외 값 처리: icmp_ult로 bin 범위 밖의 값을 predicate-out하여 정확성이 향상되었다.
  • num_bins 제약 완화: 기존에는 numBins >= numThreadsPerWarp 패딩이 필요했지만 이제 불필요하다.

정리

+49/-107로 코드가 절반 이상 줄었고, 테스트에서 실제 SASS에 ATOMS.POPC.INC가 포함되는지 검증한다. 더 단순한 알고리즘이 하드웨어 지원 덕분에 더 빠를 수 있다는 좋은 사례다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글