본문으로 건너뛰기

[flashinfer] FlashInfer: Wide Vector 최적화와 1900줄의 코드 삭제로 달성한 성능 개선

PR 링크: flashinfer-ai/flashinfer#3147 상태: Merged | 변경: +None / -None

들어가며

LLM 추론 가속화 라이브러리인 FlashInfer에서 최근 매우 흥미로운 성능 최적화 PR이 머지되었습니다. 이번 Ameyn/wide vec t1 PR의 핵심은 복잡하게 얽혀있던 레거시 커널들을 정리하고, 메모리 접근 효율을 극대화한 **gdn_wide_vec_kernel**을 도입한 것입니다.

이 PR은 단순히 코드를 추가하는 것에 그치지 않고, 약 1,900줄의 데드 코드를 삭제하면서도 성능을 대폭 끌어올렸습니다. 특히 NVIDIA B200 GPU 환경에서 이론적 최대 DRAM 대역폭의 82%에 도달하는 놀라운 수치를 보여주었습니다. 무엇이 이러한 차이를 만들었는지 코드 수준에서 분석해 보겠습니다.

1. 레거시 제거와 커널 단순화

기존에는 다양한 배치 크기와 워크로드에 대응하기 위해 여러 종류의 커널(cooprow, ilp, mtp 등)이 혼재되어 있었습니다. 이는 유지보수를 어렵게 할 뿐만 아니라 최적화의 초점을 흐트러뜨렸습니다.

이번 PR에서는 성능이 검증된 gdn_wide_vec_kernel을 메인으로 세우고, 아주 작은 배치 사이즈를 위한 mtp_ilp4 커널만을 남겼습니다.

제거된 커널 이유
gdn_decode_bf16state_cooprow_kernel 작은 배치에서 정확도 이슈 및 wide_vec으로 대체 가능
gdn_decode_bf16state_ilp_kernel 특정 shape(Qwen3.5 등)에서 도달 불가능한 경로
gdn_decode_bf16state_mtp_kernel (ILP=8) wide_vec의 tile_v 확장으로 인해 불필요해짐

2. Wide Vector 최적화: LDG.E.128의 활용

가장 핵심적인 변화는 메모리 로드/스토어 방식의 개선입니다. 새로운 gdn_wide_vec_kernel은 128비트 벡터 연산(LDG.E.128, STG.E.128)을 적극적으로 활용하도록 설계되었습니다.

Before: 개별 요소 접근 방식 (개념적)

기존 커널들은 스레드당 처리하는 데이터의 양이 적거나, 메모리 정렬(Alignment)이 최적화되지 않아 여러 번의 메모리 트랜잭션을 발생시켰습니다.

After: Wide Vector 로드

# flashinfer/gdn_kernels/gdn_decode_bf16_state.py 내 주석 및 구조 변경
# 128 threads/CTA = 8 groups × 16 threads
# vec=8 BF16 (128-bit) -> LDG.E.128 / STG.E.128 fast path
# ILP=4 V-rows per thread

이 방식은 한 번의 명령어로 8개의 BF16 요소(128비트)를 동시에 읽어옵니다. 이는 메모리 버스 대역폭을 꽉 채워 사용할 수 있게 하며, 특히 메모리 대역폭에 민감한(Memory-bound) 디코딩 단계에서 엄청난 성능 향상을 가져옵니다.

3. Split-Pool 지원 및 유연한 인덱싱

최신 서빙 엔진들은 투기적 디코딩(Speculative Decoding)이나 MTP(Multi-Token Prediction)를 위해 읽기 전용 상태와 쓰기 전용 상태를 분리하는 Split-Pool 방식을 사용합니다. 이번 PR은 이를 커널 수준에서 네이티브하게 지원합니다.

# benchmarks/bench_gdn_decode.py
-    state: torch.Tensor,  # [B, HV, V, K] - K-last layout (pretranspose)
+    state: torch.Tensor,  # [pool_size, HV, V, K] BF16 (K-last layout)
...
+    if pool_mode == "split":
+        output_state_indices = torch.arange(
+            batch_size, 2 * batch_size, dtype=torch.int32, device="cuda"
+        )
+    else:
+        output_state_indices = None

위 코드처럼 initial_state_indicesoutput_state_indices를 다르게 설정함으로써, 커널 내부에서 추가적인 메모리 복사 없이도 읽기/쓰기 위치를 자유롭게 제어할 수 있게 되었습니다. 이는 Constexpr[bool] same_pool을 통해 컴파일 타임에 최적화되어, 단일 풀 사용 시 오버헤드가 전혀 없습니다.

4. 성능 분석: 왜 이게 좋은가?

PR에서 공개된 벤치마크 결과는 압도적입니다.

  • 대역폭 효율: B200 GPU에서 6.57 TB/s를 기록했습니다. 이는 이론적 최대치인 8 TB/s의 82%에 달하는 수치로, 커널이 하드웨어 한계치에 근접하게 최적화되었음을 의미합니다.
  • 속도 향상: 기존 베이스라인 대비 T=2(두 번째 토큰 예측) 환경에서 최대 1.23배(23%)의 속도 향상을 보였습니다.
  • 안정성: intermediate_states 인덱싱 버그(OOB)를 수정하여 대규모 배치 처리 시 발생하던 cudaErrorIllegalAddress 문제를 해결했습니다.

일반적인 교훈

  1. Less is More: 1,900줄의 코드를 삭제하면서 성능을 올린 것은, 특화된 여러 커널보다 잘 설계된 범용 'Fast Path' 커널 하나가 더 강력할 수 있음을 보여줍니다.
  2. Vectorization: CUDA 커널 최적화의 기본은 메모리 정렬과 벡터 로드입니다. 128비트 단위 접근은 현대 GPU 최적화의 필수 요소입니다.
  3. Interface Alignment: 커널의 성능뿐만 아니라, 상위 래퍼(gdn_decode.py)에서 사용자 API를 유지하면서 내부적으로 최적화된 풀 모드로 자동 전환해주는 설계가 인상적입니다.

마치며

이번 FlashInfer의 업데이트는 LLM 추론 엔진이 나아가야 할 방향을 잘 보여줍니다. 하드웨어의 특성을 극한으로 활용하는 커널 설계와, 이를 뒷받침하는 깔끔한 코드 구조의 조화가 돋보이는 PR이었습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글