본문으로 건너뛰기

[triton] Custom DSL Plugin Ops 지원

PR 링크: triton-lang/triton#9626 상태: Merged | 변경: +227 / -23

들어가며

Triton은 이미 out-of-tree 패스와 dialect 플러그인을 지원하지만, 프론트엔드에서 custom op을 호출하는 것은 불가능했습니다. 이 PR은 플러그인이 custom op을 등록하고, Python DSL에서 이를 호출할 수 있는 메커니즘을 추가합니다.

핵심 코드 분석

1. Custom op 열거 및 실행 API

TRITON_PLUGIN_API
tritonEnumeratePluginCustomOps(uint32_t *count, const char **handles) {
  if (!count) return TP_GENERIC_FAILURE;
  *count = 1;
  if (!handles) return TP_SUCCESS;
  handles[0] = "create_custom_op";
  return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonAddPluginCustomOp(const char *handle, TritonOpBuilder &self,
                        std::vector<mlir::Value> &operands) {
  ::mlir::Value &dst = operands[0];
  ::mlir::Value &src = operands[1];
  dst = self.create<arith::AddFOp>(src, src);
  operands[0] = dst;
  return TP_SUCCESS;
}

플러그인은 두 개의 C 함수를 export합니다: tritonEnumeratePluginCustomOps로 op 이름을 나열하고, tritonAddPluginCustomOp로 MLIR IR을 직접 생성합니다.

2. TritonPlugin에 custom op 핸들러 추가

Before:

struct TritonPlugin {
  static constexpr char DIALECT_PLUGININFO[] = "tritonGetDialectPluginInfo";
  static constexpr char ADD_PASS[] = "tritonAddPluginPass";
  static constexpr char REGISTER_PASS[] = "tritonRegisterPluginPass";
  // custom op 관련 없음
};

After:

struct TritonPlugin {
  static constexpr char ENUMERATE_CUSTOMOPS[] =
      "tritonEnumeratePluginCustomOps";
  static constexpr char ADD_CUSTOMOP[] = "tritonAddPluginCustomOp";

  llvm::Expected<TritonPluginResult>
  getCustomOpHandles(std::vector<const char *> &handles);

  llvm::Expected<TritonPluginResult>
  addCustomOp(const char *handle, TritonOpBuilder &self,
              std::vector<mlir::Value> &operands);
};

기존 pass/dialect 플러그인 구조를 따라, custom op도 열거 + 실행의 2단계 API로 설계되었습니다.

왜 이게 좋은가

  • 확장성: 서드파티가 Triton 코어를 수정하지 않고도 자체 연산을 추가할 수 있습니다.
  • 일관된 API: 기존 pass/dialect 플러그인과 동일한 패턴(enumerate + add)을 따릅니다.
  • 직접 IR 생성: TritonOpBuilder를 통해 MLIR IR을 직접 조작하므로 표현력이 높습니다.

정리

Triton 플러그인 시스템의 세 번째 축인 custom op 지원이 추가되었습니다. 서드파티 DSL 확장이 pass, dialect, custom op의 세 레벨에서 가능해졌습니다.

참고 자료


이 글은 AI의 도움을 받아 작성되었으며, 원본 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글