[triton] AMD GPU Descriptor Encoding 최적화 패스 추가
PR 링크: triton-lang/triton#9792 상태: Merged | 변경: +568 / -194
들어가며
Tensor descriptor를 사용한 메모리 접근(gather/scatter/load)은 shared memory encoding이 성능에 큰 영향을 미칩니다. 이 PR은 AMD GFX1250 타겟에 OptimizeDescriptorEncoding 패스를 추가하여, descriptor 연산에 padded shared encoding을 자동으로 적용합니다.
핵심 코드 분석
1. PaddedSharedEncoding을 DescriptorMemoryLayouts에 통합
Before:
// SwizzledSharedEncoding만 처리
if (auto swizEnc = dyn_cast<ttg::SwizzledSharedEncodingAttr>(encoding)) {
// CGA layout 업데이트
}
// 다른 encoding 유형은 에러
constexpr auto msg = "Internal Error: Unhandled tensor descriptor encoding";
After:
if (auto paddedEnc = dyn_cast<ttg::PaddedSharedEncodingAttr>(encoding)) {
auto existingCga = paddedEnc.getCGALayout();
if (!existingCga)
return paddedEnc;
auto newCgaEnc =
ttg::updateCGALayoutForShape(cgaLayout, tensorType.getShape());
// interval padding 보존하면서 CGA layout 업데이트
SmallVector<std::pair<unsigned, unsigned>> intervalPads;
for (auto [interval, padding] :
llvm::zip(paddedEnc.getIntervals(), paddedEnc.getPaddings()))
intervalPads.push_back({interval, padding});
return ttg::PaddedSharedEncodingAttr::get(ctx, intervalPads, order,
shape, newCgaEnc);
}
2. Dot operand에 따른 padded encoding 전파
테스트에서 descriptor load가 dot 연산의 operand로 사용될 때, 각 operand에 맞는 padded encoding이 자동으로 추론됩니다:
// CHECK-DAG: #[[$PADDED_A:.*]] = #ttg.padded_shared<[128:+8] {
// CHECK-DAG: #[[$PADDED_B:.*]] = #ttg.padded_shared<[128:+16] {
// CHECK: tt.make_tensor_descriptor {{.*}} : <f16>, <tensor<512x32xf16, #[[$PADDED_A]]>>
// CHECK: tt.make_tensor_descriptor {{.*}} : <f16>, <tensor<32x64xf16, #[[$PADDED_B]]
A 행렬(512x32)은 8 element padding, B 행렬(32x64)은 16 element padding이 적용됩니다.
왜 이게 좋은가
- Bank conflict 감소: Padded encoding은 shared memory bank conflict을 방지하여 데이터 접근 성능을 향상시킵니다.
- 자동 최적화: 사용자가 직접 encoding을 지정할 필요 없이 컴파일러가 최적의 padding을 계산합니다.
- while loop 전파: Descriptor가 제어 흐름(while loop)을 통해 전달될 때도 encoding이 올바르게 전파됩니다.
정리
AMD GFX1250에서 tensor descriptor 연산의 shared memory 성능을 자동으로 최적화하는 패스입니다. Padded encoding 선택, CGA layout 업데이트, 제어 흐름 전파까지 포괄적으로 처리합니다.
참고 자료
이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [CPython] SyntaxError 재초기화 시 메모리 누수 수정
- 현재글 : [triton] AMD GPU Descriptor Encoding 최적화 패스 추가
- 다음글 [Ray] LLM 추론 벤치마크 엔진에 동시성 모드와 일정 QPS 모드 추가
댓글