본문으로 건너뛰기

[triton] Triton PROTON: CUDA 그래프 프로파일링 오버헤드를 줄이고 MsgPack API를 추가하여 성능을 대폭 개선

PR 링크: triton-lang/triton#9030 상태: Merged | 변경: +1191 / -364

들어가며

Triton PROTON은 CUDA 애플리케이션의 프로파일링 및 분석을 위한 강력한 도구입니다. 특히 CUDA 그래프(CUDA Graphs)를 사용할 때, 프로파일링 과정에서 발생하는 오버헤드는 성능에 민감한 애플리케이션에 큰 영향을 줄 수 있습니다. 이번 PR은 이러한 오버헤드를 크게 줄이고, 데이터 직렬화 방식을 개선하여 전반적인 성능을 향상시키는 데 초점을 맞추고 있습니다.

주요 개선 사항은 다음과 같습니다:

  1. deactivateget_data 오버헤드 감소: CUDA 그래프 프로파일링 시 발생하는 deactivateget_data 함수의 성능을 최적화하여 3배 더 빠르게 만듭니다.
  2. get_data_msgpack API 추가: 기존 get_data (JSON 형식) 대비 10배 빠른 get_data_msgpack API를 새로 도입하여, 데이터 직렬화 및 전송 속도를 획기적으로 개선합니다.

이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 최적화가 왜 효과적인지, 그리고 어떤 일반적인 교훈을 얻을 수 있는지 살펴보겠습니다.

코드 분석

1. test-requirements.txt 변경

--- a/python/test-requirements.txt
+++ b/python/test-requirements.txt
@@ -7,3 +7,4 @@
 scipy>=1.7.1
 llnl-hatchet
 expecttest
+msgpack

새로운 get_data_msgpack API를 사용하기 위해 msgpack 라이브러리가 테스트 요구사항에 추가되었습니다. 이는 데이터 직렬화에 msgpack 형식을 사용하게 되었음을 명확히 보여줍니다.

2. third_party/proton/csrc/Proton.cpp 변경

--- a/third_party/proton/csrc/Proton.cpp
+++ b/third_party/proton/csrc/Proton.cpp
@@ -173,6 +173,15 @@
       },
       pybind11::arg("sessionId"));
 
