본문으로 건너뛰기

[sglang] Ngram Corpus를 Torch cpp_extension에서 TVM FFI로 마이그레이션

PR 링크: sgl-project/sglang#21920 상태: Merged | 변경: +271 / -116

들어가며

SGLang의 speculative decoding에서 사용하는 ngram corpus 모듈은 기존에 torch.utils.cpp_extension을 통해 C++ 코드를 빌드했다. 이 방식은 PyTorch 빌드 시스템에 강한 의존성을 가지며, SGLang 내부의 JIT kernel 인프라(TVM FFI 기반)와 분리되어 있었다. 이 PR은 ngram corpus의 C++ 코드를 jit_kernel/csrc/ngram_corpus/로 이동하고 TVM FFI를 통해 Python에 바인딩하여 빌드 경로를 통일한다.

핵심 코드 분석

1. TVM FFI 바인딩 객체 정의 (새 파일)

Before:

// torch cpp_extension 기반 - setup.py에서 별도 빌드
// PyBind11을 통한 Python 바인딩

After:

struct NgramCorpusObj : public tvm::ffi::Object {
  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.NgramCorpus", NgramCorpusObj, tvm::ffi::Object);
  static constexpr bool _type_mutable = true;

  NgramCorpusObj(int64_t capacity, int64_t max_trie_depth, ...) {
    ngram::Param param;
    param.max_trie_depth = static_cast<size_t>(max_trie_depth);
    // ...
    ngram_ = std::make_unique<ngram::Ngram>(static_cast<size_t>(capacity), param);
  }

  void async_insert(const tvm::ffi::TensorView tokens_flat, const tvm::ffi::TensorView offsets) { ... }
  void batch_match(const tvm::ffi::TensorView tokens_flat, ...) { ... }
};

NgramCorpusObj는 TVM FFI의 Object를 상속하여 Python에서 직접 생성/호출 가능한 객체로 등록된다. tvm::ffi::TensorView를 통해 텐서 데이터를 zero-copy로 전달받는다.

2. Python JIT 로더 통합

Before:

# torch.utils.cpp_extension.load()로 별도 빌드

After:

@cache_once
def get_ngram_corpus_cls():
    module = load_jit(
        "ngram_corpus",
        cpp_files=[
            "ngram_corpus/result.cpp",
            "ngram_corpus/trie.cpp",
            "ngram_corpus/ngram.cpp",
            "ngram_corpus/ngram_corpus_ffi.cpp",
        ],
        header_only=False,
    )
    module.register_once()

    @tvm_ffi.register_object("sgl.NgramCorpus")
    class NgramCorpusFFI(tvm_ffi.Object):
        def match(self, batch_tokens):
            tokens_flat, offsets = _to_csr(batch_tokens)
            out_tokens = torch.zeros(batch_size * d, dtype=torch.int32)
            out_mask = torch.zeros(batch_size * d * d, dtype=torch.uint8)
            self.batch_match(tokens_flat, offsets, out_tokens, out_mask)
            return out_tokens.numpy(), out_mask.numpy()

load_jit() 헬퍼를 통해 TVM FFI JIT 컴파일 파이프라인을 재사용한다. CSR(Compressed Sparse Row) 포맷으로 배치 토큰을 패킹하여 C++에 전달하는 _to_csr 유틸리티도 추가되었다.

왜 이게 좋은가

  • 빌드 의존성 감소: torch cpp_extension 대신 TVM FFI를 사용하여 PyTorch 빌드 시스템 의존성 제거
  • 경로 통일: 다른 JIT kernel(rmsnorm, rope 등)과 동일한 jit_kernel/ 경로에서 관리
  • 캐싱 효율: @cache_once 데코레이터로 JIT 컴파일 결과를 캐싱하여 반복 로드 비용 제거

정리

이 PR은 기존 ngram corpus C++ 코드의 기능을 그대로 유지하면서 바인딩 레이어만 TVM FFI로 교체한 리팩터링이다. 소스 파일을 srt/speculative/cpp_ngram/에서 jit_kernel/csrc/ngram_corpus/로 이동하고, PyBind11 대신 TVM FFI의 Object 시스템과 register_object를 사용하여 Python 인터페이스를 재구성했다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글