본문으로 건너뛰기

[triton] Triton AMD StreamK GEMM 커널의 Race Condition 해결: 동기화 로직 최적화 분석

PR 링크: triton-lang/triton#10594 상태: Merged | 변경: +3 / -8

들어가며

고성능 GPU 컴퓨팅에서 GEMM(General Matrix Multiply) 연산은 핵심적인 작업입니다. 특히 Triton의 StreamK 기법은 작업을 여러 CTA(Cooperative Thread Array)로 분할하여 병렬성을 극대화하는 강력한 전략입니다. 하지만 분산된 작업 단위 간의 동기화는 매우 까다로운 문제입니다. 이번에 분석할 PR은 AMD GPU 환경의 StreamK GEMM 커널에서 발생하던 Race Condition 문제를 해결합니다. 특정 CTA가 다른 CTA의 작업이 완료되었다고 오판하여 발생하는 데이터 불일치 문제를 어떻게 수정했는지 살펴봅니다.

코드 분석

1. f16_gemm_streamk_gfx1250.py 내 불필요한 동기화 제거

기존 코드에서는 process_streamk_tilesprocess_streamk_tiles_8warps 함수 내에서 P 버퍼 초기화 후 ttgl.barrier()ttgl.store(locks_ptr + pid, 0)를 호출하고 있었습니다. 하지만 이는 오히려 경쟁 상태를 유발하거나 불필요한 오버헤드를 발생시켰습니다.

Before:

    ttgl.barrier()
    ttgl.store(locks_ptr + pid, 0)

After:

    # 삭제됨

이 변경은 각 CTA가 독립적으로 P 버퍼를 초기화하는 과정에서, 잘못된 시점에 락을 초기화함으로써 발생하는 동기화 오류를 방지합니다. 불필요한 배리어를 제거함으로써 커널 실행의 흐름을 단순화했습니다.

2. locks_device 초기화 방식 개선

커널 내부에서 락을 초기화하는 대신, 호스트 측에서 미리 0으로 초기화된 텐서를 할당하도록 변경되었습니다.

Before:

    locks = torch.empty(num_sms, dtype=torch.int32)
    locks_device = locks.cuda()

After:

    locks_device = torch.zeros(num_sms, dtype=torch.int32, device=a_device.device)

torch.empty는 메모리 할당 시점에 쓰레기 값을 포함할 수 있습니다. 이를 torch.zeros로 변경하여 명시적으로 0으로 초기화함으로써, 커널 진입 전부터 락 상태가 안전하게 보장되도록 개선했습니다.

왜 이게 좋은가

이번 수정은 두 가지 측면에서 매우 중요합니다.

  1. Race Condition 방지: 기존에는 locks_ptr를 커널 내부에서 초기화하면서, 다른 CTA가 이를 참조하는 시점과 충돌이 발생할 가능성이 있었습니다. 호스트에서 미리 초기화된 값을 전달함으로써 동기화의 원천적인 불안정성을 제거했습니다.
  2. 성능 오버헤드 감소: 커널 내부의 불필요한 ttgl.barrier() 호출은 GPU 워프 간의 실행 흐름을 강제로 멈추게 하여 성능 저하를 유발합니다. 이를 제거함으로써 커널의 파이프라인 효율성을 높였습니다.

일반적인 교훈으로, 분산 병렬 시스템에서는 '상태 초기화는 가능한 한 작업 시작 전(호스트 측)에서 완료하고, 커널 내부의 동기화는 최소화해야 한다'는 원칙을 다시 한번 확인할 수 있습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글