[sglang] TRT-LLM Sparse MLA 커널의 prefill 배치 지원
PR 링크: sgl-project/sglang#21783 상태: Merged | 변경: +12 / -14
들어가며
DeepSeek의 NSA(Native Sparse Attention) 구현에서 TRT-LLM sparse MLA 커널은 decode 배치에서만 올바르게 동작하고 있었다. prefill 배치에서는 decode용 page table 변환 함수가 사용되어 잘못된 attention 결과를 생성할 수 있었다. 이 PR은 prefill/decode를 구분하여 각각 올바른 page table 변환 함수를 호출하도록 수정한다.
핵심 코드 분석
Prefill용 page table 변환 분기 추가
Before:
def _forward_trtllm(self, ...):
if envs.SGLANG_NSA_FUSE_TOPK.get():
page_table_1 = topk_indices
else:
# decode용 변환만 존재
page_table_1 = transform_index_page_table_decode(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
page_size=1,
)
After:
def _forward_trtllm(self, ..., is_prefill: bool = False):
if envs.SGLANG_NSA_FUSE_TOPK.get():
page_table_1 = topk_indices
elif is_prefill:
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
else:
page_table_1 = transform_index_page_table_decode(...)
is_prefill 파라미터를 추가하고, prefill일 때 transform_index_page_table_prefill을 호출한다. 이 함수는 extend_lens_cpu를 추가로 받아 가변 길이 prefill 시퀀스에 대한 page table을 올바르게 구성한다.
호출 측에서도 forward_extend에서 is_prefill=True를 전달:
# forward_extend에서 호출 시
self._forward_trtllm(..., is_prefill=True)
왜 이게 좋은가
- 정확도 개선: prefill 배치에서 잘못된 page table로 인한 attention 계산 오류 수정
- 임시 workaround 제거: 기존에 prefill 시 dense attention으로 fallback하던 128K threshold 제거
- 경고 메시지 정리: "TRTLLM sparse MLA kernel requires MHA as prefill impl" 경고 불필요
정리
TRT-LLM sparse MLA 커널이 prefill과 decode에서 서로 다른 page table 레이아웃을 필요로 하는 문제를 is_prefill 파라미터 하나로 깔끔하게 해결했다. 기존에 prefill을 MHA로 fallback시키던 임시 코드도 함께 제거되었다.
참고 자료
- sgl-project/sglang#21783 — 원본 PR
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] HiRadixCache에서 TTL 기반 hard pin 기능 제거
- 현재글 : [sglang] TRT-LLM Sparse MLA 커널의 prefill 배치 지원
- 다음글 [sglang] Multi-GPU VLM 서빙에서 ShmPointerMMData broadcast race condition 수정
댓글