[Triton] AMD ds_read_tr의 b8/b16 타입 lowering 리워크
들어가며
AMD GPU의 ds_read_tr 명령어는 LDS(Local Data Share)에서 데이터를 읽으면서 transpose를 수행하는 전용 명령어다. 기존 lowering은 요청 출력의 LinearLayout(LL)을 변환하는 방식이었으나, b8/b16 타입에서 복잡한 레이아웃 변환이 필요했다. 이 PR은 ldmatrix와 유사하게 명령어 자체의 LL을 직접 모델링하는 방식으로 리워크한다.
핵심 코드 분석
Before
// 기존: 90줄의 chooseLLDsReadTrLayout - output LL의 prefix를 rotate
auto rotatePrefixes = [](BaseTy ®Base, std::size_t numReg,
BaseTy &laneBase, std::size_t numLane) {
BaseTy baseUnit(laneBase.begin(), laneBase.begin() + numLane);
llvm::append_range(baseUnit, ...);
std::copy(baseUnit.begin(), baseUnit.begin() + numReg, regBase.begin());
std::copy(baseUnit.begin() + numReg, baseUnit.end(), laneBase.begin());
};
출력 LL의 register/lane basis를 "rotate"하는 복잡한 로직이었고, b8에서는 추가 swap이 필요했다.
After
// 새: 명령어의 LL을 직접 모델링 + PaddedSharedEncoding 지원
auto paddedEnc = dyn_cast<PaddedSharedEncodingAttr>(srcTy.getEncoding());
LinearLayout cvtDstLL = LinearLayout::empty();
if (paddedEnc) {
const auto &sharedLL = paddedEnc.getLinearComponent();
cvtDstLL = toLinearLayout(dstTy).invertAndCompose(sharedLL);
} else {
auto sharedLL = toLinearLayout(srcTy);
cvtDstLL = toLinearLayout(dstTy).invertAndCompose(sharedLL);
}
chooseLLDsReadTrLayout 함수(90줄)가 삭제되고, b8/b16 경로가 lowerDsReadTr라는 새 함수로 통합되었다. FP4 packed along K 타입도 동일한 경로를 사용한다.
왜 이게 좋은가
- 명령어 모델링: 출력을 변환하는 대신 명령어의 동작을 직접 모델링하여 정확성이 향상되었다.
- 대폭 코드 감소: +322/-727로 약 400줄이 줄었다. 복잡한 rotate/swap 로직이 사라졌다.
- 확장성: PaddedSharedEncoding 지원이 자연스럽게 추가되었다.
정리
컴파일러에서 "명령어의 출력을 역변환"하는 대신 "명령어의 동작을 정방향으로 모델링"하는 것이 더 단순하고 정확한 결과를 가져온다는 설계 원칙을 보여주는 PR이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글