[triton] tl.cat 연산을 permute+reshape+join으로 재구현하여 결정적(deterministic) 동작 보장
PR 링크: triton-lang/triton#8769 상태: Merged | 변경: +40 / -187
들어가며
Triton의 tl.cat 연산은 두 텐서를 연결(concatenate)하는 기본 연산입니다. 기존 구현은 전용 CatOp을 사용했는데, 이 방식은 요소 순서가 비결정적(non-deterministic)으로 재배치될 수 있는 문제가 있었습니다. 이번 PR은 CatOp을 완전히 제거하고, 기존의 permute, reshape, join 연산 조합으로 대체하여 tl.cat이 항상 결정적 결과를 생성하도록 개선합니다.
핵심 코드 분석
Before: 전용 CatOp 정의 및 LLVM 변환
// TritonOps.td - CatOp 정의
def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
SameTypeOperands,
SameOperandsAndResultElementType]> {
let summary = "concatenate 2 tensors";
let arguments = (ins TT_Tensor:$lhs, TT_Tensor:$rhs);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)";
}
// ViewOpToLLVM.cpp - CatOp의 LLVM 변환
struct CatOpConversion : public ConvertOpToLLVMPattern<CatOp> {
LogicalResult matchAndRewrite(CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto lhsVals = unpackLLElements(loc, adaptor.getLhs(), rewriter);
auto rhsVals = unpackLLElements(loc, adaptor.getRhs(), rewriter);
SmallVector<Value> retVals;
for (Value v : lhsVals) retVals.push_back(v);
for (Value v : rhsVals) retVals.push_back(v);
// 단순 연결 - 순서 보장 없음
Value ret = packLLElements(loc, typeConverter, retVals, rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
};
After: CatOp 제거, 기존 연산 조합 사용
// Dialect.cpp - isExpensiveCat 함수 제거
// CatOp 관련 비용 분석 함수도 함께 제거
// bool isExpensiveCat(CatOp cat, Attribute targetEncoding) { ... } 삭제
// InferLayoutUtils.cpp - CatOp을 layout-varying 목록에서 제거
bool encodingsMayVary(Operation *op) {
return isa<triton::JoinOp, triton::SplitOp, triton::ReshapeOp,
triton::TransOp>(op);
// CatOp 제거됨
}
변경 후 tl.cat은 내부적으로 permute -> reshape -> join -> reshape -> permute 순서로 구현됩니다. linear layout 덕분에 성능 저하가 거의 없습니다.
왜 이게 좋은가
- 결정적 결과 보장: 동일 입력에 대해 항상 동일한 출력 순서를 보장하여 디버깅과 테스트가 용이해집니다.
- 코드 단순화: 전용 Op과 그에 따른 변환 패턴, 비용 분석 함수 등 187줄의 코드가 제거되었습니다.
- 임의 차원 지원: 기존에는 제한적이던 연결 차원을 모든 차원으로 확장할 수 있습니다.
can_reorder옵션 폐지: 비결정적 동작이 사라지면서 이 옵션이 불필요해졌습니다.
정리
이 PR은 "전용 연산을 만들기보다 기존 연산의 조합으로 해결하라"는 컴파일러 설계 원칙을 잘 보여줍니다. 전용 CatOp은 최적화 기회가 제한적이고 유지보수 부담이 컸지만, 기존 연산 조합은 각 연산의 최적화가 자연스럽게 적용됩니다. 단, 이 변경은 이후 회귀 이슈(PR #8878)로 한 번 revert되었다가 재적용되었습니다.
참고 자료
- triton-lang/triton#8769
- triton-lang/triton#8878 (Revert PR)
이 글은 AI(Claude)의 도움을 받아 작성되었으며, PR의 실제 diff를 기반으로 분석한 내용입니다.
관련 포스트
- [triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화
- [triton] Warp Specialization: 데이터 플로우 그래프 기반의 개선된 파티션 스케줄링 패스
- [Triton] WarpSpecializePartitionsOp에 명시적 캡처 전달 — IR 구조 정합성 개선
- [triton] CGAEncodingAttr::getDefault를 get1CTALayout/get1DLayout로 분리하여 multi-CTA 지원
- [Triton] AMD scf.if else 분기 누락 버그 수정 — deduceMinCountBetweeOps
PR Analysis 의 다른글
- 이전글 [ultralytics] Ultralytics 8.3.229: COCO Segmentation 평가 300% 가속화 분석
- 현재글 : [triton] tl.cat 연산을 permute+reshape+join으로 재구현하여 결정적(deterministic) 동작 보장
- 다음글 [Loki] 인메모리 레이트 트래커로 UpdateRates RPC 구현
댓글