[sglang] SGLang, Diffusion 모델의 RL 기반 후처리 최적화를 위한 새로운 Rollout API 및 정밀도 개선
PR 링크: sgl-project/sglang#22604 상태: Merged | 변경: +None / -None
들어가며
최근 SGLang 프로젝트에서는 Diffusion 모델의 강화학습(RL) 기반 후처리 과정을 효율화하고 정밀도를 높이기 위한 중요한 개선 작업을 진행했습니다. 이전에는 일반적인 텍스트-이미지(T2I) 생성 경로에 의존해야 했던 롤아웃(rollout) 로직을 분리하고, Sequence Parallelism(SP) 환경에서의 수치적 안정성을 확보하며, RL 학습에 필요한 추가 메타데이터를 효율적으로 수집하는 것이 주요 목표였습니다. 본 PR은 이러한 요구사항을 충족시키기 위해 독립적인 롤아웃 HTTP API를 도입하고, 디노이징 환경 백패스 기능을 추가하며, SP 환경에서도 정확한 로그 확률(log-probability) 계산을 보장하는 정밀도 개선을 포함합니다.
이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 각 변경이 왜 Diffusion 모델의 RL 후처리 과정에 긍정적인 영향을 미치는지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.
코드 분석
1. 독립적인 롤아웃 HTTP API 도입 (runtime/entrypoints/post_training/rollout_api.py, io_struct.py)
기존에는 텍스트-이미지 생성 API(POST /generate) 내부에 롤아웃 관련 로직이 혼재되어 있었습니다. 이 PR은 이를 분리하여 POST /rollout/generate라는 독립적인 엔드포인트와 전용 요청/응답 구조체(RolloutRequest, RolloutResponse)를 도입했습니다. 이는 RL 트레이너가 개별 샘플의 롤아웃 데이터를 쉽게 소비할 수 있도록 샘플 단위의 응답을 제공하는 데 중점을 둡니다.
주요 변경 사항:
- 새로운 API 엔드포인트:
POST /rollout/generate가 추가되어 롤아웃 관련 요청을 전담합니다. RolloutRequest/RolloutResponse: Pydantic 모델을 사용하여 롤아웃 관련 파라미터와 응답 구조를 정의합니다.RolloutResponse는 각 샘플별로 생성된 이미지(generated_output), 로그 확률(rollout_log_probs), 디버그 텐서(rollout_debug_tensors), 디노이징 환경(denoising_env), DiT 트래이젝토리(dit_trajectory) 등을 포함합니다.- 샘플 단위 응답: 배치 단위의 결과를 각 샘플별
RolloutResponse로 분리하여 직렬화된 트레이젝토리 슬라이스를 제공합니다. 이는 RL 트레이너가 개별 트레이젝토리를 처리할 때 배치 디먹싱(demuxing)이나 RL 전용 필드를 스레딩할 필요 없이 효율적으로 데이터를 사용할 수 있게 합니다.
코드 예시 (요청/응답 구조):
# runtime/entrypoints/post_training/io_struct.py
class RolloutRequest(BaseModel):
prompt: str
# ... other generation params
rollout: bool = True
rollout_sde_type: str = "sde"
rollout_noise_level: float = 0.7
rollout_log_prob_no_const: bool = False
rollout_debug_mode: bool = True
rollout_return_denoising_env: bool = False
rollout_return_dit_trajectory: bool = False
# ...
class RolloutResponse(BaseModel):
request_id: str
prompt: str
seed: int
generated_output: Any = None
rollout_log_probs: Optional[dict[str, Any]] = None
rollout_debug_tensors: Optional[dict[str, Any]] = None
denoising_env: Optional[dict[str, Any]] = None
dit_trajectory: Optional[dict[str, Any]] = None
# ...
2. 롤아웃 디노이징 환경 백패스 (runtime/post_training/rollout_denoising_mixin.py)
RL 기반 후처리는 단순히 로그 확률뿐만 아니라, 정책 경사(policy gradient) 계산을 위해 롤아웃 과정의 상세한 메타데이터가 필요합니다. 이 PR은 RolloutDenoisingMixin을 통해 DenoisingStage에 훅을 걸어, rollout_return_denoising_env 또는 rollout_return_dit_trajectory 플래그가 활성화되었을 때 필요한 데이터를 수집합니다.
주요 변경 사항:
RolloutDenoisingMixin:DenoisingStage의_maybe_prepare_rollout,_maybe_init_denoising_env_collection,_maybe_append_dit_trajectory_step,_maybe_collect_rollout_log_probs,_maybe_finalize_dit_env_collection등의 메서드를 오버라이드하거나 확장합니다.- 데이터 수집: 텍스트 임베딩, 이미지 임베딩, 가이던스 정보와 같은 고정된 트랜스포머의 키워드 인수(
kwargs)와 각 스텝의 DiT 트레이젝토리(raw noisyx_{t_i}, 최종 latent, 타임스텝)를 수집합니다. - 옵트인(Opt-in) 설계:
rollout_return_denoising_env및rollout_return_dit_trajectory플래그는 기본적으로 비활성화되어 있어, 해당 기능이 필요 없을 경우 메모리나 연산 오버헤드가 발생하지 않습니다.
코드 예시 (데이터 흐름):
graph TD
A[POST /rollout/generate] --> B{build SamplingParams}
B --> C{pipeline.forward(req)}
C --> D[DenoisingStage]
D --> E[_maybe_prepare_rollout]
E --> F[_maybe_init_denoising_env_collection]
F --> G[for each step]
G --> H[_maybe_append_dit_trajectory_step]
H --> I[scheduler.step]
I --> J[SDE/CPS: log_prob on full buffer]
I --> K[ODE: bit-exact step]
G --> L[_postprocess_rollout_outputs]
L --> M[_maybe_collect_rollout_log_probs]
L --> N[_maybe_finalize_dit_env_collection]
N --> O[SP-gather env + trajectory]
C --> P[_build_response]
P --> Q[split batch -> per-sample RolloutResponse]
Q --> R[ORJSONResponse]
3. SP 환경에서의 로그 확률 정밀도 개선 (runtime/post_training/scheduler_rl_mixin.py)
Sequence Parallelism(SP) 환경에서 로그 확률을 계산할 때 발생하는 수치적 드리프트(drift)는 이전 PR에서 지적된 문제였습니다. 특히 enable_autocast=False 파이프라인에서 PyTorch의 래핑된 스칼라 승격(wrapped-scalar promotion)으로 인해 0차원 fp32 noise_std_dev가 N차원 bf16 variance_noise와 곱해질 때 bf16으로 강제 변환되는 현상이 원인이었습니다. 또한, bf16의 비결합성(non-associativity)도 셔드(shard) 간 합산 시 문제를 야기했습니다.
이 PR은 각 랭크가 전체 사전 셔드(pre-shard) 노이즈 버퍼에 대해 로그 확률을 계산하도록 변경하여 이 문제를 해결했습니다. 이를 통해 all_reduce 연산 없이도 SP 및 단일 GPU 환경 모두에서 비트-정확한(bit-exact) 결과를 얻을 수 있게 되었습니다.
주요 변경 사항:
- 전체 버퍼 로그 확률 계산: 각 SP 랭크는 공통 시드(seed)를 사용하여 전체
rollout_session_data.noise_buffer를 생성하고, 이 전체 버퍼에 대해log_prob_no_const_val을 계산합니다. 이를 통해 각 스텝의 합계가 동일하게 보장됩니다. all_reduce제거: 더 이상 셔드 간all_reduce가 필요 없어 통신 비용이 절감되고, SP/단일 GPU 간 비트-정확성이 보장됩니다.- FlowGRPO Precision Policy: SDE/CPS 분기는
model_output.float()를 진입 시점에 캐스팅하여 FlowGRPO의sd3_sde_with_logprob.py와 일치시킵니다. 이는 bf16 오버플로우를 방지하고 PyTorch의 래핑된 스칼라 승격 문제를 회피합니다. - ODE 모드 유지: ODE 분기는 캐스팅을 하지 않아
rollout(sde_type="ode")가 비롤아웃(non-rollout) 결정론적 스텝과 비트-정확한prev_sample을 생성하도록 유지합니다.
코드 예시 (SP 정밀도 개선):
# python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py
- # SP-aligned log-prob. Previously drifted by 1 bf16 ulp per step.
- # Root cause: 0-dim fp32 noise_std_dev demoted to bf16 when multiplied
- # against N-dim bf16 variance_noise on enable_autocast=False pipelines.
- # Fix: each rank computes log-prob on the full pre-shard noise buffer
- # (no all_reduce), plus flowGRPO's fp32 entry-cast policy for SDE/CPS.
+ # SP-aligned log-prob. Previously drifted by 1 bf16 ulp per step.
+ # Root cause: 0-dim fp32 noise_std_dev demoted to bf16 when multiplied
+ # against N-dim bf16 variance_noise on enable_autocast=False pipelines.
+ # Fix: each rank computes log-prob on the full pre-shard noise buffer
+ # (no all_reduce), plus flowGRPO's fp32 entry-cast policy for SDE/CPS.
+ # ODE branch is left uncasted to preserve bit-exactness with non-rollout path.
def _rollout_log_probs(self, batch: OutputBatch, rollout_session_data: RolloutTrajectoryData):
# ...
if self.sampling_params.rollout:
- # SP-aligned log-prob calculation
- # ... (previous implementation with all_reduce)
+ # SP-aligned log-prob calculation on full buffer
+ # ... (new implementation using full noise_buffer)
4. 기타 버그 수정 및 안정성 강화
이 PR에는 위에서 언급된 주요 기능 외에도 여러 버그 수정 및 안정성 개선 사항이 포함되어 있습니다.
dit_trajectory누락된 최종 스텝:_postprocess_rollout_outputs에서 마지막 스텝의 latent를 누락 없이 추가하고, SP 수집과 일관되도록 순서를 조정했습니다.dit_trajectory잘못된 캡처 텐서: 디노이징 루프 상단에서 캡처 훅을 이동시켜, 스케일링/I2V-concat 이전의 원시x_{t_i}를 저장하도록 수정했습니다. 필드 이름도latent_model_inputs에서latents로 변경되었습니다.variance_noise별칭(aliasing) 문제: SP=1일 때append_local_rollout_debug_tensors가 버퍼를 복제(clone)하도록 수정하여, 재사용되는noise_buffer뷰가 덮어쓰여 이전 스텝의 노이즈를 참조하는 문제를 해결했습니다.- 중복
_maybe_collect_rollout_log_probs호출 제거:_post_denoising_loop에서 불필요한 호출을 제거했습니다. - CLI 기본값 누수 방지: 롤아웃 CLI 인자의 명시적 기본값 설정을 제거하여 관련 테스트가 통과하도록 했습니다.
gather_latents_for_sp키워드 인수 불일치 수정: Z-Image/Qwen-Image 롤아웃 파이프라인 믹스인 및 디버그 믹스인에서 발생한 인자 불일치를 수정했습니다.
이러한 수정들은 코드의 견고성을 높이고 예상치 못한 동작을 방지하는 데 기여합니다.
왜 이게 좋은가?
성능 향상 및 효율성 증대
- 독립적인 API: 롤아웃 로직을 메인 생성 경로에서 분리함으로써, RL 후처리 작업에 필요한 기능만 선택적으로 활성화할 수 있게 되었습니다. 이는 불필요한 오버헤드를 줄이고, 메인 생성 경로의 성능에 영향을 주지 않으면서 롤아웃 관련 기능을 사용할 수 있게 합니다.
- 샘플 단위 처리:
RolloutResponse가 각 샘플별로 직렬화된 데이터를 제공하므로, RL 트레이너는 배치 데이터를 직접 처리할 필요 없이 각 트레이젝토리를 개별적으로 효율적으로 소비할 수 있습니다. 이는 RL 학습 파이프라인의 복잡성을 줄여줍니다. - SP 환경에서의 정밀도 보장: 전체 사전 셔드 노이즈 버퍼를 사용하여 로그 확률을 계산하는 방식은 SP 환경에서 발생하는 수치적 드리프트를 완전히 제거했습니다. Qwen-Image 모델에서 SDE 모드 시
max_abs_diff가4.88e-4에서0으로, CPS 모드 시1.95e-3에서0으로 개선된 결과는 이를 명확히 보여줍니다. 이는 RL 학습의 안정성과 재현성을 크게 향상시킵니다. - 메모리 및 연산 효율성:
rollout_return_denoising_env및rollout_return_dit_trajectory와 같은 기능은 옵트인 방식으로 설계되어, 필요할 때만 활성화됩니다. 이는 기본 경로에서는 추가적인 메모리나 연산 비용 없이 제로 오버헤드를 보장합니다.
일반적인 교훈
- API 분리의 중요성: 특정 기능(예: RL 후처리)이 일반적인 사용 사례와 다를 경우, 이를 별도의 API 엔드포인트로 분리하는 것이 코드의 모듈성, 유지보수성 및 효율성을 높이는 좋은 전략입니다.
- 병렬 처리 시 정밀도 문제 해결: 분산 환경(특히 SP)에서는 데이터 타입 변환, 연산 순서, 집계 방식 등에 따라 미묘한 수치적 차이가 발생할 수 있습니다. 이러한 문제를 해결하기 위해 전체 버퍼를 사용하거나, 특정 연산의 데이터 타입을 명시적으로 제어하는 방식(예: fp32 캐스팅)이 효과적입니다. PyTorch의 래핑된 스칼라 승격과 같은 내부 동작을 이해하는 것이 중요합니다.
- 옵트인 설계의 이점: 성능에 민감한 애플리케이션에서는 추가 기능이 기본적으로 비활성화되어 있어, 사용자가 명시적으로 선택해야만 활성화되도록 하는 것이 성능 저하를 방지하는 데 도움이 됩니다.
- 철저한 테스트: 다양한 병렬 설정(TP, SP) 및 모델 구성(Qwen-Image, Z-Image-Turbo)에 걸쳐 비트-정확성 및 기능적 정확성을 검증하는 것은 이러한 복잡한 변경 사항의 신뢰성을 보장하는 데 필수적입니다. 특히,
torch.equal을 사용한 단위 테스트는 결정론적 동작을 보장하는 데 유용합니다.
리뷰 댓글 분석
제공된 리뷰 댓글은 주로 CI 관련 태그(tag-and-rerun-ci, rerun-failed-ci)와 파일 경로에 대한 간단한 언급(pipeline_configs/mixins)으로 구성되어 있습니다. 이는 코드 자체의 복잡성보다는 CI 환경 설정이나 파일 구조에 대한 논의가 있었음을 시사합니다. 핵심적인 기술적 논쟁이나 반대 의견보다는, 코드 변경 사항이 의도한 대로 작동하는지 확인하는 과정에 초점을 맞춘 것으로 보입니다. 따라서 리뷰 댓글에서 직접적으로 도출할 수 있는 추가적인 기술적 분석은 제한적입니다.
References
- torch.compile - PyTorch의 컴파일 기능 (직접적인 관련은 없으나, 성능 최적화 맥락에서 참고)
- FastAPI - 비동기 웹 프레임워크
- Pydantic - 데이터 유효성 검사 및 설정 관리 라이브러리
- Sequence Parallelism - Hugging Face 블로그의 Sequence Parallelism 설명 (개념 이해에 도움)
- FlowGRPO - PR에서 참조된 FlowGRPO의 로그 확률 계산 코드 (직접 링크는 PR 내부에만 존재)
참고 자료
- https://fastapi.tiangolo.com/
- https://docs.pydantic.dev/latest/
- https://huggingface.co/blog/sequence-parallelism
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] vLLM TurboQuant: KV 캐시 압축으로 LLM 서빙 효율 극대화
- 현재글 : [sglang] SGLang, Diffusion 모델의 RL 기반 후처리 최적화를 위한 새로운 Rollout API 및 정밀도 개선
- 다음글 [ollama] Ollama MLX Gemma4 성능 최적화: Fused Operations를 통한 효율성 증대
댓글