본문으로 건너뛰기

[triton] Matmul에서 Split-K Reduction과 Inter-Expert Reduction 분리

PR 링크: triton-lang/triton#8483 상태: Merged | 변경: +262 / -398

들어가며

이 PR은 Triton Kernels의 matmul_ogs에서 split-k reduction 로직을 inter-expert reduction에서 분리하는 리팩터링입니다. 이전에는 FusedActivation에 reduction_n 파라미터가 포함되어 두 reduction이 결합되어 있었는데, 이를 분리하여 각각 독립적으로 제어할 수 있게 합니다.

핵심 코드 분석

Before

@dataclass(frozen=True)
class FusedActivation:
    specs: FnSpecs = FnSpecs.default()
    fn_args: tuple[object] = tuple()
    reduction_n: int = 1  # reduction이 activation에 결합됨

After

@dataclass(frozen=True)
class FusedActivation:
    specs: FnSpecs = FnSpecs.default()
    fn_args: tuple[object] = tuple()
    # reduction_n 제거 - FnSpecs로 이동

@dataclass(frozen=True)
class FnSpecs:
    name: str
    fn: "triton.runtime.jit.JITFunction"
    fn_arg_names: tuple[str]
    fn_arg_do_not_specialize: tuple[str] = tuple()
    reduction_n: int = 1  # FnSpecs 레벨로 이동

SpecializationModule 도입

specializations = SpecializationModule("matmul_ogs",
    kernels=[("_matmul_ogs", _matmul_ogs), ("_p_matmul_ogs", _p_matmul_ogs)],
    closure_args={
        "epilogue": ClosureArg("EPILOGUE_FN", "epilogue_fn_args"),
        "activation": ClosureArg("ACTIVATION_FN", "activation_fn_args"),
    },
)

왜 이게 좋은가

  1. 관심사 분리: split-k reduction과 inter-expert reduction이 독립적으로 제어되어 MoE 파이프라인에서 각각 최적화할 수 있습니다.
  2. 코드 간소화: _reduce_grouped 커널이 제거되고 독립 reduce 모듈로 대체되어 코드 중복이 줄었습니다 (-398줄).
  3. 확장성: SpecializationModule 패턴으로 커널 특수화를 일반화하여 새로운 커널 추가가 용이해졌습니다.

정리

matmul과 reduction의 결합을 느슨하게 만든 리팩터링으로, 코드량이 줄면서도 기능적 유연성은 높아졌습니다. MoE 워크로드에서 Expert Parallelism을 제대로 활용하려면 이런 분리가 필수적입니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글