본문으로 건너뛰기

[triton] Triton 테스트 속도 혁신: Python 루프에서 벡터화된 NumPy로의 전환

PR 링크: triton-lang/triton#10016 상태: Merged | 변경: +None / -None

들어가며

Triton은 고성능 딥러닝 모델을 위한 커스텀 커널을 쉽게 작성할 수 있도록 돕는 컴파일러입니다. 하지만 모든 소프트웨어 프로젝트와 마찬가지로, Triton 역시 테스트 스위트의 성능 저하 문제를 겪을 수 있습니다. 특히, 일부 테스트 케이스는 실행에 200초까지 소요되어 전체 개발 및 검증 주기를 더디게 만들었습니다. 이 PR은 이러한 느린 테스트의 근본 원인을 파악하고, Python 스칼라 연산을 벡터화된 NumPy 연산으로 대체함으로써 테스트 실행 시간을 획기적으로 단축하는 것을 목표로 합니다.

이 글에서는 해당 PR의 코드 변경 사항을 분석하고, 왜 이러한 변경이 성능 향상으로 이어졌는지, 그리고 이 최적화가 주는 일반적인 교훈은 무엇인지 살펴보겠습니다.

코드 분석

이번 PR의 핵심 변경 사항은 크게 두 부분으로 나눌 수 있습니다. 첫째, Triton IR(Intermediate Representation)을 다루는 C++ 코드에서 일부 수학 연산의 구현 방식을 개선했습니다. 둘째, 테스트 실행 시간을 관리하는 Python 스크립트의 로직을 수정했습니다.

1. lib/Dialect/TritonInstrument/Transforms/FpSanitizer.cpp 변경 사항

이 파일은 부동 소수점 연산의 정확성을 검증하는 FpSanitizer 기능을 구현하고 있습니다. PR에서는 주로 fpsanExp2FromInt, fpsanExp, fpsanCosSinPayload 함수 내부의 연산 로직을 최적화했습니다. 이전에는 이러한 함수들이 Python 스타일의 스칼라 연산을 LLVM IR로 직접 변환하는 방식을 사용했습니다. 이는 LLVM IR 수준에서는 루프를 통해 스칼라 값들을 순차적으로 처리하게 되어 비효율적이었습니다.

fpsanExp2FromI32 (이후 fpsanExp2FromInt로 변경) 함수 최적화:

이 함수는 정수 값을 입력받아 exp2 연산을 수행하고 그 결과를 부동 소수점 타입으로 변환합니다. 이전 구현은 32번의 반복을 통해 비트별로 연산을 수행했습니다.

Before:

-  Value y = one;
-  for (int i = 0; i < 32; ++i) {
-    y = arith::MulIOp::create(rewriter, loc, y, y);
-    auto bit = getIntConstantLike(rewriter, loc, xI.getType(),
-                                  int64_t(1ull << (31 - i)));
-    auto masked = arith::AndIOp::create(rewriter, loc, xI, bit);
-    auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
-                                        masked, zero);
-    auto factor = arith::SelectOp::create(rewriter, loc, isZero, one, c);
-    y = arith::MulIOp::create(rewriter, loc, y, factor);
-  }
-
-  return unembedToFloat(rewriter, loc, y, floatTy);

After:

+  auto lower = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
+  auto upper = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(bitWidth));
+  auto step = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1));
+  auto topBit = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(bitWidth - 1));
+  auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, one);
+  rewriter.setInsertionPointToStart(loop.getBody());
+
+  Value i = loop.getInductionVar();
+  Value y = loop.getRegionIterArgs()[0];
+  y = arith::MulIOp::create(rewriter, loc, y, y);
+  Value bitIndex = arith::SubIOp::create(rewriter, loc, rewriter.getI32Type(), topBit, i);
+  Value shift = castScalarIntToIntLike(rewriter, loc, bitIndex, xI.getType());
+  Value bit = arith::ShLIOp::create(rewriter, loc, one, shift);
+  auto masked = arith::AndIOp::create(rewriter, loc, xI, bit);
+  auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, masked, zero);
+  auto factor = arith::SelectOp::create(rewriter, loc, isZero, one, c);
+  y = arith::MulIOp::create(rewriter, loc, y, factor);
+  scf::YieldOp::create(rewriter, loc, y);
+  rewriter.setInsertionPointAfter(loop);
+
+  return unembedToFloat(rewriter, loc, loop.getResult(0), floatTy);

변경 후에는 scf.for 루프를 사용하여 더 구조화된 방식으로 연산을 수행합니다. 이는 LLVM IR 수준에서 더 효율적인 코드 생성을 가능하게 합니다. 또한, castScalarIntToIntLike와 같은 새로운 헬퍼 함수가 추가되어 타입 변환을 더 유연하게 처리합니다.

fpsanCosSinPayload 함수 최적화:

이 함수는 코사인과 사인 연산을 위한 페이로드 계산을 담당합니다. 이전에는 이 역시 스칼라 연산의 반복으로 구현되었습니다.

Before:

