본문으로 건너뛰기

[Triton] Blackwell 2D activation-scale layout에서 ragged metadata 없이 동작하도록 수정

PR 링크: triton-lang/triton#9417 상태: Merged | 변경: +46 / -14

들어가며

Blackwell 아키텍처의 MXFP8 activation-scale layout은 입력 텐서의 차원과 ragged metadata 유무에 따라 다른 모드로 동작한다. 기존 코드는 2D 입력이면 항상 ragged metadata가 있다고 가정했지만, ragged_metadata=None인 2D 입력에서는 이 가정이 깨져 레이아웃 구성이 실패했다. 이 PR은 해당 케이스를 batched 모드로 처리한다.

핵심 코드 분석

Before: 2D 입력 시 무조건 ragged 모드 가정

@dataclass(frozen=True)
class BlackwellActMXScaleLayout(Layout):
    ragged_metadata: RaggedTensorMetadata  # None 허용 안 됨

class BlackwellActMXScaleLayoutTransformation(LayoutTransformation):
    ragged_metadata: RaggedTensorMetadata

    def __post_init__(self):
        if len(self.shape) == 2:
            B, M, K = 1, *self.shape
            # 항상 ragged_metadata에 접근 -> None이면 에러
            n_slices = self.ragged_metadata.slice_sizes.shape[0]
            max_n_blocks = self.ragged_metadata.n_blocks(n_slices, M, self.ALIGN_M)
            M_pad = self.ALIGN_M * max_n_blocks
            mode = "ragged"

After: ragged_metadata 유무에 따라 분기

@dataclass(frozen=True)
class BlackwellActMXScaleLayout(Layout):
    ragged_metadata: RaggedTensorMetadata | None  # None 허용

class BlackwellActMXScaleLayoutTransformation(LayoutTransformation):
    ragged_metadata: RaggedTensorMetadata | None
    added_leading_batch_dim: bool = False

    def __post_init__(self):
        if len(self.shape) == 2:
            B, M, K = 1, *self.shape
            added_leading_batch_dim = True
            if self.ragged_metadata is None:
                # ragged 없는 2D -> batched 모드로 처리
                M_pad = (M + self.ALIGN_M - 1) // self.ALIGN_M * self.ALIGN_M
                mode = "batched"
            else:
                # 기존 ragged 모드
                n_slices = self.ragged_metadata.slice_sizes.shape[0]
                max_n_blocks = self.ragged_metadata.n_blocks(...)
                M_pad = self.ALIGN_M * max_n_blocks
                mode = "ragged"

    def unswizzle_data(self, data):
        if self.mode == "batched":
            data = data[..., :self.M, :self.K]
            if self.added_leading_batch_dim:
                return data.squeeze(0)  # 추가한 batch dim 제거
            return data

왜 이게 좋은가

  1. None-safe: ragged_metadata=None인 2D 입력을 batched 모드로 자연스럽게 처리한다.
  2. 차원 복원: added_leading_batch_dim 플래그로 내부에서 추가한 배치 차원을 unswizzle 시 올바르게 제거한다.
  3. 테스트 추가: roundtrip 테스트와 convert_layout roundtrip 테스트로 양방향 변환을 검증한다.
  4. 기존 동작 유지: ragged metadata가 있는 경우의 동작은 변경하지 않았다.

정리

이 PR은 Blackwell MXFP8 activation-scale layout에서 2D 입력 + ragged_metadata=None 케이스를 처리한다. ragged_metadata 타입을 Optional로 변경하고, None인 경우 batched 모드로 fallback하여 레이아웃 구성 실패를 방지한다.

참고 자료


이 글은 AI를 활용하여 PR의 핵심 변경사항을 분석하고 정리한 것입니다. 실제 코드의 맥락은 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글