본문으로 건너뛰기

[SGLang] Block-wise INT8: 블록 단위 정수 양자화

들어가며

INT8 양자화는 부동소수점 대신 정수 연산을 활용하여 추론 속도를 높인다. SGLang의 Block-wise INT8은 가중치를 블록 단위로 나누어 각 블록에 독립적인 스케일을 할당함으로써, 텐서 전체를 하나의 스케일로 표현하는 방식보다 높은 정확도를 달성한다. python/sglang/srt/layers/quantization/blockwise_int8.py에서 구현되어 있다.

구조도

BlockInt8Config
├── is_checkpoint_int8_serialized: bool
├── activation_scheme: "dynamic" (블록 양자화는 dynamic만 지원)
├── weight_block_size: [block_n, block_k]
└── ignored_layers: List[str]

BlockInt8Config.get_quant_method()
├── LinearBase → BlockInt8LinearMethod
└── FusedMoE   → BlockInt8MoEMethod

블록 양자화 메모리 레이아웃:
┌──────────────────────────────┐
│  Weight [N, K]  (int8)       │
├──────────────────────────────┤
│  Scale [N/bn, K/bk] (fp32)  │  bn=block_n, bk=block_k
└──────────────────────────────┘

핵심 코드 분석

1. BlockInt8Config: 설정과 검증

블록 INT8은 INT8 직렬화된 체크포인트와 동적 활성화 스킴만 지원한다.

class BlockInt8Config(QuantizationConfig):
    def __init__(self, is_checkpoint_int8_serialized=False,
                 activation_scheme="dynamic",
                 ignored_layers=None,
                 weight_block_size=None):
        if weight_block_size is not None:
            if not is_checkpoint_int8_serialized:
                raise ValueError(
                    "The block-wise quantization only supports "
                    "int8-serialized checkpoint for now.")
            if len(weight_block_size) != 2:
                raise ValueError(
                    "The quantization block size of weight must "
                    f"have 2 dimensions.")
            if activation_scheme != "dynamic":
                raise ValueError(
                    "The block-wise quantization only supports "
                    "dynamic activation scheme for now.")

세 가지 제약이 있다: (1) INT8 직렬화 체크포인트 필수, (2) 블록 크기 2차원, (3) 동적 활성화만 지원.

2. 가중치 생성과 블록 스케일

가중치는 INT8로, 스케일은 블록 단위 float32로 생성한다.

class BlockInt8LinearMethod(LinearMethodBase):
    def create_weights(self, layer, input_size_per_partition,
                       output_partition_sizes, ...):
        block_n, block_k = (
            self.quant_config.weight_block_size[0],
            self.quant_config.weight_block_size[1],
        )
        weight = ModelWeightParameter(
            data=torch.empty(
                output_size_per_partition, input_size_per_partition,
                dtype=torch.int8),
            input_dim=1, output_dim=0, weight_loader=weight_loader,
        )
        layer.register_parameter("weight", weight)

        scale = BlockQuantScaleParameter(
            data=torch.empty(
                (output_size_per_partition + block_n - 1) // block_n,
                (input_size_per_partition + block_k - 1) // block_k,
                dtype=torch.float32,
            ),
            input_dim=1, output_dim=0, weight_loader=weight_loader,
        )
        layer.register_parameter("weight_scale_inv", scale)

스케일 텐서의 크기는 ceil(N/block_n) x ceil(K/block_k)이다. 예를 들어 가중치가 [4096, 4096]이고 블록 크기가 [128, 128]이면 스케일은 [32, 32]가 된다.

3. Tensor Parallel 정렬 검증

블록 양자화는 Tensor Parallel 파티셔닝과 블록 크기가 정렬되어야 한다.

tp_size = get_tensor_model_parallel_world_size()
# Row parallel: 입력 차원 정렬
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
    if input_size_per_partition % block_k != 0:
        raise ValueError(
            f"Weight input_size_per_partition = "
            f"{input_size_per_partition} is not divisible by "
            f"weight quantization block_k = {block_k}.")
# Column parallel: 출력 차원 정렬
for output_partition_size in output_partition_sizes:
    if output_partition_size % block_n != 0:
        raise ValueError(...)

이 검증은 TP 분할 경계가 블록 경계와 일치하도록 보장한다.

4. 추론 적용

추론 시에는 apply_w8a8_block_int8_linear 커널을 호출한다.

def apply(self, layer, x, bias=None):
    return apply_w8a8_block_int8_linear(
        input=x,
        weight=layer.weight,
        block_size=self.quant_config.weight_block_size,
        weight_scale=layer.weight_scale_inv,
        input_scale=None,  # 동적 활성화: 런타임에 계산
        bias=bias,
    )

input_scale=None은 동적 활성화 양자화를 의미한다. 커널 내부에서 입력의 스케일을 토큰별로 계산한다.

5. MoE 전용 BlockInt8

MoE 레이어는 전문가별로 독립적인 가중치와 스케일을 관리한다.

class BlockInt8MoEMethod(FusedMoEMethodBase):
    def create_weights(self, layer, num_experts, hidden_size,
                       intermediate_size_per_partition, ...):
        # w13: gate + up projection (fused)
        w13_weight = torch.nn.Parameter(
            torch.empty(num_experts,
                        2 * intermediate_size_per_partition,
                        hidden_size, dtype=torch.int8),
            requires_grad=False,
        )
        # w13 스케일
        w13_weight_scale = torch.nn.Parameter(
            torch.ones(num_experts,
                2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
                (hidden_size + block_k - 1) // block_k,
                dtype=torch.float32),
            requires_grad=False,
        )

MoE에서 w13은 gate와 up projection이 결합(fused)된 형태이므로, 출력 차원이 2 * intermediate_size이다.

6. MoE Triton 커널 실행

MoE 추론은 Triton 기반 Fused MoE 커널을 사용한다.

def apply(self, layer, dispatch_output):
    quant_info = TritonMoeQuantInfo(
        w13_weight=layer.w13_weight,
        w2_weight=layer.w2_weight,
        use_int8_w8a8=True,
        w13_scale=layer.w13_weight_scale_inv,
        w2_scale=layer.w2_weight_scale_inv,
        a13_scale=layer.w13_input_scale,
        a2_scale=layer.w2_input_scale,
        block_shape=self.quant_config.weight_block_size,
    )
    return self.runner.run(dispatch_output, quant_info)

use_int8_w8a8=True 플래그로 INT8 전용 커널 경로를 선택한다.

텐서별 vs 블록별 양자화 비교

항목 텐서별 INT8 블록별 INT8
스케일 수 텐서당 1개 블록당 1개
스케일 오버헤드 무시 가능 0.1-1%
이상치 처리 전체에 영향 블록 내 격리
정확도 낮음 높음
구현 복잡도 낮음 높음
커널 요구사항 표준 INT8 GEMM 블록 스케일 INT8 GEMM
최소 GPU SM80 SM80

설계 근거

  1. 블록 크기 선택: 일반적으로 [128, 128]이 사용된다. 이는 GPU GEMM 타일 크기와 일치하여 커널 오버헤드를 최소화한다.
  2. 동적 활성화만 지원: 블록 양자화는 정적 보정 없이도 높은 정확도를 달성하므로, 동적 활성화로 충분하다.
  3. FP32 스케일: INT8의 제한된 범위(-128~127)를 보상하기 위해 스케일은 FP32로 저장하여 정밀한 역변환을 보장한다.
  4. Triton MoE 전용 러너: MoE는 전문가별 독립 양자화가 필요하며, Triton 커널이 이를 효율적으로 처리한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글