[Triton] bf16/fp16 x mxfp 조합의 num_stages 조정 — shared memory 초과 방지
PR 링크: triton-lang/triton#8773 상태: Merged | 변경: +17 / -3
들어가며
MXFP(Microscaling Floating Point)는 소규모 블록 단위로 스케일링을 적용하는 저정밀 포맷이다. fp16/bf16 입력과 mxfp weight를 행렬 곱셈할 때, weight가 연산 전에 on-the-fly로 업캐스트된다. 문제는 num_stages(파이프라인 단계 수) 계산 시 업캐스트 후의 크기가 반영되지 않아 shared memory를 초과하는 에러가 발생한다는 것이다.
핵심 코드 분석
Before: 원래 weight 크기로 num_stages 계산
def compute_num_stages(block_m, block_k, block_n,
lhs_dtype, rhs_dtype, ...):
weight_size = bitwidth(rhs_dtype) / 8
stage_size = (block_m * block_k * lhs_dtype.itemsize
+ block_k * block_n * weight_size)
# weight_size가 mxfp8의 1바이트로 계산됨
# 실제로는 fp16으로 업캐스트되어 2바이트 사용
After: 업캐스트를 반영한 크기 사용
def compute_num_stages(block_m, block_k, block_n,
lhs_dtype, rhs_dtype, ...):
weight_size = bitwidth(rhs_dtype) / 8
if (precision_config.b_mx_scale is not None
and lhs_dtype in [torch.float16, torch.bfloat16]):
# For fp16/bf16 x mxfp, we upcast weight on the fly,
# so size smem_capacity accordingly.
weight_size = 2 # 업캐스트 후 2바이트
stage_size = (block_m * block_k * lhs_dtype.itemsize
+ block_k * block_n * weight_size)
이 수정이 없으면 다음과 같은 에러가 발생한다:
triton.runtime.errors.OutOfResources: out of resource: shared memory,
Required: 263356, Hardware limit: 232448.
# x.shape = [2048, >=4096] bf16 x [32, >=4096, >=4096] float8_e4m3fn
# block_m=64, block_n=256, block_k=128 → num_stages=4 (너무 많음)
왜 이게 좋은가
- 런타임 에러 방지: num_stages가 정확하게 계산되어 shared memory 초과로 인한
OutOfResources에러가 사라진다. - 실제 메모리 사용량 반영: on-the-fly 업캐스트로 실제 shared memory에서 소비되는 크기를 계산에 반영한다.
- 최소한의 변경: 조건부
weight_size = 2한 줄로 문제를 정확히 해결한다.
정리
이 PR은 fp16/bf16 x mxfp 행렬 곱셈에서 weight의 on-the-fly 업캐스트를 shared memory 크기 계산에 반영하여, 파이프라인 단계 수가 하드웨어 한계를 초과하지 않도록 한다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 핵심 코드와 explaination은 실제 PR diff를 기반으로 합니다.
관련 포스트
- [triton] AMD GPU Descriptor Encoding 최적화 패스 추가
- [triton] Triton 2CTA Block-Scaled Matmul — cuBLAS 대비 성능 비교
- [triton] AMD gfx1250 MXFP Flash Attention 예제 커널 업데이트
- [Triton] Blackwell 2D activation-scale layout에서 ragged metadata 없이 동작하도록 수정
- [triton] AMD: PartitionedSharedEncodingAttr의 LLVM lowering 지원으로 공유 메모리 파티셔닝 구현
PR Analysis 의 다른글
- 이전글 [vllm] group_topk 커널 최적화 - 1.9% Throughput, 2.1% TPOT 개선
- 현재글 : [Triton] bf16/fp16 x mxfp 조합의 num_stages 조정 — shared memory 초과 방지
- 다음글 [Triton] preload에 optional device 인자 추가
댓글