본문으로 건너뛰기

[sglang] ROCm 아키텍처별 최적화: 런타임 디스패치로 성능 극대화

PR 링크: sgl-project/sglang#27745 상태: Merged | 변경: +293 / -78

들어가며

최근 SGLang 프로젝트의 GitHub PR에서는 ROCm 환경에서 다양한 GPU 아키텍처(gfx942, gfx950)에 대한 멀티 아키텍처 지원을 강화하고, 런타임 최적화를 통해 성능을 극대화하는 중요한 변경이 이루어졌습니다. 이 PR은 특히 FP8 양자화 커널과 TopK 커널에서 각 아키텍처의 특성을 고려한 동적 설정을 도입하여, 이전에는 불가능했던 수준의 최적화를 가능하게 합니다. 이번 글에서는 이 PR의 주요 변경 사항을 코드 diff를 중심으로 분석하고, 이러한 최적화가 왜 효과적이며 어떤 일반적인 교훈을 주는지 살펴보겠습니다.

코드 분석

1. FP8 Quantization (dsv4_norm_rope.cu)

FP8 양자화는 모델의 메모리 사용량을 줄이고 연산 속도를 높이는 데 핵심적인 역할을 합니다. 이 PR에서는 ROCm 환경에서 gfx942와 gfx950 아키텍처 간의 FP8 표현 방식 차이를 해결하기 위해 런타임 디스패치를 도입했습니다.

변경 전:

static constexpr float kFP8Max = 448.0f;

#ifndef USE_ROCM
__device__ __forceinline__ fp8x2_e4m3_t pack_fp8(float x, float y) {
  x = fmaxf(fminf(x, kFP8Max), -kFP8Max);
  y = fmaxf(fminf(y, kFP8Max), -kFP8Max);
  return __nv_fp8x2_e4m3(float2{x, y});
}
#else
// Software float -> FP8 E4M3 conversion for ROCm
__device__ __forceinline__ uint8_t cvt_float_to_fp8_e4m3(float val) {
  constexpr float kMax = kFP8Max;
  val = fmaxf(fminf(val, kMax), -kMax);
  if (val == 0.0f) return 0;

  // ... (conversion logic)

  return sign | (static_cast<uint8_t>(exp8) << 3) | static_cast<uint8_t>(mant3);
}

__device__ __forceinline__ fp8x2_e4m3_t pack_fp8(float x, float y) {
  uint8_t x8 = cvt_float_to_fp8_e4m3(x);
  uint8_t y8 = cvt_float_to_fp8_e4m3(y);
  return static_cast<uint16_t>(x8) | (static_cast<uint16_t>(y8) << 8);
}
#endif

변경 후:

#ifdef USE_ROCM
// Runtime GPU architecture detection for ROCm multi-arch FP8 support
inline static bool is_gfx950_gpu() {
  static int cached_result = -1;
  if (cached_result == -1) {
    hipDeviceProp_t prop;
    hipGetDeviceProperties(&prop, 0);
    // gfx950 uses E4M3FN (max=448), older archs (gfx942, etc.) use E4M3FNUZ (max=224)
    cached_result = (strncmp(prop.gcnArchName, "gfx950", 6) == 0) ? 1 : 0;
  }
  return cached_result == 1;
}

inline static float get_fp8_max() {
  static const float fp8_max = is_gfx950_gpu() ? 448.0f : 224.0f;
  return fp8_max;
}
#endif

static constexpr float kFP8Max = 448.0f;  // CUDA default, overridden at runtime for ROCm

#ifndef USE_ROCM
template <int FP8_MAX = 448>
__device__ __forceinline__ fp8x2_e4m3_t pack_fp8(float x, float y) {
  constexpr float kMax = static_cast<float>(FP8_MAX);
  x = fmaxf(fminf(x, kMax), -kMax);
  y = fmaxf(fminf(y, kMax), -kMax);
  return __nv_fp8x2_e4m3(float2{x, y});
}
#else
// Software float -> FP8 E4M3 conversion for ROCm
template <int FP8_MAX = 448>
__device__ __forceinline__ uint8_t cvt_float_to_fp8_e4m3(float val) {
  constexpr float kMax = static_cast<float>(FP8_MAX);
  val = fmaxf(fminf(val, kMax), -kMax);
  if (val == 0.0f) return 0;

  // ... (conversion logic)

  return sign | (static_cast<uint8_t>(exp8) << 3) | static_cast<uint8_t>(mant3);
}

