[triton] Proton 커널 런처에 더 많은 메타데이터 전달
PR 링크: triton-lang/triton#9575 상태: Merged | 변경: +94 / -53
들어가며
Proton 프로파일러는 GPU 커널 실행 시 텐서/스칼라 메트릭을 수집하기 위해 별도의 metric 커널을 launch합니다. 기존에는 커널 포인터와 스트림만 전달했지만, 이 PR은 numThreads와 sharedMemBytes 등 런치 설정을 함께 전달하여 metric 커널의 GPU 자원 사용을 정밀하게 제어합니다.
핵심 코드 분석
Before:
m.def("set_metric_kernels",
[](uintptr_t tensorMetricKernel, uintptr_t scalarMetricKernel,
uintptr_t stream) {
SessionManager::instance().setMetricKernels(
reinterpret_cast<void *>(tensorMetricKernel),
reinterpret_cast<void *>(scalarMetricKernel),
reinterpret_cast<void *>(stream));
});
After:
struct MetricKernelLaunchConfig {
void *kernel{nullptr};
unsigned int numThreads{1};
unsigned int sharedMemBytes{0};
};
struct MetricKernelLaunchState {
MetricKernelLaunchConfig tensor{};
MetricKernelLaunchConfig scalar{};
void *stream{nullptr};
};
m.def("set_metric_kernels",
[](uintptr_t tensorMetricKernel, uintptr_t scalarMetricKernel,
uintptr_t stream, unsigned int tensorMetricKernelNumThreads,
unsigned int tensorMetricKernelSharedMemBytes, ...) {
// MetricKernelLaunchState로 통합
}, pybind11::arg("tensorMetricKernelNumThreads") = 1,
pybind11::arg("tensorMetricKernelSharedMemBytes") = 0, ...);
왜 이게 좋은가
기존에 분산되어 있던 3개의 thread-local 변수를 MetricKernelLaunchState 하나로 통합하여 상태 관리가 단순해졌습니다. Python API는 기본값을 제공하여 하위 호환성을 유지하면서도, 복잡한 metric 커널이 필요한 경우 스레드 수와 shared memory를 조절할 수 있습니다. 이는 향후 더 정교한 on-device 메트릭 처리(예: histogram, reduction)를 위한 기반이 됩니다.
정리
Metric 커널 런치 설정을 MetricKernelLaunchConfig/State 구조체로 통합하고, Python API에 numThreads와 sharedMemBytes 매개변수를 추가하여 metric 커널의 자원 제어를 가능하게 했습니다.
참고 자료
이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] --performance-mode: 워크로드별 최적화 프로파일
- 현재글 : [triton] Proton 커널 런처에 더 많은 메타데이터 전달
- 다음글 [faster-qwen3-tts] 생성 요청 직렬화 및 모델 캐싱 도입
댓글