[Triton] 4-warp persistent 커널 재활성화
들어가며
Triton의 persistent matmul 커널에서 4-warp 설정이 ptxas 버그로 인해 비활성화되어 있었다. 이후 코드 변경으로 문제가 되던 코드 시퀀스가 사라져 이 PR에서 재활성화한다.
핵심 코드 분석
Before
if isinstance(b_mx_scale_layout, HopperMXScaleLayout) and b_mx_scale_layout.num_warps == 4:
# TODO: persistent kernel is broken due with 4 warps due to a ptxas bug
supports_persistent = False
if weight_dtype_str.startswith("mxfloat4") and b_hbm_swizzling and num_warps == 4:
pytest.skip("Disabled due to ptxas bug")
After
두 제한 모두 제거되었다. 또한 warp specialization의 조건이 조정되었다:
# Workaround for compile error in hopper warp specialization
warp_specialize=FLATTEN_LOOPS, # 이전: True
왜 이게 좋은가
- 성능 복원: 4-warp persistent 커널은 특정 workload에서 8-warp보다 효율적일 수 있다.
- 최소 변경: +2/-6으로 불필요한 제한만 제거했다.
정리
외부 도구(ptxas)의 버그로 인한 workaround가 불필요해지면 즉시 제거하는 것이 코드 건강성에 중요하다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글