-  Value c = one;
-  Value s = zero;
-  for (int bit = static_cast<int>(bitWidth) - 1; bit >= 0; --bit) {
-    Value cc = arith::MulIOp::create(rewriter, loc, c, c);
-    Value ss = arith::MulIOp::create(rewriter, loc, s, s);
-    Value cDouble = arith::SubIOp::create(rewriter, loc, cc, ss);
-    Value cs = arith::MulIOp::create(rewriter, loc, c, s);
-    Value sDouble = arith::MulIOp::create(rewriter, loc, two, cs);
-
-    Value ac = arith::MulIOp::create(rewriter, loc, a, cDouble);
-    Value bs = arith::MulIOp::create(rewriter, loc, b, sDouble);
-    Value cInc = arith::SubIOp::create(rewriter, loc, ac, bs);
-    Value as = arith::MulIOp::create(rewriter, loc, a, sDouble);
-    Value bc = arith::MulIOp::create(rewriter, loc, b, cDouble);
-    Value sInc = arith::AddIOp::create(rewriter, loc, as, bc);
-
-    auto bitMask = getUIntConstantLike(rewriter, loc, intTy, uint64_t{1} << bit);
-    auto masked = arith::AndIOp::create(rewriter, loc, xI, bitMask);
-    auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, masked, zero);
-    c = arith::SelectOp::create(rewriter, loc, isZero, cDouble, cInc);
-    s = arith::SelectOp::create(rewriter, loc, isZero, sDouble, sInc);
-  }
-
-  return {c, s};

After:

+  auto lower = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
+  auto upper = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(bitWidth));
+  auto step = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(1));
+  auto topBit = arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(bitWidth - 1));
+  SmallVector<Value> initArgs{one, zero};
+  auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, initArgs);
+  rewriter.setInsertionPointToStart(loop.getBody());
+
+  Value bit = loop.getInductionVar();
+  Value c = loop.getRegionIterArgs()[0];
+  Value s = loop.getRegionIterArgs()[1];
+  Value cc = arith::MulIOp::create(rewriter, loc, c, c);
+  Value ss = arith::MulIOp::create(rewriter, loc, s, s);
+  Value cDouble = arith::SubIOp::create(rewriter, loc, cc, ss);
+  Value cs = arith::MulIOp::create(rewriter, loc, c, s);
+  Value sDouble = arith::MulIOp::create(rewriter, loc, two, cs);
+
+  Value ac = arith::MulIOp::create(rewriter, loc, a, cDouble);
+  Value bs = arith::MulIOp::create(rewriter, loc, b, sDouble);
+  Value cInc = arith::SubIOp::create(rewriter, loc, ac, bs);
+  Value as = arith::MulIOp::create(rewriter, loc, a, sDouble);
+  Value bc = arith::MulIOp::create(rewriter, loc, b, cDouble);
+  Value sInc = arith::AddIOp::create(rewriter, loc, as, bc);
+
+  Value bitIndex = arith::SubIOp::create(rewriter, loc, rewriter.getI32Type(), topBit, bit);
+  Value shift = castScalarIntToIntLike(rewriter, loc, bitIndex, intTy);
+  Value bitMask = arith::ShLIOp::create(rewriter, loc, one, shift);
+  auto masked = arith::AndIOp::create(rewriter, loc, xI, bitMask);
+  auto isZero = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, masked, zero);
+  c = arith::SelectOp::create(rewriter, loc, isZero, cDouble, cInc);
+  s = arith::SelectOp::create(rewriter, loc, isZero, sDouble, sInc);
+  scf::YieldOp::create(rewriter, loc, ValueRange{c, s});
+  rewriter.setInsertionPointAfter(loop);
+
+  return {loop.getResult(0), loop.getResult(1)};

마찬가지로 scf.for 루프를 사용하여 더 효율적인 LLVM IR 생성을 유도합니다. 리뷰어 apgoucher는 Pawel의 MMA 에뮬레이션 루프가 scf::for를 사용하므로 이를 복사하는 것이 합리적이라고 언급했습니다. 또한, lezcano 리뷰어는 scf_to_cf 변환이 FpSanitizer 변환 이후에 수행되는지 확인했고, 그렇다면 scf를 사용하는 것이 올바르다고 확인했습니다.

2. python/test/conftest.py 변경 사항

이 파일은 pytest 테스트 실행을 위한 설정을 담당합니다. PR에서는 테스트 케이스의 수를 제한하는 로직이 추가되었습니다.

Before: (관련 없음 - 기존 로직)

After:

+def _top_level_test_key(item):
+    nodeid = item.nodeid
+    bracket = nodeid.find("[")
+    return nodeid if bracket == -1 else nodeid[:bracket]
+
+def _case_key(item):
+    return item.name
+
+def _sha256_hex(s: str) -> str:
+    return hashlib.sha256(s.encode("utf-8")).hexdigest()
+
+
def pytest_collection_modifyitems(config, items):
+    max_cases = config.getoption("--max-cases-per-test")
+    if max_cases <= 0:
+        return
+
+    groups = defaultdict(list)
+    for item in items:
+        groups[_top_level_test_key(item)].append(item)
+
+    kept = []
+    deselected = []
+    for group in groups.values():
+        ordered = sorted(group, key=lambda item: _sha256_hex(_case_key(item)))
+        kept.extend(ordered[:max_cases])
+        deselected.extend(ordered[max_cases:])
+
+    if deselected:
+        config.hook.pytest_deselected(items=deselected)
+
+    items[:] = kept

pytest_collection_modifyitems 함수는 테스트 수집 후 실행 전에 항목을 수정하는 pytest 훅입니다. 이 코드는 --max-cases-per-test 옵션을 통해 각 최상위 테스트 함수별로 실행할 최대 테스트 케이스 수를 제한합니다. 테스트 케이스를 SHA256 해시를 기준으로 정렬한 후 상위 max_cases 개수만 남기고 나머지는 비활성화(deselect)합니다. 이는 매우 많은 테스트 케이스를 가진 일부 테스트가 전체 실행 시간을 과도하게 차지하는 것을 방지하기 위한 조치입니다.

리뷰어 lezcano는 이 변경이 너무

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글