본문으로 건너뛰기

[Axolotl] LoRA 커널에 bias, dropout, DoRA, embedding 지원 추가

PR 링크: axolotl-ai-cloud/axolotl#3528 상태: Merged | 변경: +2862 / -463

들어가며

Axolotl의 Triton LoRA 커널은 기존에 가장 기본적인 LoRA(A, B weight + scaling)만 지원했습니다. 이 PR은 실전 학습에서 흔히 사용되는 네 가지 기능을 추가합니다:

  1. Bias: bias="lora_only" 또는 bias="all" 설정 지원
  2. Dropout: LoRA 경로에 dropout 적용
  3. DoRA: Weight-Decomposed Low-Rank Adaptation (magnitude + direction 분리)
  4. Embedding: embedding 레이어의 LoRA 지원

핵심 코드 분석

1. get_lora_parameters 확장

Before:

def get_lora_parameters(proj):
    # ...
    return W, b, quant_state, A, B, s

After:

def get_lora_parameters(proj):
    # ...
    lora_bias = linear_B.bias       # None if bias=False
    dropout = proj.lora_dropout[active_adapter]
    magnitude = proj.lora_magnitude_vector[active_adapter].weight  # DoRA
    return W, b, quant_state, A, B, s, lora_bias, dropout, magnitude

반환값이 6개에서 9개로 확장되었습니다. 각 추가 값은 None일 수 있어 하위 호환성이 유지됩니다.

2. DoRA를 위한 Triton 커널

# kernels/dora.py
@triton.jit
def _dora_fused_norm_kernel(
    W_ptr, B_ptr, A_ptr, mag_ptr, out_ptr,
    out_features, in_features, rank, lora_scale,
    BLOCK_IN: tl.constexpr, BLOCK_R: tl.constexpr,
):
    """Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s*(B[i,:]@A)||_2

    Each program handles one output row. B[row,:] is loaded once,
    then tiles over in_features computing the dot product with A[:,tile].
    """
    row = tl.program_id(0)
    for start in range(0, in_features, BLOCK_IN):
        # Load W[row, cols] and compute B[row,:] @ A[:, cols] tile-by-tile
        ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
        for r in tl.static_range(BLOCK_R):
            b_val = tl.load(B_ptr + row * rank + r, ...)
            a_vals = tl.load(A_ptr + r * in_features + cols, ...)
            ba_vals += b_val * a_vals
        combined = w_vals + lora_scale * ba_vals
        norm_sq_acc += combined * combined

핵심은 B@A 행렬곱을 전체 [out, in] 크기로 실체화하지 않고, row별로 tile 단위로 계산하는 것입니다. 이로써 메모리 사용량이 O(out * in) 에서 O(BLOCK_IN)으로 줄어듭니다.

3. matmul_lora에 dropout과 bias 추가

Before:

def matmul_lora(X, W, b, W_quant, A, B, s, out=None):
    out = torch.matmul(X, W, out=out)
    if A is not None:
        out += s * X @ A @ B
    return out

After:

def matmul_lora(X, W, b, W_quant, A, B, s, out=None,
                X_drop=None, lora_bias=None):
    out = torch.matmul(X, W, out=out)
    if A is not None:
        X_lora = X_drop if X_drop is not None else X
        out += s * X_lora @ A @ B
        if lora_bias is not None:
            out += s * lora_bias
    return out

Dropout이 적용된 입력(X_drop)을 LoRA 경로에만 사용합니다. base weight 경로(X @ W)에는 dropout을 적용하지 않습니다. LoRA bias도 scaling factor s를 곱하여 추가됩니다.

왜 이게 좋은가

  • DoRA 메모리 효율: Triton 커널로 B@A를 실체화하지 않아 DoRA의 weight norm 계산이 메모리 효율적입니다.
  • 캐싱: magnitude._dora_cache에 weight norm을 캐싱하여 optimizer step 사이에 불필요한 재계산을 방지합니다.
  • Dropout 분리: base weight와 LoRA 경로의 dropout을 분리하여 PEFT 논문의 원래 의도를 정확히 구현합니다.
  • 하위 호환: 추가된 파라미터가 모두 None 기본값이므로 기존 코드가 수정 없이 동작합니다.

정리

2862줄 추가의 대규모 PR로, Axolotl의 LoRA 커널을 실전 수준으로 끌어올립니다. bias, dropout, DoRA, embedding을 모두 지원함으로써 PEFT의 다양한 설정 조합을 Triton 최적화된 경로로 처리할 수 있게 되었습니다.

참고 자료


이 포스트는 AI가 작성하였으며, 사실과 다를 수 있습니다. 정확한 정보는 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글