본문으로 건너뛰기

[faster-qwen3-tts] SDPA 전환으로 BF16 StaticCache hidden-state 발산 수정

PR 링크: andimarafioti/faster-qwen3-tts#48 상태: Merged | 변경: +652 / -623

들어가며

faster-qwen3-tts의 CUDA graph 기반 추론에서 ICL(In-Context Learning) voice cloning 모드의 parity 테스트가 실패하는 문제가 발견되었다. 원인은 eager attention과 StaticCache의 조합에서 BF16 GEMM 커널이 패딩된 K-sequence 길이(2048)와 실제 prefill 길이에 따라 다른 누적 결과를 생성하기 때문이었다.

핵심 코드 분석

Attention 구현 기본값 변경

Before:

@classmethod
def from_pretrained(cls, model_name, ..., attn_implementation="eager", ...):
    """Load Qwen3-TTS model and prepare CUDA graphs.
    attn_implementation: Attention implementation (use "eager" on Jetson)
    """

After:

@classmethod
def from_pretrained(cls, model_name, ..., attn_implementation="sdpa", ...):
    """Load Qwen3-TTS model and prepare CUDA graphs.
    attn_implementation: Attention implementation ("sdpa" or "flash_attention_2")
    """

근본 원인과 해결

# 테스트 코드의 설명 주석
# sdpa is required for bfloat16 CUDA-graph correctness: with
# StaticCache padded to max_seq_len the eager BF16 GEMM kernel
# accumulates differently for different K-sequence lengths
# (2048 vs the actual prefill length), causing hidden-state
# divergence that grows step-by-step.
# sdpa's tiled kernel skips fully-masked K blocks, giving
# identical results to DynamicCache regardless of StaticCache
# padding length.

테스트 fixture OOM 수정

Before:

@pytest.fixture(scope="module")
def parity_fixture():
    ...
    return dict(base=base, fast=fast, ...)

After:

@pytest.fixture(scope="class")
def parity_fixture():
    ...
    data = dict(base=base, fast=fast, ...)
    yield data
    del data["base"]
    del data["fast"]
    torch.cuda.empty_cache()
    gc.collect()

scope="module"에서 scope="class"로 변경하여, 각 테스트 클래스 종료 시 모델을 해제한다. 기존에는 모든 fixture가 모듈 끝까지 유지되어 24GB GPU에서 OOM이 발생했다.

왜 이게 좋은가

  1. 수치 정확성: SDPA의 tiled 커널은 fully-masked K 블록을 skip하므로 StaticCache 패딩과 무관하게 일관된 결과를 준다.
  2. OOM 해결: fixture를 class scope으로 좁혀 4개 fixture(8개 모델)가 동시에 GPU 메모리에 올라가는 것을 방지한다.
  3. 성능 영향 없음: SDPA는 eager보다 같거나 빠르며, Jetson에서도 정상 동작한다.

정리

StaticCache + eager attention + BF16의 조합이 만드는 미묘한 수치 발산은 재현하기 어렵고 디버깅이 까다롭다. SDPA로의 전환이 근본적인 해결책이며, fixture 생명주기 관리는 GPU 테스트의 필수 고려사항이다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석과 해석에서 오류가 있을 수 있으니, 정확한 내용은 원본 PR을 참고해주세요.

댓글

관련 포스트

PR Analysis 의 다른글