본문으로 건너뛰기

[sglang] SGLang, Diffusion 모델의 RL 기반 후처리 최적화를 위한 새로운 Rollout API 및 정밀도 개선

PR 링크: sgl-project/sglang#22604 상태: Merged | 변경: +None / -None

들어가며

최근 SGLang 프로젝트에서는 Diffusion 모델의 강화학습(RL) 기반 후처리 과정을 효율화하고 정밀도를 높이기 위한 중요한 개선 작업을 진행했습니다. 이전에는 일반적인 텍스트-이미지(T2I) 생성 경로에 의존해야 했던 롤아웃(rollout) 로직을 분리하고, Sequence Parallelism(SP) 환경에서의 수치적 안정성을 확보하며, RL 학습에 필요한 추가 메타데이터를 효율적으로 수집하는 것이 주요 목표였습니다. 본 PR은 이러한 요구사항을 충족시키기 위해 독립적인 롤아웃 HTTP API를 도입하고, 디노이징 환경 백패스 기능을 추가하며, SP 환경에서도 정확한 로그 확률(log-probability) 계산을 보장하는 정밀도 개선을 포함합니다.

이 글에서는 해당 PR의 코드 변경 사항을 상세히 분석하고, 각 변경이 왜 Diffusion 모델의 RL 후처리 과정에 긍정적인 영향을 미치는지, 그리고 어떤 기술적 교훈을 얻을 수 있는지 살펴보겠습니다.

코드 분석

1. 독립적인 롤아웃 HTTP API 도입 (runtime/entrypoints/post_training/rollout_api.py, io_struct.py)

기존에는 텍스트-이미지 생성 API(POST /generate) 내부에 롤아웃 관련 로직이 혼재되어 있었습니다. 이 PR은 이를 분리하여 POST /rollout/generate라는 독립적인 엔드포인트와 전용 요청/응답 구조체(RolloutRequest, RolloutResponse)를 도입했습니다. 이는 RL 트레이너가 개별 샘플의 롤아웃 데이터를 쉽게 소비할 수 있도록 샘플 단위의 응답을 제공하는 데 중점을 둡니다.

주요 변경 사항:

  • 새로운 API 엔드포인트: POST /rollout/generate가 추가되어 롤아웃 관련 요청을 전담합니다.
  • RolloutRequest / RolloutResponse: Pydantic 모델을 사용하여 롤아웃 관련 파라미터와 응답 구조를 정의합니다. RolloutResponse는 각 샘플별로 생성된 이미지(generated_output), 로그 확률(rollout_log_probs), 디버그 텐서(rollout_debug_tensors), 디노이징 환경(denoising_env), DiT 트래이젝토리(dit_trajectory) 등을 포함합니다.
  • 샘플 단위 응답: 배치 단위의 결과를 각 샘플별 RolloutResponse로 분리하여 직렬화된 트레이젝토리 슬라이스를 제공합니다. 이는 RL 트레이너가 개별 트레이젝토리를 처리할 때 배치 디먹싱(demuxing)이나 RL 전용 필드를 스레딩할 필요 없이 효율적으로 데이터를 사용할 수 있게 합니다.

코드 예시 (요청/응답 구조):

# runtime/entrypoints/post_training/io_struct.py

class RolloutRequest(BaseModel):
    prompt: str
    # ... other generation params
    rollout: bool = True
    rollout_sde_type: str = "sde"
    rollout_noise_level: float = 0.7
    rollout_log_prob_no_const: bool = False
    rollout_debug_mode: bool = True
    rollout_return_denoising_env: bool = False
    rollout_return_dit_trajectory: bool = False
    # ...

class RolloutResponse(BaseModel):
    request_id: str
    prompt: str
    seed: int
    generated_output: Any = None
    rollout_log_probs: Optional[dict[str, Any]] = None
    rollout_debug_tensors: Optional[dict[str, Any]] = None
    denoising_env: Optional[dict[str, Any]] = None
    dit_trajectory: Optional[dict[str, Any]] = None
    # ...

2. 롤아웃 디노이징 환경 백패스 (runtime/post_training/rollout_denoising_mixin.py)

