[Axolotl] LoRA 커널에 bias, dropout, DoRA, embedding 지원 추가
PR 링크: axolotl-ai-cloud/axolotl#3528 상태: Merged | 변경: +2862 / -463
들어가며
Axolotl의 Triton LoRA 커널은 기존에 가장 기본적인 LoRA(A, B weight + scaling)만 지원했습니다. 이 PR은 실전 학습에서 흔히 사용되는 네 가지 기능을 추가합니다:
- Bias:
bias="lora_only"또는bias="all"설정 지원 - Dropout: LoRA 경로에 dropout 적용
- DoRA: Weight-Decomposed Low-Rank Adaptation (magnitude + direction 분리)
- Embedding: embedding 레이어의 LoRA 지원
핵심 코드 분석
1. get_lora_parameters 확장
Before:
def get_lora_parameters(proj):
# ...
return W, b, quant_state, A, B, s
After:
def get_lora_parameters(proj):
# ...
lora_bias = linear_B.bias # None if bias=False
dropout = proj.lora_dropout[active_adapter]
magnitude = proj.lora_magnitude_vector[active_adapter].weight # DoRA
return W, b, quant_state, A, B, s, lora_bias, dropout, magnitude
반환값이 6개에서 9개로 확장되었습니다. 각 추가 값은 None일 수 있어 하위 호환성이 유지됩니다.
2. DoRA를 위한 Triton 커널
# kernels/dora.py
@triton.jit
def _dora_fused_norm_kernel(
W_ptr, B_ptr, A_ptr, mag_ptr, out_ptr,
out_features, in_features, rank, lora_scale,
BLOCK_IN: tl.constexpr, BLOCK_R: tl.constexpr,
):
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s*(B[i,:]@A)||_2
Each program handles one output row. B[row,:] is loaded once,
then tiles over in_features computing the dot product with A[:,tile].
"""
row = tl.program_id(0)
for start in range(0, in_features, BLOCK_IN):
# Load W[row, cols] and compute B[row,:] @ A[:, cols] tile-by-tile
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
for r in tl.static_range(BLOCK_R):
b_val = tl.load(B_ptr + row * rank + r, ...)
a_vals = tl.load(A_ptr + r * in_features + cols, ...)
ba_vals += b_val * a_vals
combined = w_vals + lora_scale * ba_vals
norm_sq_acc += combined * combined
핵심은 B@A 행렬곱을 전체 [out, in] 크기로 실체화하지 않고, row별로 tile 단위로 계산하는 것입니다. 이로써 메모리 사용량이 O(out * in) 에서 O(BLOCK_IN)으로 줄어듭니다.
3. matmul_lora에 dropout과 bias 추가
Before:
def matmul_lora(X, W, b, W_quant, A, B, s, out=None):
out = torch.matmul(X, W, out=out)
if A is not None:
out += s * X @ A @ B
return out
After:
def matmul_lora(X, W, b, W_quant, A, B, s, out=None,
X_drop=None, lora_bias=None):
out = torch.matmul(X, W, out=out)
if A is not None:
X_lora = X_drop if X_drop is not None else X
out += s * X_lora @ A @ B
if lora_bias is not None:
out += s * lora_bias
return out
Dropout이 적용된 입력(X_drop)을 LoRA 경로에만 사용합니다. base weight 경로(X @ W)에는 dropout을 적용하지 않습니다. LoRA bias도 scaling factor s를 곱하여 추가됩니다.
왜 이게 좋은가
- DoRA 메모리 효율: Triton 커널로
B@A를 실체화하지 않아 DoRA의 weight norm 계산이 메모리 효율적입니다. - 캐싱:
magnitude._dora_cache에 weight norm을 캐싱하여 optimizer step 사이에 불필요한 재계산을 방지합니다. - Dropout 분리: base weight와 LoRA 경로의 dropout을 분리하여 PEFT 논문의 원래 의도를 정확히 구현합니다.
- 하위 호환: 추가된 파라미터가 모두
None기본값이므로 기존 코드가 수정 없이 동작합니다.
정리
2862줄 추가의 대규모 PR로, Axolotl의 LoRA 커널을 실전 수준으로 끌어올립니다. bias, dropout, DoRA, embedding을 모두 지원함으로써 PEFT의 다양한 설정 조합을 Triton 최적화된 경로로 처리할 수 있게 되었습니다.
참고 자료
이 포스트는 AI가 작성하였으며, 사실과 다를 수 있습니다. 정확한 정보는 원본 PR을 참고해 주세요.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Axolotl] Qwen 3.5 모델 Liger 커널 지원 및 fused RMSNorm+Gated 커널 추가
- 현재글 : [Axolotl] LoRA 커널에 bias, dropout, DoRA, embedding 지원 추가
- 다음글 [sglang] SGLang의 SM120 FP8 Blockwise GEMM 성능 최적화: Pingpong 스케줄 도입
댓글