본문으로 건너뛰기

[SGLang] Activation Functions: SiLU, GELU 커스텀 구현

들어가며

LLM의 MLP 블록에서 활성화 함수는 게이트 메커니즘과 결합되어 사용된다. 예를 들어 LLaMA의 SwiGLU는 SiLU(x_gate) * x_up 형태다. 이를 별도로 계산하면 중간 텐서 할당과 메모리 읽기/쓰기가 이중으로 발생한다. SGLang은 이를 단일 커널로 융합(Fuse)하여 성능을 높인다.

구조도

activation.py
├── SiluAndMul      ← LLaMA, Mistral 등 SwiGLU 모델
├── GeluAndMul      ← Gemma 등 GEGLU 모델
├── NewGELU         ← GPT-NeoX 스타일 GELU
├── QuickGELU       ← CLIP, Siglip 등
├── XIELU           ← 실험적 활성화 (arxiv:2411.13010)
├── ReLU2           ← 제곱 ReLU
├── ScaledActivation← AWQ 양자화용 스케일 활성화
└── _ACTIVATION_REGISTRY ← 이름 기반 조회

핵심 코드 분석

MultiPlatformOp 기반 분기

SGLang의 활성화 함수는 MultiPlatformOp을 상속하여 플랫폼별 최적 구현을 자동 선택한다. CUDA, CPU (AMX), NPU, XPU 각각에 대해 별도 forward_* 메서드를 정의한다.

class SiluAndMul(MultiPlatformOp):
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        return F.silu(x[..., :d]) * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        silu_and_mul(x, out)
        return out

forward_native는 PyTorch 기본 연산을 사용하고, forward_cudasgl_kernel의 Fused 커널을 호출한다. Fused 커널은 입력 텐서를 반으로 나누어 한 번의 GPU 커널 실행으로 SiLU(x[:d]) * x[d:]를 계산한다.

NPU/CPU 플랫폼별 최적화

NPU에서는 Ascend의 npu_swiglu를 직접 호출하고, CPU에서는 Intel AMX 지원 여부를 확인하여 최적화된 커널을 사용한다.

def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
    out = torch_npu.npu_swiglu(x)
    return out

def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
    if _is_cpu_amx_available:
        out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
        return out
    else:
        return self.forward_native(x)

GeluAndMul: tanh vs none 분기

GELU에는 정확한 버전(none)과 tanh 근사 버전이 있다. GeluAndMul은 생성자에서 approximate 파라미터를 받아 적절한 Fused 커널을 선택한다.

class GeluAndMul(MultiPlatformOp):
    def __init__(self, approximate="tanh"):
        super().__init__()
        self.approximate = approximate

    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = x.shape[:-1] + (d,)
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
        if self.approximate == "tanh":
            gelu_tanh_and_mul(x, out)
        elif self.approximate == "none":
            gelu_and_mul(x, out)
        return out

ScaledActivation: 양자화 대응

AWQ와 같은 양자화 방식은 활성화 후 스케일 팩터를 적용해야 한다. ScaledActivation은 임의의 활성화 함수를 감싸서 스케일링을 추가한다.

class ScaledActivation(nn.Module):
    def __init__(self, act_module, intermediate_size, input_is_parallel=True, ...):
        self.act = act_module
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size, tp_size)
        self.scales = nn.Parameter(
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))

    def forward(self, x):
        return self.act(x) / self.scales

Tensor Parallel 환경에서는 intermediate_sizetp_size로 나누어 각 GPU에 해당하는 스케일만 보유한다.

활성화 함수 레지스트리

get_act_fn 함수는 문자열 이름으로 활성화 함수를 조회한다. 양자화 설정에 따라 ScaledActivation으로 자동 래핑된다.

_ACTIVATION_REGISTRY = {
    "gelu": nn.GELU(),
    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
    "gelu_new": NewGELU(),
    "relu2": ReLU2(),
    "xielu": XIELU(),
}

def get_act_fn(act_fn_name, quant_config=None, intermediate_size=None, ...):
    act_fn = _ACTIVATION_REGISTRY[act_fn_name]
    if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
        return ScaledActivation(act_fn, intermediate_size, ...)
    return act_fn

Fused 커널의 성능 이점

구분 Native (PyTorch) Fused (sgl_kernel)
커널 호출 3회 (silu, mul, slice) 1회
중간 텐서 2개 추가 할당 0개
메모리 읽기 3회 1회

Fused 커널은 입력을 한 번만 읽고, 중간 결과를 레지스터에 유지하므로 메모리 대역폭 병목을 크게 줄인다.

관련 포스트

  • Linear Layer: 양자화 통합 선형 레이어의 설계
  • sgl-kernel: 커스텀 C++/CUDA 커널 라이브러리

참고

댓글

관련 포스트

SGLang 의 다른글