[sglang] Ngram Corpus를 Torch cpp_extension에서 TVM FFI로 마이그레이션
PR 링크: sgl-project/sglang#21920 상태: Merged | 변경: +271 / -116
들어가며
SGLang의 speculative decoding에서 사용하는 ngram corpus 모듈은 기존에 torch.utils.cpp_extension을 통해 C++ 코드를 빌드했다. 이 방식은 PyTorch 빌드 시스템에 강한 의존성을 가지며, SGLang 내부의 JIT kernel 인프라(TVM FFI 기반)와 분리되어 있었다. 이 PR은 ngram corpus의 C++ 코드를 jit_kernel/csrc/ngram_corpus/로 이동하고 TVM FFI를 통해 Python에 바인딩하여 빌드 경로를 통일한다.
핵심 코드 분석
1. TVM FFI 바인딩 객체 정의 (새 파일)
Before:
// torch cpp_extension 기반 - setup.py에서 별도 빌드
// PyBind11을 통한 Python 바인딩
After:
struct NgramCorpusObj : public tvm::ffi::Object {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.NgramCorpus", NgramCorpusObj, tvm::ffi::Object);
static constexpr bool _type_mutable = true;
NgramCorpusObj(int64_t capacity, int64_t max_trie_depth, ...) {
ngram::Param param;
param.max_trie_depth = static_cast<size_t>(max_trie_depth);
// ...
ngram_ = std::make_unique<ngram::Ngram>(static_cast<size_t>(capacity), param);
}
void async_insert(const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { ... }
void batch_match(const tvm::ffi::TensorView tokens_flat, ...) { ... }
};
NgramCorpusObj는 TVM FFI의 Object를 상속하여 Python에서 직접 생성/호출 가능한 객체로 등록된다. tvm::ffi::TensorView를 통해 텐서 데이터를 zero-copy로 전달받는다.
2. Python JIT 로더 통합
Before:
# torch.utils.cpp_extension.load()로 별도 빌드
After:
@cache_once
def get_ngram_corpus_cls():
module = load_jit(
"ngram_corpus",
cpp_files=[
"ngram_corpus/result.cpp",
"ngram_corpus/trie.cpp",
"ngram_corpus/ngram.cpp",
"ngram_corpus/ngram_corpus_ffi.cpp",
],
header_only=False,
)
module.register_once()
@tvm_ffi.register_object("sgl.NgramCorpus")
class NgramCorpusFFI(tvm_ffi.Object):
def match(self, batch_tokens):
tokens_flat, offsets = _to_csr(batch_tokens)
out_tokens = torch.zeros(batch_size * d, dtype=torch.int32)
out_mask = torch.zeros(batch_size * d * d, dtype=torch.uint8)
self.batch_match(tokens_flat, offsets, out_tokens, out_mask)
return out_tokens.numpy(), out_mask.numpy()
load_jit() 헬퍼를 통해 TVM FFI JIT 컴파일 파이프라인을 재사용한다. CSR(Compressed Sparse Row) 포맷으로 배치 토큰을 패킹하여 C++에 전달하는 _to_csr 유틸리티도 추가되었다.
왜 이게 좋은가
- 빌드 의존성 감소: torch cpp_extension 대신 TVM FFI를 사용하여 PyTorch 빌드 시스템 의존성 제거
- 경로 통일: 다른 JIT kernel(rmsnorm, rope 등)과 동일한
jit_kernel/경로에서 관리 - 캐싱 효율:
@cache_once데코레이터로 JIT 컴파일 결과를 캐싱하여 반복 로드 비용 제거
정리
이 PR은 기존 ngram corpus C++ 코드의 기능을 그대로 유지하면서 바인딩 레이어만 TVM FFI로 교체한 리팩터링이다. 소스 파일을 srt/speculative/cpp_ngram/에서 jit_kernel/csrc/ngram_corpus/로 이동하고, PyBind11 대신 TVM FFI의 Object 시스템과 register_object를 사용하여 Python 인터페이스를 재구성했다.
참고 자료
- TVM FFI Documentation — TVM Foreign Function Interface 공식 문서
- sgl-project/sglang#21920 — 원본 PR
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] JIT RMSNorm 커널 업데이트 - Blackwell 최적화 및 벤치마크 통합
- [sglang] fused_qknorm_rope 최적화 - interleave RoPE에서 sincosf 중복 제거
- [sglang] CI 테스트 등록 누락 수정: JIT 커널 테스트/벤치마크 파일 등록
- [sglang] SGLang 스케줄러: 사전 생성 전용 배치 병합 시 is_prefill_only 플래그 로직 개선
- [sglang] SGLang: MiniMax-M2.5 MoE 모델을 위한 FP8 FlashInfer TRT-LLM 라우팅 최적화
PR Analysis 의 다른글
- 이전글 [sglang] HiCache 메모리 누수 수정: host indices clone으로 참조 해제 보장
- 현재글 : [sglang] Ngram Corpus를 Torch cpp_extension에서 TVM FFI로 마이그레이션
- 다음글 [sglang] PD 시나리오에서 상세 캐시 히트 분류 수정
댓글