본문으로 건너뛰기

[Triton] Gluon Dialect verifier 강화 및 에러 메시지 개선

PR 링크: triton-lang/triton#8981 상태: Merged | 변경: +239 / -135

들어가며

Triton의 Gluon dialect에서 잘못된 파라미터가 전달되면 컴파일 시 illegal instruction이 생성되거나 런타임에 크래시가 발생할 수 있다. 이 PR은 verifier를 강화하여 잘못된 입력을 조기에 감지하고, 디버깅에 도움이 되는 에러 메시지를 제공한다.

핵심 코드 분석

NVMMASharedEncodingAttr 검증 추가

LogicalResult
NVMMASharedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
                                unsigned swizzlingByteWidth, bool transposed,
                                unsigned elementBitWidth, bool fp4Padded,
                                CGAEncodingAttr CGALayout) {
  if (elementBitWidth == 0)
    return emitError() << "elementBitWidth must be non-zero";
  return success();
}

getTMABlockShape를 fallible하게 변경

// Before: fatal error로 크래시
SmallVector<int64_t> getTMABlockShape(...) {
  // ...
  if (blockShape[contigDim] < contigDimSize) {
    llvm::report_fatal_error(
        "Block shape is too small for the swizzle byte size");
  }
}

// After: 에러를 반환하여 상위에서 처리 가능
FailureOr<SmallVector<int64_t>> getTMABlockShape(
    ArrayRef<int64_t> shapePerCTA, int elementBitWidth, int swizzleBytes,
    bool fp4Padded, bool isTransposed, bool packedSize,
    function_ref<InFlightDiagnostic()> emitError) {
  // ...
  if (blockShape[contigDim] < contigDimSize) {
    return emitError() << "block shape along the contiguous dimension "
                       << contigDim
                       << " is too small for the swizzle byte size "
                       << swizzleBytes;
  }
  return blockShape;
}

MemDescType 검증에 TMA block shape 체크 추가

LogicalResult MemDescType::verify(...) {
  if (auto enc = dyn_cast<NVMMASharedEncodingAttr>(encoding)) {
    SmallVector<int64_t> shapePerCTA(getShapePerCTA(enc, allocShape));
    auto blockShape = ArrayRef(shapePerCTA).take_back(enc.getRank());
    if (failed(getTMABlockShape(blockShape, enc.getElementBitWidth(),
                                enc.getSwizzlingByteWidth(), enc.getFp4Padded(),
                                enc.getTransposed(), /*packedSize=*/false,
                                emitError)))
      return failure();
  }
}

왜 이게 좋은가

  1. 조기 에러 감지: 잘못된 파라미터를 verifier 단계에서 잡아 illegal instruction 생성을 방지한다.
  2. 실행 가능한 에러 메시지: "Block shape is too small" 대신 구체적인 차원, 기대값, 실제값을 포함한 메시지를 제공한다.
  3. report_fatal_error 제거: LLVM fatal error 대신 FailureOr를 반환하여 상위 호출자가 적절히 처리할 수 있다.
  4. getTMABlockShape 공유: TritonNvidiaGPU에서 TritonGPU로 이동하여 verifier에서도 재사용 가능해졌다.

정리

이 PR은 Gluon dialect의 verifier를 강화하여 NVMMASharedEncoding 검증, TMA 함수 파라미터 검증, DotOpMMASmemLoader fallible 변환을 추가했다. 런타임 크래시 대신 컴파일 타임 에러를 제공하고, 에러 메시지에 디버깅에 필요한 구체적 정보를 포함한다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글