[Triton] BitmatrixMetadata와 RaggedTensorMetadata 도입으로 routing 모듈 리팩터링
들어가며
Triton의 MoE(Mixture of Experts) 커널에서 routing은 어떤 토큰이 어떤 expert로 갈지를 결정하는 핵심 로직이다. 기존 triton_kernels.routing 모듈은 routing 계산, sparse matrix 관리, metadata 생성이 하나로 뭉쳐 있어 재사용이 어려웠다. 이 PR은 BitmatrixMetadata와 RaggedTensorMetadata를 별도 자료구조로 분리하고, 기존 routing API를 deprecated 처리한다.
핵심 코드 분석
Before
from triton_kernels.routing import (
RoutingData,
GatherIndx,
ScatterIndx,
compute_expt_data_torch,
topk_torch,
routing_from_bitmatrix,
)
routing 모듈에서 모든 것을 가져왔고, compute_expt_data_torch 같은 내부 구현 세부사항이 공개 API로 노출되어 있었다.
After
from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx
from triton_kernels.topk import topk_torch
from triton_kernels.topk import topk
from triton_kernels.tensor import BIT, SparseMatrix, Bitmatrix, make_ragged_tensor_metadata
RoutingData, GatherIndx, ScatterIndx는 matmul_ogs로, topk 로직은 topk 모듈로, sparse matrix 관련 코드는 tensor 모듈로 각각 분리되었다.
새로운 routing 코드는 SparseMatrix를 활용한 조합 방식으로 변경되었다:
def legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx
combine_indx = sparse_logits.mask_metadata.row_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(
sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
gate_scal = sparse_logits.vals.flatten()[combine_indx]
routing_data = RoutingData(gate_scal, ragged_batch_metadata.batch_sizes,
n_expts_tot, n_expts_act, ragged_batch_metadata)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
왜 이게 좋은가
- 관심사 분리: routing, topk, sparse matrix가 각각 독립 모듈로 분리되어 단독 테스트와 재사용이 가능하다.
- SparseMatrix 추상화:
SparseMatrix(indx, vals, mask)구조체로 sparse 연산의 의미가 명확해졌다. - 하위 호환성:
legacy_routing함수를 통해 기존 코드를 점진적으로 마이그레이션할 수 있다.
정리
+687/-774의 대규모 리팩터링으로, MoE routing 코드의 모듈 경계를 재정의했다. 테스트 파일 test_routing.py가 삭제되고 관련 테스트가 각 모듈의 테스트로 흡수된 것이 눈에 띈다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석 내용은 실제 PR diff를 기반으로 합니다.
댓글