본문으로 건너뛰기

[SGLang] RoPE 변형: 로타리 위치 인코딩의 다양한 구현

들어가며

RoPE(Rotary Position Embedding)는 현대 LLM의 사실상 표준 위치 인코딩이다. 하지만 훈련 시 컨텍스트 길이를 넘어서는 추론을 위해 다양한 스케일링 기법이 등장했다. SGLang은 rotary_embedding/ 패키지 아래 10가지 이상의 RoPE 변형을 구현하며, 팩토리 패턴으로 모델 설정에 따라 자동 선택한다.

구조도

rotary_embedding/
├── base.py      ── RotaryEmbedding (기본)
│                   └── LinearScalingRotaryEmbedding
├── rope_variant.py
│   ├── DynamicNTKScalingRotaryEmbedding
│   ├── DynamicNTKAlphaRotaryEmbedding
│   ├── Llama3RotaryEmbedding
│   ├── DeepseekScalingRotaryEmbedding
│   ├── Phi3LongRoPEScaledRotaryEmbedding
│   ├── FourierRotaryEmbedding
│   ├── Gemma4RotaryEmbedding
│   └── DualChunkRotaryEmbedding
├── yarn.py      ── YaRNScalingRotaryEmbedding
├── mrope.py     ── MRotaryEmbedding (멀티모달)
├── factory.py   ── get_rope() 팩토리 함수
└── utils.py     ── apply_rotary_emb 헬퍼

핵심 코드 분석

기본 RoPE: cos/sin 캐시 사전 계산

기본 RotaryEmbedding은 초기화 시 max_position_embeddings까지의 cos/sin 캐시를 계산하고 버퍼로 저장한다.

class RotaryEmbedding(MultiPlatformOp):
    def _compute_inv_freq(self, base):
        inv_freq = 1.0 / (
            base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
                     / self.rotary_dim)
        )
        return inv_freq

    def _compute_cos_sin_cache(self):
        inv_freq = self._compute_inv_freq(self.base)
        t = torch.arange(self.max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
        return cache

CUDA에서는 캐시를 FP32로 유지하여 수치 안정성을 보장하고, apply_rope_with_cos_sin_cache_inplace로 in-place 연산한다.

동적 캐시 확장

캐시 길이를 초과하는 위치가 들어오면 점진적으로 확장한다. 재할당 빈도를 줄이기 위해 정렬(align) 단위로 확장한다.

def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
    cur_len = int(self.cos_sin_cache.shape[0])
    if needed_max_pos < cur_len:
        return
    align = envs.SGLANG_ROPE_CACHE_ALIGN.get()
    new_len = ((needed_max_pos + align) // align) * align
    # 새 위치만 증분 계산
    t_new = torch.arange(cur_len, new_len, ...)
    freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
    new_rows = torch.cat((freqs_new.cos(), freqs_new.sin()), dim=-1)
    self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0)

YaRN: 회전 보간과 외삽의 혼합

YaRN은 저주파 차원은 외삽(extrapolation), 고주파 차원은 보간(interpolation)을 적용한다. beta_fastbeta_slow로 경계를 결정한다.

class YaRNScalingRotaryEmbedding(RotaryEmbedding):
    def _compute_inv_freq(self, scaling_factor):
        pos_freqs = self.base ** (torch.arange(0, self.rotary_dim, 2, ...) / self.rotary_dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = yarn_find_correction_range(
            self.beta_fast, self.beta_slow, self.rotary_dim, self.base,
            self.max_position_embeddings)
        inv_freq_mask = (1 - yarn_linear_ramp_mask(low, high, ...)) * self.extrapolation_factor
        inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + \
                   inv_freq_extrapolation * inv_freq_mask
        return inv_freq

추가로 magnitude scaling(mscale)을 적용하여 어텐션 분포가 스케일링 후에도 안정적으로 유지되게 한다.

Llama3 RoPE: 주파수별 차등 스케일링

Llama 3은 low_freq_factorhigh_freq_factor를 사용하여 주파수 대역별로 다른 스케일링을 적용한다.

팩토리 패턴: get_rope()

factory.pyget_rope()rope_scaling 딕셔너리의 rope_type 필드를 보고 적절한 RoPE 클래스를 생성한다. 동일 설정에 대해 캐싱(_ROPE_DICT)하여 중복 생성을 방지한다.

def get_rope(head_size, rotary_dim, max_position, base, ..., rope_scaling=None):
    key = (head_size, rotary_dim, max_position, base, ...)
    if key in _ROPE_DICT:
        return _ROPE_DICT[key]
    
    if rope_scaling is None:
        rotary_emb = RotaryEmbedding(...)
    else:
        scaling_type = rope_scaling["rope_type"]
        if scaling_type == "llama3":
            rotary_emb = Llama3RotaryEmbedding(...)
        elif scaling_type == "yarn":
            rotary_emb = YaRNScalingRotaryEmbedding(...)
        elif scaling_type == "deepseek_yarn":
            rotary_emb = DeepseekScalingRotaryEmbedding(...)
        # ... 10+ variants
    _ROPE_DICT[key] = rotary_emb
    return rotary_emb

CUDA 최적화 forward

CUDA에서는 sglang.jit_kernel.ropeapply_rope_with_cos_sin_cache_inplace를 사용하여 query와 key에 in-place로 RoPE를 적용한다. head_size가 64, 128, 256, 512인 경우 최적화 커널을 사용하고, 그 외에는 fallback 커널로 전환한다.

def forward_cuda(self, positions, query, key, ...):
    if not self.use_fallback_kernel:
        q_rope = query.view(batch_size, -1, self.head_size)
        k_rope = key.view(batch_size, -1, self.head_size)
        apply_rope_with_cos_sin_cache_inplace(
            positions=positions, q=q_rope, k=k_rope,
            cos_sin_cache=self.cos_sin_cache, is_neox=self.is_neox_style)

스케일링 방식 비교

방식 원리 대표 모델
Linear 위치를 factor로 나눔 CodeLlama
Dynamic NTK base를 동적으로 조정 -
YaRN 주파수별 보간/외삽 혼합 Qwen
Llama3 대역별 차등 스케일링 Llama 3
DeepSeek YaRN YaRN + 별도 mscale DeepSeek-V2/V3
Phi3 LongRoPE Short/Long factor 전환 Phi-3

관련 포스트

  • Model Configuration 시스템: 모델 설정 관리
  • Hardware Backends: MLX, NPU, XPU 하드웨어 추상화

참고

  • 소스 코드: python/sglang/srt/layers/rotary_embedding/
  • RoFormer: Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
  • YaRN: Peng et al., "YaRN: Efficient Context Window Extension of Large Language Models" (2023)

댓글

관련 포스트

SGLang 의 다른글