본문으로 건너뛰기

[triton] AMD Canonicalize Pointers에서 arith.select의 비대칭 fat pointer 처리 강화

PR 링크: triton-lang/triton#8779 상태: Merged | 변경: +76 / -6

들어가며

Triton의 AMD 백엔드에는 포인터를 base + offset 형태의 "fat pointer"로 정규화하는 CanonicalizePointers 패스가 있습니다. 이 패스는 32비트 정수 연산으로 오프셋을 처리하여 64비트 포인터 연산을 줄이는 최적화를 수행합니다. 그러나 arith.select 연산에서 한쪽 인자만 fat pointer로 변환된 경우(예: pointer_range = 32가 한쪽에만 설정된 경우), 기존 코드가 크래시하거나 잘못된 결과를 생성하는 문제가 있었습니다.

핵심 코드 분석

Before:

LogicalResult matchAndRewrite_(arith::SelectOp selectOp, ...) {
  if (adaptor.getTrueValue().size() != 2 ||
      adaptor.getFalseValue().size() != 2) {
    assert(adaptor.getTrueValue().size() == adaptor.getFalseValue().size());
    return success();  // 둘 다 변환 안 된 경우만 처리
  }
  // 둘 다 fat pointer인 경우만 처리
  ValueRange fatPtrFalse = adaptor.getFalseValue();
  ValueRange fatPtrTrue = adaptor.getTrueValue();
  // select(base_t, base_f), select(offset_t, offset_f)
}

After:

LogicalResult matchAndRewrite_(arith::SelectOp selectOp, ...) {
  ValueRange fatPtrFalse = adaptor.getFalseValue();
  ValueRange fatPtrTrue = adaptor.getTrueValue();

  if (fatPtrTrue.size() == 1 && fatPtrFalse.size() == 1)
    return success();  // 둘 다 미변환: 패스

  if (fatPtrTrue.size() != 2 || fatPtrFalse.size() != 2) {
    // 비대칭: 한쪽만 fat pointer인 경우
    Value trueOp, falseOp;
    if (fatPtrTrue.size() == 2) {
      trueOp = tt::AddPtrOp::create(rewriter, loc,
                selectOp.getType(), fatPtrTrue[0], fatPtrTrue[1]);
    } else {
      trueOp = fatPtrTrue[0];
    }
    if (fatPtrFalse.size() == 2) {
      falseOp = tt::AddPtrOp::create(rewriter, loc,
                selectOp.getType(), fatPtrFalse[0], fatPtrFalse[1]);
    } else {
      falseOp = fatPtrFalse[0];
    }
    auto newSelectOp = arith::SelectOp::create(rewriter, loc,
        selectOp.getType(), selectOp.getCondition(), trueOp, falseOp);
    rewriter.replaceOp(selectOp, newSelectOp);
    return success();
  }
  // 둘 다 fat pointer: 기존 로직 유지
}

핵심 변경은 fat pointer가 비대칭적으로 적용된 경우의 처리입니다. 한쪽이 (base, offset) 형태이고 다른 쪽은 원래 포인터 그대로인 경우, fat pointer 쪽을 AddPtrOp으로 다시 합쳐서 일반 포인터로 되돌린 후 select를 수행합니다.

왜 이게 좋은가

이 수정은 방어적 프로그래밍의 좋은 사례입니다. 기존 코드는 select의 양쪽 인자가 항상 동일한 형태(둘 다 변환되거나 둘 다 미변환)라고 가정했지만, tt.pointer_range = 32 속성이 한쪽에만 설정된 실제 워크로드에서 이 가정이 깨졌습니다. offset이 같다고 암묵적으로 가정하는 것 역시 정적으로 안전하지 않은 코드였습니다. 이제 모든 조합을 명시적으로 처리하므로 컴파일러 크래시가 방지됩니다.

정리

  • arith.select에서 비대칭 fat pointer 조합 처리 추가
  • 한쪽만 fat pointer인 경우 AddPtrOp으로 재합성 후 select 수행
  • 크래시 방지 및 정확한 코드 생성 보장
  • 테스트 케이스(_scalar_select) 추가

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글