본문으로 건너뛰기

[SGLang] BitsAndBytes: QLoRA와 NF4 동적 양자화

들어가며

BitsAndBytes는 QLoRA 학습에서 널리 사용되는 양자화 라이브러리이다. SGLang은 BitsAndBytes로 양자화된 모델을 직접 로딩하여 추론할 수 있도록 python/sglang/srt/layers/quantization/bitsandbytes.py에서 통합 레이어를 제공한다. 4비트(NF4/FP4)와 8비트 양자화를 모두 지원하며, MoE 레이어에 대한 역양자화 경로도 포함한다.

구조도

BitsAndBytesConfig
├── load_in_8bit: bool
├── load_in_4bit: bool (기본값)
├── bnb_4bit_quant_type: "fp4" | "nf4"
├── bnb_4bit_compute_dtype: "float32" | "bfloat16"
├── bnb_4bit_use_double_quant: bool
├── llm_int8_threshold: 6.0
└── llm_int8_skip_modules: List[str]

BitsAndBytesConfig.get_quant_method()
├── LinearBase → BitsAndBytesLinearMethod
│   ├── load_in_8bit → _apply_8bit_weight()
│   └── load_in_4bit → _apply_4bit_weight()
└── FusedMoE   → BitsAndBytesMoEMethod
    └── _apply_4bit_dequant() → fused_moe()

핵심 코드 분석

1. BitsAndBytesConfig: 다양한 설정 옵션

BitsAndBytes는 INT8과 NF4/FP4 두 가지 모드를 지원한다.

class BitsAndBytesConfig(QuantizationConfig):
    def __init__(self, load_in_8bit=False, load_in_4bit=True,
                 bnb_4bit_compute_dtype="float32",
                 bnb_4bit_quant_storage="uint8",
                 bnb_4bit_quant_type="fp4",
                 bnb_4bit_use_double_quant=False,
                 llm_int8_threshold=6.0, ...):
        self.load_in_8bit = load_in_8bit
        self.load_in_4bit = load_in_4bit
        self.bnb_4bit_quant_type = bnb_4bit_quant_type
        self.llm_int8_threshold = llm_int8_threshold

llm_int8_threshold은 INT8 모드에서 이상치(outlier) 채널을 FP16으로 처리할 임계값이다. 기본값 6.0은 활성화의 절대값이 6을 넘는 채널을 이상치로 분류한다.

2. 레이어 스킵 로직

BitsAndBytes는 모듈 이름 기반으로 양자화를 제외할 레이어를 결정한다.

def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
    components = prefix.split(".")
    # 컴포넌트 단위 매칭
    substr_check = any(
        module_name in components
        for module_name in llm_int8_skip_modules
    )
    # 프리픽스 단위 매칭
    set_components = set(
        ".".join(components[:i+1]) for i in range(len(components))
    )
    set_llm_int8_skip_modules = set(llm_int8_skip_modules)
    prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
    return substr_check or prefix_check

두 가지 매칭 방식을 OR 연산한다: 컴포넌트 이름 매칭("lm_head" in ["model", "lm_head"])과 프리픽스 매칭("model.lm_head" == "model.lm_head").

3. 4비트 가중치 생성

4비트 가중치는 pack factor에 따라 압축된 uint8 텐서로 생성된다.

def create_qweight_for_4bit():
    quant_ratio = calculate_quant_ratio(params_dtype)
    total_size = input_size_per_partition * sum(output_partition_sizes)
    qweight = torch.nn.Parameter(
        torch.empty(total_size // quant_ratio, 1, dtype=torch.uint8),
        requires_grad=False,
    )
    set_weight_attrs(qweight, {
        "input_dim": 0, "output_dim": 0,
        "pack_factor": quant_ratio,
        "use_bitsandbytes_4bit": True,
    })
    return qweight

quant_ratio는 원본 dtype의 비트 수를 uint8(8비트)로 나눈 값이다. BF16의 경우 16 / 8 = 2이므로, 총 원소 수가 절반으로 줄어든다.

4. INT8 추론: MatmulLtState 관리

INT8 추론은 bitsandbytes의 matmul 함수와 상태 객체를 사용한다.

def _apply_8bit_weight(self, layer, x, bias=None):
    from bitsandbytes import MatmulLtState, matmul

    qweight = layer.weight
    quant_states = qweight.bnb_quant_state
    matmul_states = qweight.matmul_state

    for i in range(len(quant_states)):
        if generation == 0 or generation == 1:
            matmul_states[i] = MatmulLtState()
            matmul_states[i].CB = qweight[offsets[i]:offsets[i+1]]
            matmul_states[i].SCB = quant_states[i].to(x.device)
            matmul_states[i].threshold = self.quant_config.llm_int8_threshold
            matmul_states[i].has_fp16_weights = (
                self.quant_config.llm_int8_has_fp16_weight
            )

generation 카운터로 프로파일링 실행(0)과 첫 번째 추론(1) 시에만 상태를 초기화한다. 이후 생성에서는 캐시된 상태를 재사용하여 오버헤드를 줄인다.

5. 4비트 추론: Custom Op

4비트 추론은 @register_custom_op으로 등록된 커스텀 연산을 사용한다.

@register_custom_op(mutates_args=["out"])
def apply_bnb_4bit(x, weight, offsets, out):
    from bitsandbytes import matmul_4bit

    quant_states = weight.bnb_quant_state
    current_index = 0
    for i in range(len(quant_states)):
        output_size = quant_states[i].shape[0]
        out[:, current_index:current_index + output_size] = matmul_4bit(
            x, weight[offsets[i]:offsets[i+1]].t(), quant_states[i]
        )
        current_index += output_size

@register_custom_optorch.compile과 CUDA Graph 캡처에 호환되도록 함수를 등록한다. mutates_args=["out"]은 out 텐서가 제자리(in-place) 수정됨을 컴파일러에 알린다.

6. MoE 4비트 역양자화

MoE에서는 전문가 가중치를 역양자화한 뒤 fused_moe 커널에 전달한다.

class BitsAndBytesMoEMethod(FusedMoEMethodBase):
    def _apply_4bit_dequant(self, layer):
        from bitsandbytes.functional import dequantize_4bit

        w13 = dequantize_4bit(
            layer.w13_weight.reshape(-1, 1),
            layer.w13_weight.bnb_quant_state,
        )
        w2 = dequantize_4bit(
            layer.w2_weight.reshape(-1, 1),
            layer.w2_weight.bnb_quant_state,
        )
        w13 = w13.reshape(layer.w13_weight.experts_shape)
        w2 = w2.reshape(layer.w2_weight.experts_shape)
        return w13, w2

MoE의 역양자화는 hot path에서 실행되므로 성능 오버헤드가 있다. 코드 주석에 TODO로 개선이 필요함을 명시하고 있다.

설계 근거

  1. 호환성 우선: BitsAndBytes는 HuggingFace 생태계에서 QLoRA 학습 모델의 표준 포맷이다. 직접 로딩 지원으로 변환 없이 추론이 가능하다.
  2. Generation 카운터: INT8 상태 초기화를 프로파일링과 첫 실행에만 수행하여, 이후 추론의 오버헤드를 제거한다.
  3. Custom Op 등록: torch.compile과 CUDA Graph 호환을 위해 bitsandbytes 연산을 커스텀 Op으로 래핑한다.
  4. MoE 역양자화: 전문가별 독립 역양자화 후 fused_moe로 전달하는 구조는 간단하지만, 매 추론마다 역양자화 비용이 발생하는 트레이드오프가 있다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글