[triton] Triton PROTON: CUDA 그래프 프로파일링 오버헤드를 줄이고 MsgPack API를 추가하여 성능을 대폭 개선
PR 링크: triton-lang/triton#9030 상태: Merged | 변경: +1191 / -364
들어가며
Triton PROTON은 CUDA 애플리케이션의 프로파일링 및 분석을 위한 강력한 도구입니다. 특히 CUDA 그래프(CUDA Graphs)를 사용할 때, 프로파일링 과정에서 발생하는 오버헤드는 성능에 민감한 애플리케이션에 큰 영향을 줄 수 있습니다. 이번 PR은 이러한 오버헤드를 크게 줄이고, 데이터 직렬화 방식을 개선하여 전반적인 성능을 향상시키는 데 초점을 맞추고 있습니다.
주요 개선 사항은 다음과 같습니다:
deactivate및get_data오버헤드 감소: CUDA 그래프 프로파일링 시 발생하는deactivate및get_data함수의 성능을 최적화하여 3배 더 빠르게 만듭니다.get_data_msgpackAPI 추가: 기존get_data(JSON 형식) 대비 10배 빠른get_data_msgpackAPI를 새로 도입하여, 데이터 직렬화 및 전송 속도를 획기적으로 개선합니다.
이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 이러한 최적화가 왜 효과적인지, 그리고 어떤 일반적인 교훈을 얻을 수 있는지 살펴보겠습니다.
코드 분석
1. test-requirements.txt 변경
--- a/python/test-requirements.txt
+++ b/python/test-requirements.txt
@@ -7,3 +7,4 @@
scipy>=1.7.1
llnl-hatchet
expecttest
+msgpack
새로운 get_data_msgpack API를 사용하기 위해 msgpack 라이브러리가 테스트 요구사항에 추가되었습니다. 이는 데이터 직렬화에 msgpack 형식을 사용하게 되었음을 명확히 보여줍니다.
2. third_party/proton/csrc/Proton.cpp 변경
--- a/third_party/proton/csrc/Proton.cpp
+++ b/third_party/proton/csrc/Proton.cpp
@@ -173,6 +173,15 @@
},
pybind11::arg("sessionId"));
+ m.def(
+ "get_data_msgpack",
+ [](size_t sessionId) {
+ auto data = SessionManager::instance().getDataMsgPack(sessionId);
+ return pybind11::bytes(reinterpret_cast<const char *>(data.data()),
+ data.size());
+ },
+ pybind11::arg("sessionId"));
+
m.def(
"clear_data",
[](size_t sessionId) { SessionManager::instance().clearData(sessionId); },
새로운 get_data_msgpack 함수가 Python 바인딩에 추가되었습니다. 이 함수는 SessionManager로부터 getDataMsgPack을 호출하여 데이터를 가져온 후, pybind11::bytes 형태로 반환합니다. 이는 기존 get_data 함수가 내부적으로 JSON 직렬화를 수행하는 것과 달리, 더 효율적인 MsgPack 직렬화 결과를 직접 반환하도록 설계되었음을 시사합니다.
3. third_party/proton/csrc/include/Data/Data.h 변경
--- a/third_party/proton/csrc/include/Data/Data.h
+++ b/third_party/proton/csrc/include/Data/Data.h
@@ -35,6 +35,15 @@
/// Add a single metric to the data.
virtual void addMetric(size_t scopeId, std::shared_ptr<Metric> metric) = 0;
+ /// Add an op and a metric with one call.
+ /// The default implementation forwards to addOp + addMetric.
+ virtual void addOpAndMetric(size_t scopeId, const std::string &opName,
+ std::shared_ptr<Metric> metric) {
+ scopeId = this->addOp(scopeId, opName);
+ this->addMetric(scopeId, metric);
+ }
+
/// Add multiple metrics to the data.
virtual void
addMetrics(size_t scopeId,
@@ -49,6 +58,9 @@
/// To Json
virtual std::string toJsonString() const = 0;
+ /// To MsgPack
+ virtual std::vector<uint8_t> toMsgPack() const = 0;
+
/// Dump the data to the given output format.
void dump(const std::string &outputFormat);
Data 클래스에 두 가지 중요한 가상 함수가 추가되었습니다:
addOpAndMetric: 기존의addOp과addMetric을 분리하여 호출하는 대신, 한 번에 연산 이름과 메트릭을 추가할 수 있게 합니다. 이는 내부적으로 두 번의scopeId조회 및 잠금 해제를 방지하여 효율성을 높일 수 있습니다. 기본 구현은 기존 방식대로 호출하지만, 이를 오버라이드하는 파생 클래스에서 최적화가 가능합니다.toMsgPack(): JSON 직렬화에 대응하는 MsgPack 직렬화 메서드가 추가되었습니다. 이는get_data_msgpackAPI의 기반이 됩니다.
4. third_party/proton/csrc/include/Data/Metric.h 변경
--- a/third_party/proton/csrc/include/Data/Metric.h
+++ b/third_party/proton/csrc/include/Data/Metric.h
@@ -65,17 +65,17 @@
virtual ~Metric() = default;
- virtual const std::string getName() const = 0;
+ virtual const std::string &getName() const = 0;
- virtual const std::string getValueName(int valueId) const = 0;
+ virtual const std::string &getValueName(int valueId) const = 0;
virtual bool isProperty(int valueId) const = 0;
virtual bool isExclusive(int valueId) const = 0;
- std::vector<MetricValueType> getValues() const { return values; }
+ const std::vector<MetricValueType> &getValues() const { return values; }
- MetricValueType getValue(int valueId) { return values[valueId]; }
+ const MetricValueType &getValue(int valueId) const { return values[valueId]; }
/// Update a specific value id with the new value.
void updateValue(int valueId, MetricValueType value) {
@@ -115,7 +115,7 @@
}
/// Update all values with another metric.
- void updateMetric(Metric &other) {
+ void updateMetric(const Metric &other) {
for (int i = 0; i < values.size(); ++i) {
updateValue(i, other.values[i]);
}
@@ -125,7 +125,6 @@
private:
const MetricKind kind;
- const std::string name;
protected:
std::vector<MetricValueType> values;
@@ -153,9 +152,9 @@
std::visit([&](auto &&v) { this->values[0] = v; }, value);
}
- const std::string getName() const override { return "FlexibleMetric"; }
+ const std::string &getName() const override { return name; }
- const std::string getValueName(int valueId) const override {
+ const std::string &getValueName(int valueId) const override {
return valueName;
}
@@ -166,6 +165,7 @@
private:
bool property{};
bool exclusive{};
+ const static inline std::string name = "FlexibleMetric";
std::string valueName;
};
@@ -196,15 +196,15 @@
this->values[StreamId] = streamId;
}
- virtual const std::string getName() const { return "KernelMetric"; }
+ const std::string &getName() const override { return name; }
- virtual const std::string getValueName(int valueId) const {
+ const std::string &getValueName(int valueId) const override {
return VALUE_NAMES[valueId];
}
- virtual bool isProperty(int valueId) const { return PROPERTY[valueId]; }
+ bool isProperty(int valueId) const override { return PROPERTY[valueId]; }
- virtual bool isExclusive(int valueId) const { return EXCLUSIVE[valueId]; }
+ bool isExclusive(int valueId) const override { return EXCLUSIVE[valueId]; }
private:
const static inline bool PROPERTY[kernelMetricKind::Count] = {
@@ -215,6 +215,7 @@
"start_time (ns)", "end_time (ns)", "count", "time (ns)",
"device_id", "device_type", "stream_id",
};
+ const static inline std::string name = "KernelMetric";
};
class PCSamplingMetric : public Metric {
@@ -254,17 +255,15 @@
this->values[PCSamplingMetricKind::NumStalledSamples] = stalledSamples;
}
- virtual const std::string getName() const { return "PCSamplingMetric"; }
+ const std::string &getName() const override { return name; }
- virtual const std::string getValueName(int valueId) const {
+ const std::string &getValueName(int valueId) const override {
return VALUE_NAMES[valueId];
}
- virtual bool isProperty(int valueId) const { return false; }
+ bool isProperty(int valueId) const override { return false; }
+ bool isExclusive(int valueId) const override { return false; }
- virtual bool isExclusive(int valueId) const { return false; }
-
-private:
const static inline std::string VALUE_NAMES[PCSamplingMetricKind::Count] = {
"num_samples",
"num_stalled_samples",
@@ -287,6 +286,7 @@
"stalled_sleeping",
"stalled_selected",
};
+ const static inline std::string name = "PCSamplingMetric";
};
class CycleMetric : public Metric {
@@ -336,15 +336,15 @@
this->values[PostFinalTime] = postFinalTime;
}
- virtual const std::string getName() const { return "CycleMetric"; }
+ const std::string &getName() const override { return name; }
- virtual const std::string getValueName(int valueId) const {
+ const std::string &getValueName(int valueId) const override {
return VALUE_NAMES[valueId];
}
- virtual bool isProperty(int valueId) const { return PROPERTY[valueId]; }
+ bool isProperty(int valueId) const override { return PROPERTY[valueId]; }
- virtual bool isExclusive(int valueId) const { return EXCLUSIVE[valueId]; }
+ bool isExclusive(int valueId) const override { return EXCLUSIVE[valueId]; }
private:
const static inline bool PROPERTY[CycleMetricKind::Count] = {
@@ -358,6 +358,7 @@
"kernel_id", "kernel_name", "block_id", "processor_id",
"unit_id", "device_id", "device_type", "time_shift_cost",
"init_time", "pre_final_time", "post_final_time"};
+ const static inline std::string name = "CycleMetric";
};
/// Each TensorMetric represents a scalar metric stored in a device buffer.
Metric 클래스 계층 구조에서 몇 가지 중요한 변경이 이루어졌습니다:
- 반환 타입 변경:
getName(),getValueName(),getValues(),getValue()와 같은 함수들이const std::string을 반환하는 대신const std::string &를 반환하도록 변경되었습니다. 이는 불필요한 문자열 복사를 방지하여 성능을 향상시킵니다. 또한,std::vector<MetricValueType> getValues() const가const std::vector<MetricValueType> &getValues() const로 변경되어 벡터 자체의 복사를 피합니다. Metric::name멤버 제거:Metric클래스에서name멤버 변수가 제거되었습니다. 대신, 각 파생 클래스(FlexibleMetric,KernelMetric,PCSamplingMetric,CycleMetric)에서 `const static inline std::string name =
참고 자료
- https://github.com/triton-lang/triton/blob/main/python/test-requirements.txt
- https://github.com/triton-lang/triton/blob/main/third_party/proton/csrc/Proton.cpp
- https://github.com/triton-lang/triton/blob/main/third_party/proton/csrc/include/Data/Data.h
- https://github.com/triton-lang/triton/blob/main/third_party/proton/csrc/include/Data/Metric.h
- https://github.com/triton-lang/triton/blob/main/third_party/proton/csrc/include/Data/TraceData.h
- https://github.com/triton-lang/triton/blob/main/third_party/proton/csrc/include/Data/TreeData.h
- https://github.com/triton-lang/triton/blob/main/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h
- https://github.com/triton-lang/triton/blob/main/third_party/proton/csrc/include/Profiler/GPUProfiler.h
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ray Data] StreamingRepartition과 MapBatches 퓨전 규칙 개선
- 현재글 : [triton] Triton PROTON: CUDA 그래프 프로파일링 오버헤드를 줄이고 MsgPack API를 추가하여 성능을 대폭 개선
- 다음글 [triton] Triton AMD 백엔드 최적화: Subtiling을 통한 GEMM 성능 향상
댓글