본문으로 건너뛰기

[triton] Gluon TMA Op Verifier 강화 및 Illegal Instruction Sanitize 모드 추가

PR 링크: triton-lang/triton#9112 상태: Merged | 변경: +218 / -102

들어가며

Triton의 TMA(Tensor Memory Accelerator) 연산들은 descriptor와 tensor 사이의 타입 일치가 필수적입니다. 기존 verifier는 shape이 정확히 일치해야 했는데, 이는 rank reduction이나 reshape를 허용하지 않는 문제가 있었습니다. 이 PR은 element 총 수 기반의 검증으로 변경하고, gather/scatter의 verifier도 공통 함수로 통합합니다.

핵심 코드 분석

Before - shape 완전 일치 요구

static LogicalResult verifyDescriptorLoadStoreType(
    Operation *op, TensorDescType desc, RankedTensorType tensor) {
  if (blockShape == tensorShape &&
      block.getElementType() == tensor.getElementType())
    return success();
  return op->emitOpError("tensor descriptor block and tensor types must match");
}

After - element 수 기반 검증

LogicalResult verifyDescriptorLoadStoreOp(Operation *op, TensorDescType desc,
                                          ShapedType tensor) {
  unsigned blockNumels = product(blockShape);
  unsigned tensorNumels = product(tensorShape);
  if (blockNumels != tensorNumels) {
    return op->emitOpError("descriptor block and tensor must have the same "
                           "number of elements");
  }
  return success();
}

Gather/Scatter verifier 통합

// DescriptorGatherOp::verifyResultType 제거
// 공통 함수로 통합
LogicalResult verifyGatherScatterOp(Operation *op, ShapedType blockType,
                                    ShapedType resultType,
                                    ShapedType indicesType);

왜 이게 좋은가

  1. 유연한 검증: rank reduction이나 reshape를 허용하면서도 element 수 일치를 보장합니다.
  2. 코드 통합: gather/scatter, load/store, reduce 모두 같은 verifier 함수를 사용합니다.
  3. 더 나은 에러 메시지: element 수와 element type을 각각 구체적으로 보고합니다.

정리

TMA 연산의 verifier를 element 수 기반으로 완화하면서 더 구체적인 에러 메시지를 제공하도록 개선한 PR입니다. gather/scatter/load/store/reduce의 verifier 코드 중복도 해소했습니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글