본문으로 건너뛰기

[axolotl] Async GRPO 지원: vLLM 비동기 생성과 Importance Sampling으로 RLHF 학습 가속화

PR 링크: axolotl-ai-cloud/axolotl#3486 상태: Merged | 변경: +5474 / -36

들어가며

GRPO(Group Relative Policy Optimization)는 강화학습 기반 LLM 학습 방법론입니다. 기존 동기 방식에서는 vLLM으로 completions를 생성한 후 학습을 진행하는 과정이 순차적이었습니다. 이 PR은 axolotl에 Async GRPO를 도입하여, 백그라운드 스레드에서 다음 batch의 completions를 미리 생성하면서 현재 batch로 학습을 수행하는 파이프라인 병렬화를 구현합니다.

핵심 코드 분석

1. Async/Sync GRPO Trainer 클래스 분기

# src/axolotl/core/builders/rl.py
async_grpo = bool(
    self.cfg.trl
    and (
        getattr(self.cfg.trl, "async_prefetch", False)
        or getattr(self.cfg.trl, "use_data_producer", False)
    )
)
trainer_cls = GRPOStrategy.get_trainer_class(
    sequence_parallel=self.cfg.context_parallel_size > 1,
    async_grpo=async_grpo,
)

async_prefetch 또는 use_data_producer 플래그가 설정되면 AxolotlAsyncGRPOTrainer를 사용합니다. Context Parallel과 Async GRPO는 동시에 사용할 수 없도록 명시적으로 차단했습니다.

2. vLLM LoRA Sync 자동 선택

Before:

serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")

After:

serve_module = cli_args.get("serve_module") or getattr(
    cfg.vllm, "serve_module", None
)
if (
    serve_module is None
    and getattr(cfg, "trl", None)
    and getattr(cfg.trl, "vllm_lora_sync", False)
):
    serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
    serve_module = "trl.scripts.vllm_serve"

vllm_lora_sync: true 설정 시 LoRA 전용 serve 모듈을 자동으로 선택합니다. NCCL merge-sync 대신 파일시스템 기반 어댑터 동기화를 사용하여 multi-GPU 환경에서의 GPU 경합을 줄입니다.

3. Async Prefetch 아키텍처

새로 추가된 AsyncGRPOTrainer의 핵심 구조:

# 백그라운드 스레드에서 completions 생성
# → 메인 스레드에서 reward 계산 + 학습
# → Importance Sampling으로 stale weights 보정

주요 설정:

trl:
  async_prefetch: true
  prefetch_depth: 1
  vllm_sync_interval: 2
  streaming_partial_batch: true
  vllm_importance_sampling_correction: true
  replay_buffer_size: 100

왜 이게 좋은가

Async GRPO의 핵심 가치는 GPU 활용률 극대화입니다. 동기 방식에서는 vLLM 생성 동안 학습 GPU가 유휴 상태였지만, 비동기 방식은 이 시간을 overlap하여 wall-clock time을 크게 줄입니다. Importance Sampling 보정은 stale weights로 인한 분포 이동 문제를 수학적으로 해결하며, Replay Buffer와 Deferred Re-rolling은 학습 효율을 더욱 높입니다. 5400행 이상의 대규모 추가임에도 기존 동기 GRPO 경로에 영향을 주지 않는 깔끔한 분기 구조가 인상적입니다.

정리

항목 내용
기능 Async GRPO 파이프라인 (비동기 생성 + 학습)
핵심 Importance Sampling 보정, LoRA Sync, Replay Buffer
영향 GRPO 학습 wall-clock time 대폭 감소

참고 자료

알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글