template <int FP8_MAX = 448>
__device__ __forceinline__ fp8x2_e4m3_t pack_fp8(float x, float y) {
  uint8_t x8 = cvt_float_to_fp8_e4m3<FP8_MAX>(x);
  uint8_t y8 = cvt_float_to_fp8_e4m3<FP8_MAX>(y);
  return static_cast<uint16_t>(x8) | (static_cast<uint16_t>(y8) << 8);
}
#endif

설명:

  • 런타임 GPU 감지 (is_gfx950_gpu, get_fp8_max): ROCm 커널 내에서 hipGetDeviceProperties를 사용하여 현재 실행 중인 GPU 아키텍처를 감지합니다. gfx950은 FP8 E4M3FN (최대값 448.0)을 지원하는 반면, gfx942와 같은 이전 아키텍처는 E4M3FNUZ (최대값 224.0)를 사용합니다. 이 차이를 런타임에 파악하여 get_fp8_max 함수가 적절한 최대값을 반환하도록 합니다.
  • 템플릿 파라미터화 (template <int FP8_MAX = 448>): pack_fp8 및 관련 함수들이 FP8_MAX 값을 템플릿 파라미터로 받도록 변경되었습니다. 이를 통해 컴파일 시점에 각 아키텍처에 맞는 최대값을 적용할 수 있습니다. CUDA 빌드에서는 기본값인 448.0f를 사용하고, ROCm 환경에서는 런타임 감지 결과에 따라 224.0f 또는 448.0f가 동적으로 선택됩니다.
  • Zero-overhead GPU detection: static int cached_result = -1;와 같이 정적 변수를 사용하여 GPU 아키텍처 감지 로직이 프로그램 실행 중 단 한 번만 수행되도록 캐싱합니다. 이는 런타임 오버헤드를 최소화하는 중요한 최적화 기법입니다.

이 변경으로 인해 각 ROCm 아키텍처는 자신의 FP8 하드웨어 기능을 최대한 활용할 수 있게 되어, 이전에는 고정된 최대값으로 인해 발생했던 잠재적인 오버플로우 또는 언더플로우 문제를 방지하고 정확도를 유지하면서 성능을 최적화할 수 있습니다.

2. TopK Kernels (topk.cu, deepseek_v4_topk.cu)

TopK 연산은 시퀀스 길이 예측 등에서 중요한 역할을 합니다. 이 PR은 GPU의 공유 메모리(Shared Memory, SMEM) 크기 최적화를 통해 TopK 커널의 성능을 개선했습니다.

변경 전:

#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES
constexpr size_t kSMEM = static_cast<size_t>(SGL_TOPK_DYNAMIC_SMEM_BYTES);
#else
constexpr size_t kSMEM = 48 * 1024;  // bytes
#endif
static_assert(kSMEM % (2 * sizeof(int32_t)) == 0, "kSMEM must be a multiple of 8 bytes.");

// ...

__device__ void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, uint32_t length, uint32_t topk) {
  constexpr uint32_t RADIX = 256;
  constexpr uint32_t BLOCK_SIZE = kBlockSize;
  constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t));

  alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32];
  // ...
  extern __shared__ uint32_t s_input_idx[][SMEM_INPUT_SIZE];
  // ...
}

변경 후:

#ifdef USE_ROCM
// Runtime GPU architecture detection for ROCm multi-arch shared memory support
inline static bool is_gfx950_gpu() {
  static int cached_result = -1;  // -1 = not detected, 0 = no, 1 = yes
  if (cached_result == -1) {
    hipDeviceProp_t prop;
    hipGetDeviceProperties(&prop, 0);
    // gfx950 supports 64KB LDS, older archs (gfx942, etc.) support 48KB
    cached_result = (strncmp(prop.gcnArchName, "gfx950", 6) == 0) ? 1 : 0;
  }
  return cached_result == 1;
}

inline static size_t get_ksmem_size() {
  static const size_t smem_size = is_gfx950_gpu() ? (64 * 1024) : (48 * 1024);
  return smem_size;
}
#endif

// ...

