[axolotl] Async GRPO 지원: vLLM 비동기 생성과 Importance Sampling으로 RLHF 학습 가속화
PR 링크: axolotl-ai-cloud/axolotl#3486 상태: Merged | 변경: +5474 / -36
들어가며
GRPO(Group Relative Policy Optimization)는 강화학습 기반 LLM 학습 방법론입니다. 기존 동기 방식에서는 vLLM으로 completions를 생성한 후 학습을 진행하는 과정이 순차적이었습니다. 이 PR은 axolotl에 Async GRPO를 도입하여, 백그라운드 스레드에서 다음 batch의 completions를 미리 생성하면서 현재 batch로 학습을 수행하는 파이프라인 병렬화를 구현합니다.
핵심 코드 분석
1. Async/Sync GRPO Trainer 클래스 분기
# src/axolotl/core/builders/rl.py
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1,
async_grpo=async_grpo,
)
async_prefetch 또는 use_data_producer 플래그가 설정되면 AxolotlAsyncGRPOTrainer를 사용합니다. Context Parallel과 Async GRPO는 동시에 사용할 수 없도록 명시적으로 차단했습니다.
2. vLLM LoRA Sync 자동 선택
Before:
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
After:
serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None
)
if (
serve_module is None
and getattr(cfg, "trl", None)
and getattr(cfg.trl, "vllm_lora_sync", False)
):
serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
serve_module = "trl.scripts.vllm_serve"
vllm_lora_sync: true 설정 시 LoRA 전용 serve 모듈을 자동으로 선택합니다. NCCL merge-sync 대신 파일시스템 기반 어댑터 동기화를 사용하여 multi-GPU 환경에서의 GPU 경합을 줄입니다.
3. Async Prefetch 아키텍처
새로 추가된 AsyncGRPOTrainer의 핵심 구조:
# 백그라운드 스레드에서 completions 생성
# → 메인 스레드에서 reward 계산 + 학습
# → Importance Sampling으로 stale weights 보정
주요 설정:
trl:
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
streaming_partial_batch: true
vllm_importance_sampling_correction: true
replay_buffer_size: 100
왜 이게 좋은가
Async GRPO의 핵심 가치는 GPU 활용률 극대화입니다. 동기 방식에서는 vLLM 생성 동안 학습 GPU가 유휴 상태였지만, 비동기 방식은 이 시간을 overlap하여 wall-clock time을 크게 줄입니다. Importance Sampling 보정은 stale weights로 인한 분포 이동 문제를 수학적으로 해결하며, Replay Buffer와 Deferred Re-rolling은 학습 효율을 더욱 높입니다. 5400행 이상의 대규모 추가임에도 기존 동기 GRPO 경로에 영향을 주지 않는 깔끔한 분기 구조가 인상적입니다.
정리
| 항목 | 내용 |
|---|---|
| 기능 | Async GRPO 파이프라인 (비동기 생성 + 학습) |
| 핵심 | Importance Sampling 보정, LoRA Sync, Replay Buffer |
| 영향 | GRPO 학습 wall-clock time 대폭 감소 |
참고 자료
알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [llm-compressor] Intermediates Cache Prefetch - 중간 결과 프리페칭
- 현재글 : [axolotl] Async GRPO 지원: vLLM 비동기 생성과 Importance Sampling으로 RLHF 학습 가속화
- 다음글 [Ray Data] RAPIDS MPF 기반 GPU 셔플 지원으로 GPU 데이터 처리 파이프라인 가속
댓글