[faster-qwen3-tts] SDPA 전환으로 BF16 StaticCache hidden-state 발산 수정
PR 링크: andimarafioti/faster-qwen3-tts#48 상태: Merged | 변경: +652 / -623
들어가며
faster-qwen3-tts의 CUDA graph 기반 추론에서 ICL(In-Context Learning) voice cloning 모드의 parity 테스트가 실패하는 문제가 발견되었다. 원인은 eager attention과 StaticCache의 조합에서 BF16 GEMM 커널이 패딩된 K-sequence 길이(2048)와 실제 prefill 길이에 따라 다른 누적 결과를 생성하기 때문이었다.
핵심 코드 분석
Attention 구현 기본값 변경
Before:
@classmethod
def from_pretrained(cls, model_name, ..., attn_implementation="eager", ...):
"""Load Qwen3-TTS model and prepare CUDA graphs.
attn_implementation: Attention implementation (use "eager" on Jetson)
"""
After:
@classmethod
def from_pretrained(cls, model_name, ..., attn_implementation="sdpa", ...):
"""Load Qwen3-TTS model and prepare CUDA graphs.
attn_implementation: Attention implementation ("sdpa" or "flash_attention_2")
"""
근본 원인과 해결
# 테스트 코드의 설명 주석
# sdpa is required for bfloat16 CUDA-graph correctness: with
# StaticCache padded to max_seq_len the eager BF16 GEMM kernel
# accumulates differently for different K-sequence lengths
# (2048 vs the actual prefill length), causing hidden-state
# divergence that grows step-by-step.
# sdpa's tiled kernel skips fully-masked K blocks, giving
# identical results to DynamicCache regardless of StaticCache
# padding length.
테스트 fixture OOM 수정
Before:
@pytest.fixture(scope="module")
def parity_fixture():
...
return dict(base=base, fast=fast, ...)
After:
@pytest.fixture(scope="class")
def parity_fixture():
...
data = dict(base=base, fast=fast, ...)
yield data
del data["base"]
del data["fast"]
torch.cuda.empty_cache()
gc.collect()
scope="module"에서 scope="class"로 변경하여, 각 테스트 클래스 종료 시 모델을 해제한다. 기존에는 모든 fixture가 모듈 끝까지 유지되어 24GB GPU에서 OOM이 발생했다.
왜 이게 좋은가
- 수치 정확성: SDPA의 tiled 커널은 fully-masked K 블록을 skip하므로 StaticCache 패딩과 무관하게 일관된 결과를 준다.
- OOM 해결: fixture를 class scope으로 좁혀 4개 fixture(8개 모델)가 동시에 GPU 메모리에 올라가는 것을 방지한다.
- 성능 영향 없음: SDPA는 eager보다 같거나 빠르며, Jetson에서도 정상 동작한다.
정리
StaticCache + eager attention + BF16의 조합이 만드는 미묘한 수치 발산은 재현하기 어렵고 디버깅이 까다롭다. SDPA로의 전환이 근본적인 해결책이며, fixture 생명주기 관리는 GPU 테스트의 필수 고려사항이다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었습니다. 코드 분석과 해석에서 오류가 있을 수 있으니, 정확한 내용은 원본 PR을 참고해주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Open WebUI] 저장 버튼 스피너 인라인 레이아웃 수정
- 현재글 : [faster-qwen3-tts] SDPA 전환으로 BF16 StaticCache hidden-state 발산 수정
- 다음글 [Axolotl] MXFP4 양자화 지원 추가
댓글