[SGLang] RoPE 변형: 로타리 위치 인코딩의 다양한 구현
들어가며
RoPE(Rotary Position Embedding)는 현대 LLM의 사실상 표준 위치 인코딩이다. 하지만 훈련 시 컨텍스트 길이를 넘어서는 추론을 위해 다양한 스케일링 기법이 등장했다. SGLang은 rotary_embedding/ 패키지 아래 10가지 이상의 RoPE 변형을 구현하며, 팩토리 패턴으로 모델 설정에 따라 자동 선택한다.
구조도
rotary_embedding/
├── base.py ── RotaryEmbedding (기본)
│ └── LinearScalingRotaryEmbedding
├── rope_variant.py
│ ├── DynamicNTKScalingRotaryEmbedding
│ ├── DynamicNTKAlphaRotaryEmbedding
│ ├── Llama3RotaryEmbedding
│ ├── DeepseekScalingRotaryEmbedding
│ ├── Phi3LongRoPEScaledRotaryEmbedding
│ ├── FourierRotaryEmbedding
│ ├── Gemma4RotaryEmbedding
│ └── DualChunkRotaryEmbedding
├── yarn.py ── YaRNScalingRotaryEmbedding
├── mrope.py ── MRotaryEmbedding (멀티모달)
├── factory.py ── get_rope() 팩토리 함수
└── utils.py ── apply_rotary_emb 헬퍼
핵심 코드 분석
기본 RoPE: cos/sin 캐시 사전 계산
기본 RotaryEmbedding은 초기화 시 max_position_embeddings까지의 cos/sin 캐시를 계산하고 버퍼로 저장한다.
class RotaryEmbedding(MultiPlatformOp):
def _compute_inv_freq(self, base):
inv_freq = 1.0 / (
base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
/ self.rotary_dim)
)
return inv_freq
def _compute_cos_sin_cache(self):
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
CUDA에서는 캐시를 FP32로 유지하여 수치 안정성을 보장하고, apply_rope_with_cos_sin_cache_inplace로 in-place 연산한다.
동적 캐시 확장
캐시 길이를 초과하는 위치가 들어오면 점진적으로 확장한다. 재할당 빈도를 줄이기 위해 정렬(align) 단위로 확장한다.
def _ensure_cos_sin_cache_length(self, needed_max_pos: int):
cur_len = int(self.cos_sin_cache.shape[0])
if needed_max_pos < cur_len:
return
align = envs.SGLANG_ROPE_CACHE_ALIGN.get()
new_len = ((needed_max_pos + align) // align) * align
# 새 위치만 증분 계산
t_new = torch.arange(cur_len, new_len, ...)
freqs_new = torch.einsum("i,j->ij", t_new, inv_freq)
new_rows = torch.cat((freqs_new.cos(), freqs_new.sin()), dim=-1)
self.cos_sin_cache = torch.cat((self.cos_sin_cache, new_rows), dim=0)
YaRN: 회전 보간과 외삽의 혼합
YaRN은 저주파 차원은 외삽(extrapolation), 고주파 차원은 보간(interpolation)을 적용한다. beta_fast와 beta_slow로 경계를 결정한다.
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
def _compute_inv_freq(self, scaling_factor):
pos_freqs = self.base ** (torch.arange(0, self.rotary_dim, 2, ...) / self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = yarn_find_correction_range(
self.beta_fast, self.beta_slow, self.rotary_dim, self.base,
self.max_position_embeddings)
inv_freq_mask = (1 - yarn_linear_ramp_mask(low, high, ...)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + \
inv_freq_extrapolation * inv_freq_mask
return inv_freq
추가로 magnitude scaling(mscale)을 적용하여 어텐션 분포가 스케일링 후에도 안정적으로 유지되게 한다.
Llama3 RoPE: 주파수별 차등 스케일링
Llama 3은 low_freq_factor와 high_freq_factor를 사용하여 주파수 대역별로 다른 스케일링을 적용한다.
팩토리 패턴: get_rope()
factory.py의 get_rope()는 rope_scaling 딕셔너리의 rope_type 필드를 보고 적절한 RoPE 클래스를 생성한다. 동일 설정에 대해 캐싱(_ROPE_DICT)하여 중복 생성을 방지한다.
def get_rope(head_size, rotary_dim, max_position, base, ..., rope_scaling=None):
key = (head_size, rotary_dim, max_position, base, ...)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(...)
else:
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
rotary_emb = Llama3RotaryEmbedding(...)
elif scaling_type == "yarn":
rotary_emb = YaRNScalingRotaryEmbedding(...)
elif scaling_type == "deepseek_yarn":
rotary_emb = DeepseekScalingRotaryEmbedding(...)
# ... 10+ variants
_ROPE_DICT[key] = rotary_emb
return rotary_emb
CUDA 최적화 forward
CUDA에서는 sglang.jit_kernel.rope의 apply_rope_with_cos_sin_cache_inplace를 사용하여 query와 key에 in-place로 RoPE를 적용한다. head_size가 64, 128, 256, 512인 경우 최적화 커널을 사용하고, 그 외에는 fallback 커널로 전환한다.
def forward_cuda(self, positions, query, key, ...):
if not self.use_fallback_kernel:
q_rope = query.view(batch_size, -1, self.head_size)
k_rope = key.view(batch_size, -1, self.head_size)
apply_rope_with_cos_sin_cache_inplace(
positions=positions, q=q_rope, k=k_rope,
cos_sin_cache=self.cos_sin_cache, is_neox=self.is_neox_style)
스케일링 방식 비교
| 방식 | 원리 | 대표 모델 |
|---|---|---|
| Linear | 위치를 factor로 나눔 | CodeLlama |
| Dynamic NTK | base를 동적으로 조정 | - |
| YaRN | 주파수별 보간/외삽 혼합 | Qwen |
| Llama3 | 대역별 차등 스케일링 | Llama 3 |
| DeepSeek YaRN | YaRN + 별도 mscale | DeepSeek-V2/V3 |
| Phi3 LongRoPE | Short/Long factor 전환 | Phi-3 |
관련 포스트
- Model Configuration 시스템: 모델 설정 관리
- Hardware Backends: MLX, NPU, XPU 하드웨어 추상화
참고
- 소스 코드:
python/sglang/srt/layers/rotary_embedding/ - RoFormer: Su et al., "RoFormer: Enhanced Transformer with Rotary Position Embedding" (2021)
- YaRN: Peng et al., "YaRN: Efficient Context Window Extension of Large Language Models" (2023)
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Activation Functions: SiLU, GELU 커스텀 구현
- 현재글 : [SGLang] RoPE 변형: 로타리 위치 인코딩의 다양한 구현
- 다음글 [SGLang] Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리
댓글