본문으로 건너뛰기

[Triton] ext slice rematerialization 견고성 개선 — 실패 시 원본 보존

PR 링크: triton-lang/triton#9019 상태: Merged | 변경: +24 / -36

들어가며

Triton의 RemoveLayoutConversions 패스는 불필요한 레이아웃 변환을 제거하여 성능을 향상시킨다. 이 중 hoistConvertOnTopOfExtOrBroadcast 함수는 ext(확장) 연산의 backward slice를 rematerialization 후보로 탐색한다.

문제는 탐색에 실패했을 때 이미 수정된 slicelayout 데이터가 원래 상태로 복원되지 않는다는 것이었다. 이 PR은 탐색을 복사본 위에서 수행하고, 성공 시에만 원본을 갱신하는 방식으로 견고성을 높인다.

핵심 코드 분석

Before: 원본을 직접 수정

LogicalResult LayoutRematerialization::getRematerializableSlice(
    OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
    DenseMap<Value, Attribute> &layout,
    std::function<bool(Operation *)> stopPropagation) {
  LogicalResult result = getConvertBackwardSlice(
      root, rootEncoding, slice, layout, stopPropagation);
  if (result.failed() || slice.empty())
    return failure();
  // ...
}

getConvertBackwardSliceslicelayout에 직접 데이터를 추가한 뒤 나중에 실패하면, 이미 추가된 데이터가 남아서 호출자의 상태를 오염시킨다.

After: 복사본 위에서 작업 후 성공 시 이동

LogicalResult LayoutRematerialization::getRematerializableSlice(
    OpOperand &root, Attribute rootEncoding, SetVector<Value> &sliceArg,
    DenseMap<Value, Attribute> &layoutArg,
    std::function<bool(Operation *)> stopPropagation) {
  // Operate on copies of the input, we do not want to modify them
  // unless we have succeeded.
  auto slice = sliceArg;
  auto layout = layoutArg;
  LogicalResult result = getConvertBackwardSlice(
      root, rootEncoding, slice, layout, stopPropagation);
  if (result.failed() || slice.empty())
    return failure();
  // ... 추가 검증 ...
  sliceArg = std::move(slice);
  layoutArg = std::move(layout);
  return success();
}

또한 hoistConvertOnTopOfExtOrBroadcast에서 ext slice 탐색 코드도 대폭 단순화되었다. 기존에 20줄 이상의 중복 검증 로직이 getRematerializableSlice의 보장 덕분에 불필요해졌다.

// Before: 20줄 이상의 수동 충돌 검사
for (auto [val, enc] : tempLayout) {
    auto preexistingLayout = layout.find(val);
    if (preexistingLayout != layout.end() &&
        preexistingLayout->second != enc) {
      result = failure();
      break;
    }
}

// After: getRematerializableSlice가 실패 시 원본을 보존하므로
// 충돌 검사가 자연스럽게 처리됨

왜 이게 좋은가

  1. 트랜잭션 의미론: "성공하면 전부 적용, 실패하면 아무것도 바뀌지 않음"이라는 원칙을 코드 레벨에서 보장한다.
  2. 코드 단순화: 호출자가 실패 후 수동 정리를 할 필요가 없어져 20줄 이상의 방어 코드가 제거되었다.
  3. 잠재적 버그 방지: 이전에는 탐색 실패 시 부분적으로 오염된 상태가 후속 처리에 영향을 줄 수 있었다.

정리

이 PR은 Triton 레이아웃 최적화 패스의 getRematerializableSlice 함수에 copy-on-success 패턴을 적용하여, 실패 시 원본 데이터가 보존되도록 개선했다. 결과적으로 더 안전하고 간결한 코드가 되었다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.

댓글

관련 포스트

PR Analysis 의 다른글