[Triton] ext slice rematerialization 견고성 개선 — 실패 시 원본 보존
PR 링크: triton-lang/triton#9019 상태: Merged | 변경: +24 / -36
들어가며
Triton의 RemoveLayoutConversions 패스는 불필요한 레이아웃 변환을 제거하여 성능을 향상시킨다. 이 중 hoistConvertOnTopOfExtOrBroadcast 함수는 ext(확장) 연산의 backward slice를 rematerialization 후보로 탐색한다.
문제는 탐색에 실패했을 때 이미 수정된 slice와 layout 데이터가 원래 상태로 복원되지 않는다는 것이었다. 이 PR은 탐색을 복사본 위에서 수행하고, 성공 시에만 원본을 갱신하는 방식으로 견고성을 높인다.
핵심 코드 분석
Before: 원본을 직접 수정
LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
LogicalResult result = getConvertBackwardSlice(
root, rootEncoding, slice, layout, stopPropagation);
if (result.failed() || slice.empty())
return failure();
// ...
}
getConvertBackwardSlice가 slice와 layout에 직접 데이터를 추가한 뒤 나중에 실패하면, 이미 추가된 데이터가 남아서 호출자의 상태를 오염시킨다.
After: 복사본 위에서 작업 후 성공 시 이동
LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &sliceArg,
DenseMap<Value, Attribute> &layoutArg,
std::function<bool(Operation *)> stopPropagation) {
// Operate on copies of the input, we do not want to modify them
// unless we have succeeded.
auto slice = sliceArg;
auto layout = layoutArg;
LogicalResult result = getConvertBackwardSlice(
root, rootEncoding, slice, layout, stopPropagation);
if (result.failed() || slice.empty())
return failure();
// ... 추가 검증 ...
sliceArg = std::move(slice);
layoutArg = std::move(layout);
return success();
}
또한 hoistConvertOnTopOfExtOrBroadcast에서 ext slice 탐색 코드도 대폭 단순화되었다. 기존에 20줄 이상의 중복 검증 로직이 getRematerializableSlice의 보장 덕분에 불필요해졌다.
// Before: 20줄 이상의 수동 충돌 검사
for (auto [val, enc] : tempLayout) {
auto preexistingLayout = layout.find(val);
if (preexistingLayout != layout.end() &&
preexistingLayout->second != enc) {
result = failure();
break;
}
}
// After: getRematerializableSlice가 실패 시 원본을 보존하므로
// 충돌 검사가 자연스럽게 처리됨
왜 이게 좋은가
- 트랜잭션 의미론: "성공하면 전부 적용, 실패하면 아무것도 바뀌지 않음"이라는 원칙을 코드 레벨에서 보장한다.
- 코드 단순화: 호출자가 실패 후 수동 정리를 할 필요가 없어져 20줄 이상의 방어 코드가 제거되었다.
- 잠재적 버그 방지: 이전에는 탐색 실패 시 부분적으로 오염된 상태가 후속 처리에 영향을 줄 수 있었다.
정리
이 PR은 Triton 레이아웃 최적화 패스의 getRematerializableSlice 함수에 copy-on-success 패턴을 적용하여, 실패 시 원본 데이터가 보존되도록 개선했다. 결과적으로 더 안전하고 간결한 코드가 되었다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Triton] Proton 프로파일러 tensor descriptor 및 two-CTA 모드 테스트 추가
- 현재글 : [Triton] ext slice rematerialization 견고성 개선 — 실패 시 원본 보존
- 다음글 [vllm] --max-model-len auto: GPU 메모리에 맞춘 자동 컨텍스트 길이 설정
댓글