[Triton] WGMMA wait op의 출력 constraint 타입별 분기 수정
PR 링크: triton-lang/triton#8579 상태: Merged | 변경: +36 / -3
들어가며
WGMMA wait 명령어는 비동기 WGMMA가 완료될 때까지 기다리면서, 결과 레지스터를 inline assembly의 출력 constraint로 묶어준다. 기존에는 모든 레지스터에 대해 =f(float32) 또는 =r(int32) constraint만 사용했는데, f16 같은 16비트 타입에 =r을 사용하면 LLVM이 값을 32비트로 확장하는 cvt 명령어를 추가로 생성했다.
핵심 코드 분석
Before
Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = cast<LLVM::LLVMStructType>(op.getType());
uint32_t numOutputRegs = outputStructType.getBody().size();
std::string output =
outputStructType.getBody().front().isF32() ? "=f" : "=r";
return Constraints(numOutputRegs, output);
}
모든 필드에 동일한 constraint를 사용한다.
After
Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
auto outputStructType = cast<LLVM::LLVMStructType>(op.getType());
Constraints constraints;
mlir::DataLayout dl(op->getParentOfType<mlir::ModuleOp>());
for (auto ty : outputStructType.getBody()) {
auto bitwidth = dl.getTypeSizeInBits(ty);
switch (bitwidth) {
case 64: c = "=l"; break;
case 32: c = ty.isF32() ? "=f" : "=r"; break;
case 16: c = "=h"; break;
default: llvm::report_fatal_error(...);
}
constraints.push_back(c);
}
return constraints;
}
각 필드의 bitwidth에 따라 올바른 constraint를 선택한다.
왜 이게 좋은가
- 불필요한 cvt 제거: f16에
=hconstraint를 사용하면 LLVM이 16비트 레지스터를 직접 사용하여, 32비트 확장 명령어가 사라진다. - ptxas 최적화 가능: no-op
mov.b32시퀀스가 남더라도 ptxas가 이를 제거할 수 있다. - 타입 안전성: 64비트(
=l), 32비트(=f/=r), 16비트(=h) 각각에 정확한 PTX constraint를 매핑한다.
정리
Inline assembly의 constraint는 코드 생성 품질에 직접적인 영향을 미친다. 타입에 맞지 않는 constraint는 컴파일러가 불필요한 타입 변환을 삽입하는 원인이 된다. 이 PR은 struct의 각 필드별로 올바른 constraint를 선택하여 이 문제를 해결한다.
참고 자료
이 글은 AI 도구의 도움을 받아 작성되었습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] torch.compile로 Qwen Vision 모델 속도 향상
- 현재글 : [Triton] WGMMA wait op의 출력 constraint 타입별 분기 수정
- 다음글 [Triton] AMD amdgpu.async_wait Op 도입으로 비동기 트랜잭션 의미론 명확화
댓글