[Axolotl] GRPO 트레이너에 batch flattening/packing 지원 추가
PR 링크: axolotl-ai-cloud/axolotl#3552 상태: Merged | 변경: +1307 / -52
들어가며
GRPO(Group Relative Policy Optimization) 트레이닝에서 scoring forward pass는 상당한 GPU 시간을 소비합니다. 문제는 배치 내 시퀀스 길이가 다를 때 padding 토큰이 많아지는데, 이 padding에 대해서도 attention을 계산한다는 것입니다. 이 PR은 "batch flattening" 기법을 도입하여 padding 토큰을 forward pass 전에 제거하고, FlashAttention의 variable-length attention(cu_seq_lens)을 활용합니다.
핵심 코드 분석
1. Batch flattening 설정 추가
# AsyncGRPOConfig
batch_flattening: bool = field(
default=False,
metadata={
"help": "Use batch flattening for the scoring forward pass. "
"Removes padding tokens before the forward pass, reducing "
"attention FLOPs proportional to the padding ratio. "
"Requires flash_attention_2 attention implementation."
},
)
2. Flattened forward pass 구현
def _get_per_token_logps_flattened(self, model, input_ids, attention_mask,
logits_to_keep, batch_size=None,
prompt_mask=None) -> torch.Tensor:
"""Compute per-token log-probs using batch flattening (padding-free).
1. Chunks the batch into sub-batches of ``batch_size`` sequences
2. Flattens non-padding tokens into [1, chunk_tokens]
3. Uses FlashAttentionKwargs (cu_seq_lens) for varlen attention
4. Computes selective_log_softmax on the flat logits
5. Gathers completion logprobs back to (B, logits_to_keep) padded format
"""
핵심 아이디어는 (B, L) 형태의 padded 배치를 (1, total_tokens) 형태의 flat 텐서로 변환하고, cu_seq_lens로 각 시퀀스의 경계를 FlashAttention에 알려주는 것입니다.
3. 조건부 dispatch
Before:
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model, prompt_completion_ids, attention_mask,
logits_to_keep, logprob_batch_size,
num_images=num_images, **forward_kwargs,
)
After:
can_flatten = (
getattr(self.args, "batch_flattening", False)
and not forward_kwargs # no multimodal inputs
and not self.is_fsdp_enabled # FSDP needs wrapped model
)
if can_flatten:
old_per_token_logps = self._get_per_token_logps_flattened(
self.model, prompt_completion_ids, attention_mask,
logits_to_keep, batch_size=logprob_batch_size,
prompt_mask=prompt_mask,
)
else:
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(...)
FSDP나 multimodal 입력이 있으면 기존 padded 경로를 사용하고, 그렇지 않으면 flattened 경로를 사용합니다.
왜 이게 좋은가
- 실측 20-34% 성능 향상: padding 비율에 비례하여 attention FLOP가 감소합니다. 시퀀스 길이 편차가 클수록 효과가 큽니다.
- 정밀도 차이 무시 가능: bf16 정밀도에서 평균 ~0.03의 per-token logprob 차이가 발생하지만, loss와 gradient에는 동등한 결과를 줍니다.
- 안전한 fallback: FSDP, multimodal, flash_attention_2 미사용 시 자동으로 기존 경로를 사용합니다.
정리
GRPO 트레이닝의 scoring 병목을 batch flattening으로 해결한 대규모 PR입니다. FlashAttention의 variable-length 기능을 활용하여 padding 오버헤드를 제거하고, 기존 경로와의 호환성을 조건부 dispatch로 유지합니다.
참고 자료
이 포스트는 AI가 작성하였으며, 사실과 다를 수 있습니다. 정확한 정보는 원본 PR을 참고해 주세요.
관련 포스트
- [Axolotl] LoRA 커널에 bias, dropout, DoRA, embedding 지원 추가
- [Axolotl] Qwen 3.5 모델 Liger 커널 지원 및 fused RMSNorm+Gated 커널 추가
- [논문리뷰] FIPO: Eliciting Deep Reasoning with Future-KL Influenced Policy Optimization
- [논문리뷰] EVA: Efficient Reinforcement Learning for End-to-End Video Agent
- [Axolotl] 플러그인에 scored rollout 디스패치, 외부 플러그인 경로 확장, vLLM 에러 처리 개선
PR Analysis 의 다른글
- 이전글 [CPython 3.14] asyncio.Queue docstring의 모호한 표현 수정 (backport)
- 현재글 : [Axolotl] GRPO 트레이너에 batch flattening/packing 지원 추가
- 다음글 [triton] AMD TDM의 Partition-Aware 분할 및 다중 Intrinsic 지원
댓글