+  m.def(
+      "get_data_msgpack",
+      [](size_t sessionId) {
+        auto data = SessionManager::instance().getDataMsgPack(sessionId);
+        return pybind11::bytes(reinterpret_cast<const char *>(data.data()),
+                               data.size());
+      },
+      pybind11::arg("sessionId"));
+
   m.def(
       "clear_data",
       [](size_t sessionId) { SessionManager::instance().clearData(sessionId); },

새로운 get_data_msgpack 함수가 Python 바인딩에 추가되었습니다. 이 함수는 SessionManager로부터 getDataMsgPack을 호출하여 데이터를 가져온 후, pybind11::bytes 형태로 반환합니다. 이는 기존 get_data 함수가 내부적으로 JSON 직렬화를 수행하는 것과 달리, 더 효율적인 MsgPack 직렬화 결과를 직접 반환하도록 설계되었음을 시사합니다.

3. third_party/proton/csrc/include/Data/Data.h 변경

--- a/third_party/proton/csrc/include/Data/Data.h
+++ b/third_party/proton/csrc/include/Data/Data.h
@@ -35,6 +35,15 @@
   /// Add a single metric to the data.
   virtual void addMetric(size_t scopeId, std::shared_ptr<Metric> metric) = 0;
 
+  /// Add an op and a metric with one call.
+  /// The default implementation forwards to addOp + addMetric.
+  virtual void addOpAndMetric(size_t scopeId, const std::string &opName,
+                              std::shared_ptr<Metric> metric) {
+    scopeId = this->addOp(scopeId, opName);
+    this->addMetric(scopeId, metric);
+  }
+
   /// Add multiple metrics to the data.
   virtual void
   addMetrics(size_t scopeId,
@@ -49,6 +58,9 @@
   /// To Json
   virtual std::string toJsonString() const = 0;
 
+  /// To MsgPack
+  virtual std::vector<uint8_t> toMsgPack() const = 0;
+
   /// Dump the data to the given output format.
   void dump(const std::string &outputFormat);
 

Data 클래스에 두 가지 중요한 가상 함수가 추가되었습니다:

  • addOpAndMetric: 기존의 addOpaddMetric을 분리하여 호출하는 대신, 한 번에 연산 이름과 메트릭을 추가할 수 있게 합니다. 이는 내부적으로 두 번의 scopeId 조회 및 잠금 해제를 방지하여 효율성을 높일 수 있습니다. 기본 구현은 기존 방식대로 호출하지만, 이를 오버라이드하는 파생 클래스에서 최적화가 가능합니다.
  • toMsgPack(): JSON 직렬화에 대응하는 MsgPack 직렬화 메서드가 추가되었습니다. 이는 get_data_msgpack API의 기반이 됩니다.

4. third_party/proton/csrc/include/Data/Metric.h 변경

--- a/third_party/proton/csrc/include/Data/Metric.h
+++ b/third_party/proton/csrc/include/Data/Metric.h
@@ -65,17 +65,17 @@
 
   virtual ~Metric() = default;
 
-  virtual const std::string getName() const = 0;
+  virtual const std::string &getName() const = 0;
 
-  virtual const std::string getValueName(int valueId) const = 0;
+  virtual const std::string &getValueName(int valueId) const = 0;
 
   virtual bool isProperty(int valueId) const = 0;
 
   virtual bool isExclusive(int valueId) const = 0;
 
-  std::vector<MetricValueType> getValues() const { return values; }
+  const std::vector<MetricValueType> &getValues() const { return values; }
 
-  MetricValueType getValue(int valueId) { return values[valueId]; }
+  const MetricValueType &getValue(int valueId) const { return values[valueId]; }
 
   /// Update a specific value id with the new value.
   void updateValue(int valueId, MetricValueType value) {
@@ -115,7 +115,7 @@
   }
 
   /// Update all values with another metric.
-  void updateMetric(Metric &other) {
+  void updateMetric(const Metric &other) {
     for (int i = 0; i < values.size(); ++i) {
       updateValue(i, other.values[i]);
     }
@@ -125,7 +125,6 @@
 
 private:
   const MetricKind kind;
-  const std::string name;
 
 protected:
   std::vector<MetricValueType> values;
@@ -153,9 +152,9 @@
     std::visit([&](auto &&v) { this->values[0] = v; }, value);
   }
 
-  const std::string getName() const override { return "FlexibleMetric"; }
+  const std::string &getName() const override { return name; }
 
-  const std::string getValueName(int valueId) const override {
+  const std::string &getValueName(int valueId) const override {
     return valueName;
   }
 
@@ -166,6 +165,7 @@
 private:
   bool property{};
   bool exclusive{};
+  const static inline std::string name = "FlexibleMetric";
   std::string valueName;
 };
 
@@ -196,15 +196,15 @@
     this->values[StreamId] = streamId;
   }
 
-  virtual const std::string getName() const { return "KernelMetric"; }
+  const std::string &getName() const override { return name; }
 
-  virtual const std::string getValueName(int valueId) const {
+  const std::string &getValueName(int valueId) const override {
     return VALUE_NAMES[valueId];
   }
 
-  virtual bool isProperty(int valueId) const { return PROPERTY[valueId]; }
+  bool isProperty(int valueId) const override { return PROPERTY[valueId]; }
 
-  virtual bool isExclusive(int valueId) const { return EXCLUSIVE[valueId]; }
+  bool isExclusive(int valueId) const override { return EXCLUSIVE[valueId]; }
 
 private:
   const static inline bool PROPERTY[kernelMetricKind::Count] = {
@@ -215,6 +215,7 @@
       "start_time (ns)", "end_time (ns)", "count",     "time (ns)",
       "device_id",       "device_type",   "stream_id",
   };
+  const static inline std::string name = "KernelMetric";
 };
 
 class PCSamplingMetric : public Metric {
@@ -254,17 +255,15 @@
     this->values[PCSamplingMetricKind::NumStalledSamples] = stalledSamples;
   }
 
-  virtual const std::string getName() const { return "PCSamplingMetric"; }
+  const std::string &getName() const override { return name; }
 
-  virtual const std::string getValueName(int valueId) const {
+  const std::string &getValueName(int valueId) const override {
     return VALUE_NAMES[valueId];
   }
 
-  virtual bool isProperty(int valueId) const { return false; }
+  bool isProperty(int valueId) const override { return false; }
+  bool isExclusive(int valueId) const override { return false; }
 
-  virtual bool isExclusive(int valueId) const { return false; }
-
-private:
   const static inline std::string VALUE_NAMES[PCSamplingMetricKind::Count] = {
       "num_samples",
       "num_stalled_samples",
@@ -287,6 +286,7 @@
       "stalled_sleeping",
       "stalled_selected",
   };
+  const static inline std::string name = "PCSamplingMetric";
 };
 
 class CycleMetric : public Metric {
@@ -336,15 +336,15 @@
     this->values[PostFinalTime] = postFinalTime;
   }
 
-  virtual const std::string getName() const { return "CycleMetric"; }
+  const std::string &getName() const override { return name; }
 
-  virtual const std::string getValueName(int valueId) const {
+  const std::string &getValueName(int valueId) const override {
     return VALUE_NAMES[valueId];
   }
 
-  virtual bool isProperty(int valueId) const { return PROPERTY[valueId]; }
+  bool isProperty(int valueId) const override { return PROPERTY[valueId]; }
 
-  virtual bool isExclusive(int valueId) const { return EXCLUSIVE[valueId]; }
+  bool isExclusive(int valueId) const override { return EXCLUSIVE[valueId]; }
 
 private:
   const static inline bool PROPERTY[CycleMetricKind::Count] = {
@@ -358,6 +358,7 @@
       "kernel_id",   "kernel_name",    "block_id",       "processor_id",
       "unit_id",     "device_id",      "device_type",    "time_shift_cost",
       "init_time",   "pre_final_time", "post_final_time"};
+  const static inline std::string name = "CycleMetric";
 };
 
 /// Each TensorMetric represents a scalar metric stored in a device buffer.

Metric 클래스 계층 구조에서 몇 가지 중요한 변경이 이루어졌습니다:

  • 반환 타입 변경: getName(), getValueName(), getValues(), getValue()와 같은 함수들이 const std::string을 반환하는 대신 const std::string &를 반환하도록 변경되었습니다. 이는 불필요한 문자열 복사를 방지하여 성능을 향상시킵니다. 또한, std::vector<MetricValueType> getValues() constconst std::vector<MetricValueType> &getValues() const로 변경되어 벡터 자체의 복사를 피합니다.
  • Metric::name 멤버 제거: Metric 클래스에서 name 멤버 변수가 제거되었습니다. 대신, 각 파생 클래스(FlexibleMetric, KernelMetric, PCSamplingMetric, CycleMetric)에서 `const static inline std::string name =

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글