template <size_t SMEM_SIZE>
__device__ void radix_topk_impl(const float* __restrict__ input, int32_t* __restrict__ output, uint32_t length, uint32_t topk) {
  constexpr uint32_t RADIX = 256;
  constexpr uint32_t BLOCK_SIZE = kBlockSize;
  constexpr uint32_t SMEM_INPUT_SIZE = SMEM_SIZE / (2 * sizeof(int32_t));

  alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32];
  // ...
  extern __shared__ uint32_t s_input_idx_1d[];
  uint32_t (*s_input_idx)[SMEM_INPUT_SIZE] = (uint32_t (*)[SMEM_INPUT_SIZE])s_input_idx_1d;
  // ...
}

// ... in deepseek_v4_topk_transform_512 function
#ifdef USE_ROCM
  const size_t kSmem_runtime = get_ksmem_size();
  if (kSmem_runtime == 64 * 1024) {
    setup_kernel_smem_runtime<deepseek_v4_topk_transform_kernel<64 * 1024>>(kSmem_runtime);
    deepseek_v4_topk_transform_kernel<64 * 1024><<<grid, block, kSmem_runtime, stream>>>(params);
  } else {
    setup_kernel_smem_runtime<deepseek_v4_topk_transform_kernel<48 * 1024>>(kSmem_runtime);
    deepseek_v4_topk_transform_kernel<48 * 1024><<<grid, block, kSmem_runtime, stream>>>(params);
  }
#else
  setup_kernel_smem_runtime<deepseek_v4_topk_transform_kernel<kSMEM>>(kSMEM);
  deepseek_v4_topk_transform_kernel<kSMEM><<<grid, block, kSMEM, stream>>>(params);
#endif

설명:

  • 런타임 공유 메모리 크기 감지 (is_gfx950_gpu, get_ksmem_size): FP8 커널과 유사하게, ROCm 환경에서 GPU 아키텍처를 감지하여 공유 메모리 크기를 결정합니다. gfx950은 64KB의 공유 메모리를 지원하는 반면, gfx942는 48KB를 지원합니다. get_ksmem_size 함수는 이 정보를 바탕으로 적절한 크기를 반환합니다.
  • 템플릿화된 커널 및 공유 메모리 설정: radix_topk_impl 함수는 SMEM_SIZE를 템플릿 인자로 받아, 컴파일 시점에 해당 크기에 맞게 공유 메모리 접근 로직(SMEM_INPUT_SIZE)이 결정되도록 합니다. 또한, deepseek_v4_topk_transform_512 함수 내에서 런타임에 감지된 kSmem_runtime 값에 따라 적절한 템플릿 인자(48KB 또는 64KB)를 가진 커널을 선택하고, cudaFuncSetAttribute를 통해 MaxDynamicSharedMemorySize를 설정합니다. 이는 커널이 실제로 필요로 하는 공유 메모리 양을 GPU에 알려주어 효율적인 메모리 할당을 가능하게 합니다.
  • extern __shared__ 사용: 공유 메모리 선언 방식이 extern __shared__ uint32_t s_input_idx[][SMEM_INPUT_SIZE];에서 extern __shared__ uint32_t s_input_idx_1d[]; uint32_t (*s_input_idx)[SMEM_INPUT_SIZE] = (uint32_t (*)[SMEM_INPUT_SIZE])s_input_idx_1d;로 변경되었습니다. 이는 동적으로 할당되는 공유 메모리의 크기를 유연하게 처리하기 위한 표준적인 방법입니다.

이 최적화를 통해 각 GPU 아키텍처는 사용 가능한 공유 메모리 용량을 최대한 활용할 수 있게 됩니다. 더 큰 공유 메모리는 더 많은 데이터를 캐싱하여 메모리 접근 지연 시간을 줄이고, 결과적으로 TopK 연산의 성능을 향상시킵니다.

3. Build System (setup_rocm.py)

빌드 시스템에서의 변경은 주로 최신 컴파일러 옵션과의 호환성을 맞추는 데 중점을 둡니다.

변경 전 (암시적): 이전에는 --amdgpu-target 플래그를 사용했을 수 있습니다.

변경 후:

# Fixed deprecated --amdgpu-target flag to --offload-arch

설명:

