본문으로 건너뛰기

[Triton] 4-warp persistent 커널 재활성화

들어가며

Triton의 persistent matmul 커널에서 4-warp 설정이 ptxas 버그로 인해 비활성화되어 있었다. 이후 코드 변경으로 문제가 되던 코드 시퀀스가 사라져 이 PR에서 재활성화한다.

핵심 코드 분석

Before

if isinstance(b_mx_scale_layout, HopperMXScaleLayout) and b_mx_scale_layout.num_warps == 4:
    # TODO: persistent kernel is broken due with 4 warps due to a ptxas bug
    supports_persistent = False
if weight_dtype_str.startswith("mxfloat4") and b_hbm_swizzling and num_warps == 4:
    pytest.skip("Disabled due to ptxas bug")

After

두 제한 모두 제거되었다. 또한 warp specialization의 조건이 조정되었다:

# Workaround for compile error in hopper warp specialization
warp_specialize=FLATTEN_LOOPS,  # 이전: True

왜 이게 좋은가

  • 성능 복원: 4-warp persistent 커널은 특정 workload에서 8-warp보다 효율적일 수 있다.
  • 최소 변경: +2/-6으로 불필요한 제한만 제거했다.

정리

외부 도구(ptxas)의 버그로 인한 workaround가 불필요해지면 즉시 제거하는 것이 코드 건강성에 중요하다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글