본문으로 건너뛰기

[triton] Proton 커널 런처에 더 많은 메타데이터 전달

PR 링크: triton-lang/triton#9575 상태: Merged | 변경: +94 / -53

들어가며

Proton 프로파일러는 GPU 커널 실행 시 텐서/스칼라 메트릭을 수집하기 위해 별도의 metric 커널을 launch합니다. 기존에는 커널 포인터와 스트림만 전달했지만, 이 PR은 numThreadssharedMemBytes 등 런치 설정을 함께 전달하여 metric 커널의 GPU 자원 사용을 정밀하게 제어합니다.

핵심 코드 분석

Before:

m.def("set_metric_kernels",
    [](uintptr_t tensorMetricKernel, uintptr_t scalarMetricKernel,
       uintptr_t stream) {
        SessionManager::instance().setMetricKernels(
            reinterpret_cast<void *>(tensorMetricKernel),
            reinterpret_cast<void *>(scalarMetricKernel),
            reinterpret_cast<void *>(stream));
    });

After:

struct MetricKernelLaunchConfig {
  void *kernel{nullptr};
  unsigned int numThreads{1};
  unsigned int sharedMemBytes{0};
};

struct MetricKernelLaunchState {
  MetricKernelLaunchConfig tensor{};
  MetricKernelLaunchConfig scalar{};
  void *stream{nullptr};
};

m.def("set_metric_kernels",
    [](uintptr_t tensorMetricKernel, uintptr_t scalarMetricKernel,
       uintptr_t stream, unsigned int tensorMetricKernelNumThreads,
       unsigned int tensorMetricKernelSharedMemBytes, ...) {
        // MetricKernelLaunchState로 통합
    }, pybind11::arg("tensorMetricKernelNumThreads") = 1,
       pybind11::arg("tensorMetricKernelSharedMemBytes") = 0, ...);

왜 이게 좋은가

기존에 분산되어 있던 3개의 thread-local 변수를 MetricKernelLaunchState 하나로 통합하여 상태 관리가 단순해졌습니다. Python API는 기본값을 제공하여 하위 호환성을 유지하면서도, 복잡한 metric 커널이 필요한 경우 스레드 수와 shared memory를 조절할 수 있습니다. 이는 향후 더 정교한 on-device 메트릭 처리(예: histogram, reduction)를 위한 기반이 됩니다.

정리

Metric 커널 런치 설정을 MetricKernelLaunchConfig/State 구조체로 통합하고, Python API에 numThreads와 sharedMemBytes 매개변수를 추가하여 metric 커널의 자원 제어를 가능하게 했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글