--amdgpu-target 플래그는 더 이상 사용되지 않으며, 대신 --offload-arch 플래그를 사용하도록 수정되었습니다. 이는 ROCm 컴파일러의 최신 버전에 맞춰 빌드 프로세스를 업데이트하는 필수적인 변경입니다. 이 변경은 코드 자체의 성능 최적화라기보다는, 빌드 과정의 정확성과 최신성을 보장하는 역할을 합니다.

왜 이게 좋은가?

이 PR의 핵심적인 가치는 '런타임 최적화''멀티 아키텍처 지원 강화'에 있습니다. 이전에는 특정 GPU 아키텍처에 맞춰 커널을 컴파일하거나, 모든 아키텍처에 대해 보수적인 기본 설정을 사용해야 했습니다. 하지만 이 PR은 다음과 같은 이점을 제공합니다:

  1. 성능 극대화: 각 GPU 아키텍처(gfx942, gfx950)의 고유한 하드웨어 특성(FP8 지원 범위, 공유 메모리 크기)을 런타임에 감지하고 이를 커널 실행에 반영합니다. 이를 통해 각 아키텍처는 자신의 성능 잠재력을 최대한 발휘할 수 있습니다. 예를 들어, gfx950의 더 큰 공유 메모리나 넓은 FP8 범위를 활용하여 연산 속도를 높일 수 있습니다.
  2. 정확도 유지: FP8 양자화에서 아키텍처별 FP8 최대값을 정확히 적용함으로써, 이전의 고정된 최대값으로 인해 발생할 수 있었던 데이터 손실이나 부정확한 표현을 방지합니다. 이는 모델의 정확도를 유지하면서 양자화의 이점을 누릴 수 있게 합니다.
  3. 코드 재사용성 및 유지보수성 향상: 단일 코드베이스로 여러 ROCm 아키텍처를 지원할 수 있게 되었습니다. 이는 코드 중복을 줄이고, 새로운 아키텍처 지원 추가 또는 기존 아키텍처 최적화 시 유지보수성을 크게 향상시킵니다.
  4. Zero-overhead 추상화: GPU 아키텍처 감지 로직이 static 변수를 통해 캐싱되어 프로그램 실행 중 단 한 번만 수행됩니다. 이는 런타임 성능에 거의 영향을 주지 않으면서도 동적인 최적화를 가능하게 합니다.

일반적인 교훈:

  • 하드웨어 특성 인지: GPU 커널 개발 시, 타겟 아키텍처의 하드웨어 특성(레지스터 파일 크기, 공유 메모리 용량, 명령어 세트, 데이터 타입 지원 등)을 깊이 이해하는 것이 중요합니다. 이를 통해 최적의 성능을 이끌어낼 수 있습니다.
  • 런타임 디스패치 활용: 다양한 하드웨어 환경을 지원해야 할 경우, 컴파일 타임 분기(compile-time branching)보다는 런타임 디스패치(runtime dispatch)가 더 유연하고 효율적인 솔루션이 될 수 있습니다. 특히, GPU 아키텍처 감지 후 커널 실행 방식을 동적으로 결정하는 것은 성능 최적화에 효과적입니다.
  • 템플릿 메타프로그래밍: C++ 템플릿을 활용하여 컴파일 타임에 다양한 설정을 적용하는 것은 코드 중복을 줄이고 가독성을 높이는 좋은 방법입니다. 이 PR에서는 SMEM_SIZEFP8_MAX와 같은 값을 템플릿 인자로 사용하여 유연성을 확보했습니다.
  • 빌드 시스템 관리: 최신 컴파일러 및 라이브러리 버전과의 호환성을 유지하기 위해 빌드 시스템(setup.py, CMake 등)을 정기적으로 업데이트하는 것이 중요합니다.

결론

이 PR은 SGLang 프로젝트가 ROCm 환경에서 다양한 GPU 아키텍처를 더욱 효과적으로 지원하고 성능을 최적화하는 데 크게 기여했습니다. 런타임에 GPU 특성을 감지하고 이에 맞춰 커널 동작을 동적으로 조정하는 접근 방식은, 복잡한 하드웨어 환경에서 최고 성능을 달성하기 위한 모범 사례를 보여줍니다. 이러한 최적화는 모델 추론 속도를 향상시키고, 더 넓은 범위의 AMD GPU 사용자들에게 SGLang의 이점을 제공할 것입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글