[Triton] Blackwell 2D activation-scale layout에서 ragged metadata 없이 동작하도록 수정
PR 링크: triton-lang/triton#9417 상태: Merged | 변경: +46 / -14
들어가며
Blackwell 아키텍처의 MXFP8 activation-scale layout은 입력 텐서의 차원과 ragged metadata 유무에 따라 다른 모드로 동작한다. 기존 코드는 2D 입력이면 항상 ragged metadata가 있다고 가정했지만, ragged_metadata=None인 2D 입력에서는 이 가정이 깨져 레이아웃 구성이 실패했다. 이 PR은 해당 케이스를 batched 모드로 처리한다.
핵심 코드 분석
Before: 2D 입력 시 무조건 ragged 모드 가정
@dataclass(frozen=True)
class BlackwellActMXScaleLayout(Layout):
ragged_metadata: RaggedTensorMetadata # None 허용 안 됨
class BlackwellActMXScaleLayoutTransformation(LayoutTransformation):
ragged_metadata: RaggedTensorMetadata
def __post_init__(self):
if len(self.shape) == 2:
B, M, K = 1, *self.shape
# 항상 ragged_metadata에 접근 -> None이면 에러
n_slices = self.ragged_metadata.slice_sizes.shape[0]
max_n_blocks = self.ragged_metadata.n_blocks(n_slices, M, self.ALIGN_M)
M_pad = self.ALIGN_M * max_n_blocks
mode = "ragged"
After: ragged_metadata 유무에 따라 분기
@dataclass(frozen=True)
class BlackwellActMXScaleLayout(Layout):
ragged_metadata: RaggedTensorMetadata | None # None 허용
class BlackwellActMXScaleLayoutTransformation(LayoutTransformation):
ragged_metadata: RaggedTensorMetadata | None
added_leading_batch_dim: bool = False
def __post_init__(self):
if len(self.shape) == 2:
B, M, K = 1, *self.shape
added_leading_batch_dim = True
if self.ragged_metadata is None:
# ragged 없는 2D -> batched 모드로 처리
M_pad = (M + self.ALIGN_M - 1) // self.ALIGN_M * self.ALIGN_M
mode = "batched"
else:
# 기존 ragged 모드
n_slices = self.ragged_metadata.slice_sizes.shape[0]
max_n_blocks = self.ragged_metadata.n_blocks(...)
M_pad = self.ALIGN_M * max_n_blocks
mode = "ragged"
def unswizzle_data(self, data):
if self.mode == "batched":
data = data[..., :self.M, :self.K]
if self.added_leading_batch_dim:
return data.squeeze(0) # 추가한 batch dim 제거
return data
왜 이게 좋은가
- None-safe:
ragged_metadata=None인 2D 입력을 batched 모드로 자연스럽게 처리한다. - 차원 복원:
added_leading_batch_dim플래그로 내부에서 추가한 배치 차원을unswizzle시 올바르게 제거한다. - 테스트 추가: roundtrip 테스트와
convert_layoutroundtrip 테스트로 양방향 변환을 검증한다. - 기존 동작 유지: ragged metadata가 있는 경우의 동작은 변경하지 않았다.
정리
이 PR은 Blackwell MXFP8 activation-scale layout에서 2D 입력 + ragged_metadata=None 케이스를 처리한다. ragged_metadata 타입을 Optional로 변경하고, None인 경우 batched 모드로 fallback하여 레이아웃 구성 실패를 방지한다.
참고 자료
이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Grafana Loki] cmp.Diff 대신 cmp.Equal로 상태 비교를 단순화
- 현재글 : [Triton] Blackwell 2D activation-scale layout에서 ragged metadata 없이 동작하도록 수정
- 다음글 [ACE-Step-1.5] Apple Silicon을 위한 네이티브 MLX DiT 백엔드 도입: 2-3배 성능 향상
댓글