RL 기반 후처리는 단순히 로그 확률뿐만 아니라, 정책 경사(policy gradient) 계산을 위해 롤아웃 과정의 상세한 메타데이터가 필요합니다. 이 PR은 RolloutDenoisingMixin을 통해 DenoisingStage에 훅을 걸어, rollout_return_denoising_env 또는 rollout_return_dit_trajectory 플래그가 활성화되었을 때 필요한 데이터를 수집합니다.

주요 변경 사항:

  • RolloutDenoisingMixin: DenoisingStage_maybe_prepare_rollout, _maybe_init_denoising_env_collection, _maybe_append_dit_trajectory_step, _maybe_collect_rollout_log_probs, _maybe_finalize_dit_env_collection 등의 메서드를 오버라이드하거나 확장합니다.
  • 데이터 수집: 텍스트 임베딩, 이미지 임베딩, 가이던스 정보와 같은 고정된 트랜스포머의 키워드 인수(kwargs)와 각 스텝의 DiT 트레이젝토리(raw noisy x_{t_i}, 최종 latent, 타임스텝)를 수집합니다.
  • 옵트인(Opt-in) 설계: rollout_return_denoising_envrollout_return_dit_trajectory 플래그는 기본적으로 비활성화되어 있어, 해당 기능이 필요 없을 경우 메모리나 연산 오버헤드가 발생하지 않습니다.

코드 예시 (데이터 흐름):

graph TD
    A[POST /rollout/generate] --> B{build SamplingParams}
    B --> C{pipeline.forward(req)}
    C --> D[DenoisingStage]
    D --> E[_maybe_prepare_rollout]
    E --> F[_maybe_init_denoising_env_collection]
    F --> G[for each step]
    G --> H[_maybe_append_dit_trajectory_step]
    H --> I[scheduler.step]
    I --> J[SDE/CPS: log_prob on full buffer]
    I --> K[ODE: bit-exact step]
    G --> L[_postprocess_rollout_outputs]
    L --> M[_maybe_collect_rollout_log_probs]
    L --> N[_maybe_finalize_dit_env_collection]
    N --> O[SP-gather env + trajectory]
    C --> P[_build_response]
    P --> Q[split batch -> per-sample RolloutResponse]
    Q --> R[ORJSONResponse]

3. SP 환경에서의 로그 확률 정밀도 개선 (runtime/post_training/scheduler_rl_mixin.py)

Sequence Parallelism(SP) 환경에서 로그 확률을 계산할 때 발생하는 수치적 드리프트(drift)는 이전 PR에서 지적된 문제였습니다. 특히 enable_autocast=False 파이프라인에서 PyTorch의 래핑된 스칼라 승격(wrapped-scalar promotion)으로 인해 0차원 fp32 noise_std_dev가 N차원 bf16 variance_noise와 곱해질 때 bf16으로 강제 변환되는 현상이 원인이었습니다. 또한, bf16의 비결합성(non-associativity)도 셔드(shard) 간 합산 시 문제를 야기했습니다.

이 PR은 각 랭크가 전체 사전 셔드(pre-shard) 노이즈 버퍼에 대해 로그 확률을 계산하도록 변경하여 이 문제를 해결했습니다. 이를 통해 all_reduce 연산 없이도 SP 및 단일 GPU 환경 모두에서 비트-정확한(bit-exact) 결과를 얻을 수 있게 되었습니다.

주요 변경 사항:

  • 전체 버퍼 로그 확률 계산: 각 SP 랭크는 공통 시드(seed)를 사용하여 전체 rollout_session_data.noise_buffer를 생성하고, 이 전체 버퍼에 대해 log_prob_no_const_val을 계산합니다. 이를 통해 각 스텝의 합계가 동일하게 보장됩니다.
  • all_reduce 제거: 더 이상 셔드 간 all_reduce가 필요 없어 통신 비용이 절감되고, SP/단일 GPU 간 비트-정확성이 보장됩니다.
  • FlowGRPO Precision Policy: SDE/CPS 분기는 model_output.float()를 진입 시점에 캐스팅하여 FlowGRPO의 sd3_sde_with_logprob.py와 일치시킵니다. 이는 bf16 오버플로우를 방지하고 PyTorch의 래핑된 스칼라 승격 문제를 회피합니다.
  • ODE 모드 유지: ODE 분기는 캐스팅을 하지 않아 rollout(sde_type="ode")가 비롤아웃(non-rollout) 결정론적 스텝과 비트-정확한 prev_sample을 생성하도록 유지합니다.

