본문으로 건너뛰기

[Triton] BitmatrixMetadata와 RaggedTensorMetadata 도입으로 routing 모듈 리팩터링

들어가며

Triton의 MoE(Mixture of Experts) 커널에서 routing은 어떤 토큰이 어떤 expert로 갈지를 결정하는 핵심 로직이다. 기존 triton_kernels.routing 모듈은 routing 계산, sparse matrix 관리, metadata 생성이 하나로 뭉쳐 있어 재사용이 어려웠다. 이 PR은 BitmatrixMetadataRaggedTensorMetadata를 별도 자료구조로 분리하고, 기존 routing API를 deprecated 처리한다.

핵심 코드 분석

Before

from triton_kernels.routing import (
    RoutingData,
    GatherIndx,
    ScatterIndx,
    compute_expt_data_torch,
    topk_torch,
    routing_from_bitmatrix,
)

routing 모듈에서 모든 것을 가져왔고, compute_expt_data_torch 같은 내부 구현 세부사항이 공개 API로 노출되어 있었다.

After

from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx
from triton_kernels.topk import topk_torch
from triton_kernels.topk import topk
from triton_kernels.tensor import BIT, SparseMatrix, Bitmatrix, make_ragged_tensor_metadata

RoutingData, GatherIndx, ScatterIndxmatmul_ogs로, topk 로직은 topk 모듈로, sparse matrix 관련 코드는 tensor 모듈로 각각 분리되었다.

새로운 routing 코드는 SparseMatrix를 활용한 조합 방식으로 변경되었다:

def legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
    sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
    dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
    combine_indx = sparse_logits.mask_metadata.row_sorted_indx
    ragged_batch_metadata = make_ragged_tensor_metadata(
        sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
    gate_scal = sparse_logits.vals.flatten()[combine_indx]
    routing_data = RoutingData(gate_scal, ragged_batch_metadata.batch_sizes,
                               n_expts_tot, n_expts_act, ragged_batch_metadata)
    gather_idx = GatherIndx(combine_indx, dispatch_indx)
    scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
    return routing_data, gather_idx, scatter_idx

왜 이게 좋은가

  • 관심사 분리: routing, topk, sparse matrix가 각각 독립 모듈로 분리되어 단독 테스트와 재사용이 가능하다.
  • SparseMatrix 추상화: SparseMatrix(indx, vals, mask) 구조체로 sparse 연산의 의미가 명확해졌다.
  • 하위 호환성: legacy_routing 함수를 통해 기존 코드를 점진적으로 마이그레이션할 수 있다.

정리

+687/-774의 대규모 리팩터링으로, MoE routing 코드의 모듈 경계를 재정의했다. 테스트 파일 test_routing.py가 삭제되고 관련 테스트가 각 모듈의 테스트로 흡수된 것이 눈에 띈다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.

댓글