본문으로 건너뛰기

[Triton] WGMMA wait op의 출력 constraint 타입별 분기 수정

PR 링크: triton-lang/triton#8579 상태: Merged | 변경: +36 / -3

들어가며

WGMMA wait 명령어는 비동기 WGMMA가 완료될 때까지 기다리면서, 결과 레지스터를 inline assembly의 출력 constraint로 묶어준다. 기존에는 모든 레지스터에 대해 =f(float32) 또는 =r(int32) constraint만 사용했는데, f16 같은 16비트 타입에 =r을 사용하면 LLVM이 값을 32비트로 확장하는 cvt 명령어를 추가로 생성했다.

핵심 코드 분석

Before

Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
  auto outputStructType = cast<LLVM::LLVMStructType>(op.getType());
  uint32_t numOutputRegs = outputStructType.getBody().size();
  std::string output =
      outputStructType.getBody().front().isF32() ? "=f" : "=r";
  return Constraints(numOutputRegs, output);
}

모든 필드에 동일한 constraint를 사용한다.

After

Constraints getOutputConstraints(ttn::WGMMAWaitGroupOp op) const {
  auto outputStructType = cast<LLVM::LLVMStructType>(op.getType());
  Constraints constraints;
  mlir::DataLayout dl(op->getParentOfType<mlir::ModuleOp>());
  for (auto ty : outputStructType.getBody()) {
    auto bitwidth = dl.getTypeSizeInBits(ty);
    switch (bitwidth) {
    case 64: c = "=l"; break;
    case 32: c = ty.isF32() ? "=f" : "=r"; break;
    case 16: c = "=h"; break;
    default: llvm::report_fatal_error(...);
    }
    constraints.push_back(c);
  }
  return constraints;
}

각 필드의 bitwidth에 따라 올바른 constraint를 선택한다.

왜 이게 좋은가

  • 불필요한 cvt 제거: f16에 =h constraint를 사용하면 LLVM이 16비트 레지스터를 직접 사용하여, 32비트 확장 명령어가 사라진다.
  • ptxas 최적화 가능: no-op mov.b32 시퀀스가 남더라도 ptxas가 이를 제거할 수 있다.
  • 타입 안전성: 64비트(=l), 32비트(=f/=r), 16비트(=h) 각각에 정확한 PTX constraint를 매핑한다.

정리

Inline assembly의 constraint는 코드 생성 품질에 직접적인 영향을 미친다. 타입에 맞지 않는 constraint는 컴파일러가 불필요한 타입 변환을 삽입하는 원인이 된다. 이 PR은 struct의 각 필드별로 올바른 constraint를 선택하여 이 문제를 해결한다.

참고 자료


이 글은 AI 도구의 도움을 받아 작성되었습니다.

댓글

관련 포스트

PR Analysis 의 다른글