코드 예시 (SP 정밀도 개선):

# python/sglang/multimodal_gen/runtime/post_training/scheduler_rl_mixin.py

- # SP-aligned log-prob. Previously drifted by 1 bf16 ulp per step.
- # Root cause: 0-dim fp32 noise_std_dev demoted to bf16 when multiplied
- # against N-dim bf16 variance_noise on enable_autocast=False pipelines.
- # Fix: each rank computes log-prob on the full pre-shard noise buffer
- # (no all_reduce), plus flowGRPO's fp32 entry-cast policy for SDE/CPS.
+ # SP-aligned log-prob. Previously drifted by 1 bf16 ulp per step.
+ # Root cause: 0-dim fp32 noise_std_dev demoted to bf16 when multiplied
+ # against N-dim bf16 variance_noise on enable_autocast=False pipelines.
+ # Fix: each rank computes log-prob on the full pre-shard noise buffer
+ # (no all_reduce), plus flowGRPO's fp32 entry-cast policy for SDE/CPS.
+ # ODE branch is left uncasted to preserve bit-exactness with non-rollout path.
 def _rollout_log_probs(self, batch: OutputBatch, rollout_session_data: RolloutTrajectoryData):
     # ...
     if self.sampling_params.rollout:
-        # SP-aligned log-prob calculation
-        # ... (previous implementation with all_reduce)
+        # SP-aligned log-prob calculation on full buffer
+        # ... (new implementation using full noise_buffer)

4. 기타 버그 수정 및 안정성 강화

이 PR에는 위에서 언급된 주요 기능 외에도 여러 버그 수정 및 안정성 개선 사항이 포함되어 있습니다.

  • dit_trajectory 누락된 최종 스텝: _postprocess_rollout_outputs에서 마지막 스텝의 latent를 누락 없이 추가하고, SP 수집과 일관되도록 순서를 조정했습니다.
  • dit_trajectory 잘못된 캡처 텐서: 디노이징 루프 상단에서 캡처 훅을 이동시켜, 스케일링/I2V-concat 이전의 원시 x_{t_i}를 저장하도록 수정했습니다. 필드 이름도 latent_model_inputs에서 latents로 변경되었습니다.
  • variance_noise 별칭(aliasing) 문제: SP=1일 때 append_local_rollout_debug_tensors가 버퍼를 복제(clone)하도록 수정하여, 재사용되는 noise_buffer 뷰가 덮어쓰여 이전 스텝의 노이즈를 참조하는 문제를 해결했습니다.
  • 중복 _maybe_collect_rollout_log_probs 호출 제거: _post_denoising_loop에서 불필요한 호출을 제거했습니다.
  • CLI 기본값 누수 방지: 롤아웃 CLI 인자의 명시적 기본값 설정을 제거하여 관련 테스트가 통과하도록 했습니다.
  • gather_latents_for_sp 키워드 인수 불일치 수정: Z-Image/Qwen-Image 롤아웃 파이프라인 믹스인 및 디버그 믹스인에서 발생한 인자 불일치를 수정했습니다.

이러한 수정들은 코드의 견고성을 높이고 예상치 못한 동작을 방지하는 데 기여합니다.

왜 이게 좋은가?

