본문으로 건너뛰기

[Axolotl] GRPO 트레이너에 batch flattening/packing 지원 추가

PR 링크: axolotl-ai-cloud/axolotl#3552 상태: Merged | 변경: +1307 / -52

들어가며

GRPO(Group Relative Policy Optimization) 트레이닝에서 scoring forward pass는 상당한 GPU 시간을 소비합니다. 문제는 배치 내 시퀀스 길이가 다를 때 padding 토큰이 많아지는데, 이 padding에 대해서도 attention을 계산한다는 것입니다. 이 PR은 "batch flattening" 기법을 도입하여 padding 토큰을 forward pass 전에 제거하고, FlashAttention의 variable-length attention(cu_seq_lens)을 활용합니다.

핵심 코드 분석

1. Batch flattening 설정 추가

# AsyncGRPOConfig
batch_flattening: bool = field(
    default=False,
    metadata={
        "help": "Use batch flattening for the scoring forward pass. "
        "Removes padding tokens before the forward pass, reducing "
        "attention FLOPs proportional to the padding ratio. "
        "Requires flash_attention_2 attention implementation."
    },
)

2. Flattened forward pass 구현

def _get_per_token_logps_flattened(self, model, input_ids, attention_mask,
                                    logits_to_keep, batch_size=None,
                                    prompt_mask=None) -> torch.Tensor:
    """Compute per-token log-probs using batch flattening (padding-free).

    1. Chunks the batch into sub-batches of ``batch_size`` sequences
    2. Flattens non-padding tokens into [1, chunk_tokens]
    3. Uses FlashAttentionKwargs (cu_seq_lens) for varlen attention
    4. Computes selective_log_softmax on the flat logits
    5. Gathers completion logprobs back to (B, logits_to_keep) padded format
    """

핵심 아이디어는 (B, L) 형태의 padded 배치를 (1, total_tokens) 형태의 flat 텐서로 변환하고, cu_seq_lens로 각 시퀀스의 경계를 FlashAttention에 알려주는 것입니다.

3. 조건부 dispatch

Before:

old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
    self.model, prompt_completion_ids, attention_mask,
    logits_to_keep, logprob_batch_size,
    num_images=num_images, **forward_kwargs,
)

After:

can_flatten = (
    getattr(self.args, "batch_flattening", False)
    and not forward_kwargs      # no multimodal inputs
    and not self.is_fsdp_enabled  # FSDP needs wrapped model
)
if can_flatten:
    old_per_token_logps = self._get_per_token_logps_flattened(
        self.model, prompt_completion_ids, attention_mask,
        logits_to_keep, batch_size=logprob_batch_size,
        prompt_mask=prompt_mask,
    )
else:
    old_per_token_logps, _ = self._get_per_token_logps_and_entropies(...)

FSDP나 multimodal 입력이 있으면 기존 padded 경로를 사용하고, 그렇지 않으면 flattened 경로를 사용합니다.

왜 이게 좋은가

  • 실측 20-34% 성능 향상: padding 비율에 비례하여 attention FLOP가 감소합니다. 시퀀스 길이 편차가 클수록 효과가 큽니다.
  • 정밀도 차이 무시 가능: bf16 정밀도에서 평균 ~0.03의 per-token logprob 차이가 발생하지만, loss와 gradient에는 동등한 결과를 줍니다.
  • 안전한 fallback: FSDP, multimodal, flash_attention_2 미사용 시 자동으로 기존 경로를 사용합니다.

정리

GRPO 트레이닝의 scoring 병목을 batch flattening으로 해결한 대규모 PR입니다. FlashAttention의 variable-length 기능을 활용하여 padding 오버헤드를 제거하고, 기존 경로와의 호환성을 조건부 dispatch로 유지합니다.

참고 자료


이 포스트는 AI가 작성하였으며, 사실과 다를 수 있습니다. 정확한 정보는 원본 PR을 참고해 주세요.

댓글

관련 포스트

PR Analysis 의 다른글