본문으로 건너뛰기

[sglang] AMD/ROCm 시작 크래시 수정: CuteDSL KDA 커널 Lazy Import 적용

PR 링크: sgl-project/sglang#21428 상태: Merged | 변경: +4 / -3

들어가며

SGLang의 KDA(Key-Driven Attention) 백엔드는 CuTe DSL과 Triton 두 가지 커널 구현을 지원합니다. 기존에는 kda_backend.py 모듈의 최상위에서 CuTe DSL 커널을 import했는데, 이 커널이 CUDA 전용이기 때문에 AMD/ROCm 환경에서는 import 시점에 크래시가 발생했습니다. 이번 PR은 CuTe DSL import를 실제로 해당 백엔드가 선택될 때까지 지연(lazy)시켜 문제를 해결합니다.

핵심 코드 분석

Top-level import를 조건부 import로 변경

Before:

from sglang.srt.layers.attention.linear.kernels.kda_cutedsl import (
    CuteDSLKDAKernel,
)
from sglang.srt.layers.attention.linear.kernels.kda_triton import TritonKDAKernel

class KDABackend:
    def __init__(self, ...):
        if decode_backend.is_cutedsl():
            if not is_cuda():
                raise ValueError("KDA CuTe DSL backend requires CUDA")
            self.decode_kernel = CuteDSLKDAKernel()

After:

from sglang.srt.layers.attention.linear.kernels.kda_triton import TritonKDAKernel

class KDABackend:
    def __init__(self, ...):
        elif decode_backend.is_cutedsl():
            if not is_cuda():
                raise ValueError("KDA CuTe DSL backend requires CUDA")
            from sglang.srt.layers.attention.linear.kernels.kda_cutedsl import (
                CuteDSLKDAKernel,
            )
            self.decode_kernel = CuteDSLKDAKernel()

import를 함수 내부로 이동하여, CuTe DSL 백엔드가 실제로 선택되지 않으면 import 자체가 실행되지 않습니다.

왜 이게 좋은가

  1. 플랫폼 호환성: AMD/ROCm에서 SGLang이 정상 시작됩니다. CuTe DSL이 필요 없는 설정에서는 CUDA 전용 코드가 로드되지 않습니다.
  2. 시작 시간 단축: 사용하지 않는 커널 모듈의 import를 건너뛰어 초기화 시간이 약간 줄어듭니다.
  3. Fail-fast 유지: CuTe DSL이 실제로 선택된 경우에는 is_cuda() 체크가 먼저 실행되어, 명확한 에러 메시지를 제공합니다.

정리

4줄 추가, 3줄 삭제의 최소 변경이지만, 멀티 플랫폼 지원에서 lazy import 패턴은 매우 중요합니다. Top-level import는 해당 모듈을 import하는 순간 실행되므로, 플랫폼 종속 코드는 반드시 필요한 시점에 import해야 합니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글