본문으로 건너뛰기

[triton] GSan AxisInfo 기반 Shadow Update 중복 제거로 2~10배 성능 향상

PR 링크: triton-lang/triton#9832 상태: Merged | 변경: +146 / -19

들어가며

Triton의 Global Sanitizer(GSan)는 GPU 커널의 메모리 접근 오류를 런타임에 감지하는 계측(instrumentation) 도구입니다. 그러나 계측 코드가 삽입되면 커널 성능이 크게 저하됩니다. 특히 sub-word 크기(예: float16)의 요소들이 연속으로 저장된 경우, 각 포인터마다 개별적으로 shadow 메모리를 업데이트하는 것은 심각한 중복 작업이었습니다. 이 PR은 AxisInfo 분석을 활용하여 인접한 포인터들을 병합함으로써 GSan 계측 오버헤드를 대폭 줄였습니다.

핵심 코드 분석

1. AxisInfo를 활용한 벡터 크기 결정

기존 코드에서는 각 텐서 요소에 대해 개별적으로 shadow update를 수행했습니다. 새로운 코드는 AxisInfo의 contiguity와 mask alignment를 분석하여 병합 가능한 요소 수를 결정합니다.

Before:

emitTensorAccessRuntimeCall(rewriter, loc, gsanGlobalStatePtr, ptrElems,
                            maskElems, regMask, threadPred, bytesPerElem,
                            op.getIsStore());

After:

unsigned mergeVec = getVecSize(op);
if (mergeVec > 1) {
  SmallVector<Value> mergedPtrElems;
  SmallVector<Value> mergedMaskElems;
  for (unsigned i = 0; i < numElems; i += mergeVec) {
    mergedPtrElems.push_back(ptrElems[i]);
    if (maskElems.empty()) continue;
    Value mergedMask = maskElems[i];
    for (unsigned j = maskAlign; j < mergeVec; j += maskAlign) {
      mergedMask = arith::OrIOp::create(rewriter, loc, mergedMask, maskElems[i + j]);
    }
    mergedMaskElems.push_back(mergedMask);
  }
  ptrElems = std::move(mergedPtrElems);
  maskElems = std::move(mergedMaskElems);
  bytesPerElem *= mergeVec;
}

핵심은 getVecSize 메서드입니다. contiguity(연속된 요소 수)와 mask alignment를 비교하여 안전하게 병합 가능한 최대 크기를 계산합니다. 특히 shadow granularity(4바이트)보다 작은 요소(예: fp16 = 2바이트)의 경우, mask alignment를 최소 4/bytesPerElem으로 올려서 shadow 갱신 단위에 맞춥니다.

2. TMA 로드에 대한 AxisInfo 힌트 설정

TMA(Tensor Memory Accelerator) 기반 로드의 경우, 컴파일러가 contiguity 정보를 자동으로 추론하기 어렵습니다. 이를 해결하기 위해 TMA 포인터에 직접 contiguity와 divisibility 속성을 설정합니다.

Before: (TMA 로드에 대한 힌트 없음)

After:

static void setTMAPtrAxisHints(OpBuilder &builder, Value ptr) {
  auto rank = ptrTy.getRank();
  SmallVector<int32_t> contiguity(rank, 1);
  contiguity.back() = ptrTy.getShape().back();
  SmallVector<int32_t> divisibility(rank, 1);
  divisibility.back() = 16;
  def->setDiscardableAttr("tt.contiguity",
                          DenseIntElementsAttr::get(attrTy, contiguity));
  def->setDiscardableAttr("tt.divisibility",
                          DenseIntElementsAttr::get(attrTy, divisibility));
}

이 힌트가 있어야 AxisInfo 분석이 TMA 로드의 연속성을 인식하고, 앞서 설명한 병합 최적화를 적용할 수 있습니다.

왜 이게 좋은가

이 최적화의 핵심 가치는 분석 정보의 재활용입니다. Triton 컴파일러는 이미 vectorized load/store를 위해 AxisInfo 분석을 수행합니다. 동일한 분석 결과를 sanitizer 계측에도 활용함으로써, 추가적인 분석 비용 없이 중복 작업을 제거했습니다.

실제 효과도 인상적입니다. FP16 기반 matmul에서 GSan이 활성화된 상태로 2배, TMA 기반 FP16 matmul에서는 10배의 속도 향상을 보였습니다. 이는 sanitizer의 실용성을 크게 높여주는 변화입니다. 개발 중 항상 sanitizer를 켜놓고도 합리적인 성능을 유지할 수 있게 되었기 때문입니다.

정리

  • AxisInfo의 contiguity 속성으로 인접 포인터를 병합하여 shadow update 횟수를 대폭 감소
  • TMA 로드에 contiguity/divisibility 힌트를 추가하여 최적화 적용 범위 확대
  • FP16 matmul 기준 GSan 오버헤드 2~10배 감소

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글