성능 향상 및 효율성 증대

  1. 독립적인 API: 롤아웃 로직을 메인 생성 경로에서 분리함으로써, RL 후처리 작업에 필요한 기능만 선택적으로 활성화할 수 있게 되었습니다. 이는 불필요한 오버헤드를 줄이고, 메인 생성 경로의 성능에 영향을 주지 않으면서 롤아웃 관련 기능을 사용할 수 있게 합니다.
  2. 샘플 단위 처리: RolloutResponse가 각 샘플별로 직렬화된 데이터를 제공하므로, RL 트레이너는 배치 데이터를 직접 처리할 필요 없이 각 트레이젝토리를 개별적으로 효율적으로 소비할 수 있습니다. 이는 RL 학습 파이프라인의 복잡성을 줄여줍니다.
  3. SP 환경에서의 정밀도 보장: 전체 사전 셔드 노이즈 버퍼를 사용하여 로그 확률을 계산하는 방식은 SP 환경에서 발생하는 수치적 드리프트를 완전히 제거했습니다. Qwen-Image 모델에서 SDE 모드 시 max_abs_diff4.88e-4에서 0으로, CPS 모드 시 1.95e-3에서 0으로 개선된 결과는 이를 명확히 보여줍니다. 이는 RL 학습의 안정성과 재현성을 크게 향상시킵니다.
  4. 메모리 및 연산 효율성: rollout_return_denoising_envrollout_return_dit_trajectory와 같은 기능은 옵트인 방식으로 설계되어, 필요할 때만 활성화됩니다. 이는 기본 경로에서는 추가적인 메모리나 연산 비용 없이 제로 오버헤드를 보장합니다.

일반적인 교훈

  • API 분리의 중요성: 특정 기능(예: RL 후처리)이 일반적인 사용 사례와 다를 경우, 이를 별도의 API 엔드포인트로 분리하는 것이 코드의 모듈성, 유지보수성 및 효율성을 높이는 좋은 전략입니다.
  • 병렬 처리 시 정밀도 문제 해결: 분산 환경(특히 SP)에서는 데이터 타입 변환, 연산 순서, 집계 방식 등에 따라 미묘한 수치적 차이가 발생할 수 있습니다. 이러한 문제를 해결하기 위해 전체 버퍼를 사용하거나, 특정 연산의 데이터 타입을 명시적으로 제어하는 방식(예: fp32 캐스팅)이 효과적입니다. PyTorch의 래핑된 스칼라 승격과 같은 내부 동작을 이해하는 것이 중요합니다.
  • 옵트인 설계의 이점: 성능에 민감한 애플리케이션에서는 추가 기능이 기본적으로 비활성화되어 있어, 사용자가 명시적으로 선택해야만 활성화되도록 하는 것이 성능 저하를 방지하는 데 도움이 됩니다.
  • 철저한 테스트: 다양한 병렬 설정(TP, SP) 및 모델 구성(Qwen-Image, Z-Image-Turbo)에 걸쳐 비트-정확성 및 기능적 정확성을 검증하는 것은 이러한 복잡한 변경 사항의 신뢰성을 보장하는 데 필수적입니다. 특히, torch.equal을 사용한 단위 테스트는 결정론적 동작을 보장하는 데 유용합니다.

리뷰 댓글 분석

제공된 리뷰 댓글은 주로 CI 관련 태그(tag-and-rerun-ci, rerun-failed-ci)와 파일 경로에 대한 간단한 언급(pipeline_configs/mixins)으로 구성되어 있습니다. 이는 코드 자체의 복잡성보다는 CI 환경 설정이나 파일 구조에 대한 논의가 있었음을 시사합니다. 핵심적인 기술적 논쟁이나 반대 의견보다는, 코드 변경 사항이 의도한 대로 작동하는지 확인하는 과정에 초점을 맞춘 것으로 보입니다. 따라서 리뷰 댓글에서 직접적으로 도출할 수 있는 추가적인 기술적 분석은 제한적입니다.

References

  • torch.compile - PyTorch의 컴파일 기능 (직접적인 관련은 없으나, 성능 최적화 맥락에서 참고)
  • FastAPI - 비동기 웹 프레임워크
  • Pydantic - 데이터 유효성 검사 및 설정 관리 라이브러리
  • Sequence Parallelism - Hugging Face 블로그의 Sequence Parallelism 설명 (개념 이해에 도움)
  • FlowGRPO - PR에서 참조된 FlowGRPO의 로그 확률 계산 코드 (직접 링크는 PR 내부에만 존재)

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글