본문으로 건너뛰기

[SGLang] Linear Layer: 양자화 통합 선형 레이어의 설계

들어가며

LLM 추론에서 선형 레이어는 전체 연산의 대부분을 차지한다. SGLang의 linear.py는 양자화 메서드를 투명하게 통합하면서 Tensor Parallel을 지원하는 계층 구조를 제공한다. 이 포스트에서는 LinearBase에서 RowParallelLinear까지의 상속 구조와 가중치 로딩 메커니즘을 분석한다.

구조도

LinearBase (추상 기반)
├── ReplicatedLinear          (복제 - 모든 GPU 동일 가중치)
├── ColumnParallelLinear      (열 분할 - 출력 차원 분할)
│   ├── MergedColumnParallelLinear  (QKV 등 Fused)
│   └── QKVParallelLinear          (Q/K/V 독립 로딩)
└── RowParallelLinear         (행 분할 - 입력 차원 분할)

핵심 코드 분석

LinearBase: 양자화 메서드 자동 선택

LinearBasequant_config가 None이면 UnquantizedLinearMethod를, 있으면 해당 양자화 설정에서 적절한 메서드를 가져온다.

class LinearBase(torch.nn.Module):
    def __init__(
        self, input_size: int, output_size: int,
        skip_bias_add: bool = False,
        params_dtype: Optional[torch.dtype] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        if quant_config is None:
            from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
            self.quant_method = UnquantizedLinearMethod()
        else:
            self.quant_method = quant_config.get_quant_method(self, prefix=prefix)

이 설계의 핵심은 모델 코드가 양자화 방식을 알 필요가 없다는 점이다. forward()에서는 항상 self.quant_method.apply(self, x, bias)를 호출하며, 실제 연산은 양자화 메서드가 결정한다.

ColumnParallelLinear: 출력 차원 분할

열 병렬 선형 레이어는 가중치 행렬 A를 열 방향으로 분할한다. Y = XA에서 A = [A_1, ..., A_p]로 나누어 각 GPU가 Y_i = XA_i를 계산한다.

class ColumnParallelLinear(LinearBase):
    def __init__(self, input_size, output_size, bias=True,
                 gather_output=False, ...):
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.output_size_per_partition = divide(self.output_size, tp_size)
        
    def forward(self, input_):
        bias = self.bias if not self.skip_bias_add else None
        output_parallel = self.quant_method.apply(self, input_, bias)
        if self.gather_output:
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel
        return output, output_bias

gather_output=True이면 All-Gather로 전체 출력을 모으고, False이면 각 GPU가 분할된 출력만 보유한다.

MergedColumnParallelLinear: QKV Fused 로딩

QKV 프로젝션처럼 여러 출력을 하나로 합친 레이어는 MergedColumnParallelLinear를 사용한다. 가중치 로딩 시 loaded_shard_id를 통해 각 부분을 올바른 위치에 배치한다.

class MergedColumnParallelLinear(ColumnParallelLinear):
    def __init__(self, input_size, output_sizes: List[int], ...):
        self.output_sizes = output_sizes
        assert all(output_size % tp_size == 0 for output_size in output_sizes)
        super().__init__(input_size=input_size,
                         output_size=sum(output_sizes), ...)

가중치 로더 V2 지원

SGLang은 20개 이상의 양자화 메서드에 대해 weight_loader_v2를 지원한다. V2 로더는 _ColumnvLLMParameter 타입의 파라미터에 대해 load_column_parallel_weight를 직접 호출한다.

WEIGHT_LOADER_V2_SUPPORTED = [
    "CompressedTensorsLinearMethod",
    "AWQMarlinLinearMethod",
    "Fp8LinearMethod",
    "BlockInt8LinearMethod",
    # ... 20+ methods
]

RowParallelLinear: 입력 차원 분할

행 병렬 레이어는 가중치를 행 방향으로 분할하여 각 GPU가 입력의 일부분에 대해 연산한 후 All-Reduce로 결과를 합산한다.

# RowParallelLinear.forward 핵심 흐름
output_parallel = self.quant_method.apply(self, input_)
if self.reduce_results and self.tp_size > 1:
    output_ = tensor_model_parallel_all_reduce(output_parallel)

비교/설계 근거

구분 ColumnParallel RowParallel
분할 차원 출력(열) 입력(행)
통신 All-Gather (선택적) All-Reduce (필수)
사용처 QKV Proj, Gate/Up Proj O Proj, Down Proj
입력 요구 전체 입력 분할된 입력

Column + Row를 쌍으로 사용하면 중간에 통신 없이 분할 상태를 유지할 수 있다. 예를 들어 MLP에서 Gate/Up (Column) -> Activation -> Down (Row) 순서로 처리하면 Column의 분할 출력이 Row의 분할 입력으로 바로 이어진다.

관련 포스트

  • Activation Functions: SiLU, GELU 커스텀 구현
  • Deep GEMM Wrapper: 최적화 행렬 곱 라이브러리

참고

댓글

관련 포스트

SGLang 의 다른글