본문으로 건너뛰기

[sglang] [HunyuanVideo] Sequence Parallelism 최적화: Text Token Sharding으로 성능 한계 돌파하기

PR 링크: sgl-project/sglang#28319 상태: Merged | 변경: +88 / -13

들어가며\n\n최근 비디오 생성 AI 모델인 HunyuanVideo와 같은 DiT(Diffusion Transformer) 아키텍처는 고해상도 비디오를 처리하기 위해 엄청난 수의 토큰을 다룹니다. 이러한 대규모 시퀀스를 단일 GPU 메모리에 담는 것은 불가능에 가깝기 때문에, 시퀀스 차원을 여러 GPU에 나누어 처리하는 Sequence Parallelism(SP) 기술이 필수적입니다.\n\n기존 SGLang의 HunyuanVideo 구현에서는 이미지 토큰은 SP를 통해 분산 처리되었지만, 상대적으로 길이가 짧은 텍스트 토큰은 모든 GPU에 복제(Replicated)되어 처리되는 구조였습니다. 하지만 모델의 규모가 커지고 최적화의 정밀도가 요구됨에 따라, 이 텍스트 토큰들조차 중복 연산의 원인이 되었습니다. 이번 PR([diffusion] Shard HunyuanVideo text tokens under SP)은 텍스트 토큰을 SP 환경에서 샤딩(Sharding)하고, 불균형한 시퀀스 길이를 처리하기 위한 Varlen Ulysses All-to-All 경로를 추가하여 성능을 한 단계 더 끌어올렸습니다.\n\n## 코드 분석: 핵심 변경 사항\n\n### 1. 텍스트 토큰의 동적 샤딩 (hunyuanvideo.py)\n\n가장 먼저 살펴볼 부분은 모델의 입력 단계에서 텍스트 토큰을 각 GPU의 랭크(Rank)에 맞춰 분할하는 로직입니다. 기존에는 모든 GPU가 동일한 txt 텐서를 가졌으나, 이제는 자신의 몫에 해당하는 부분만 가집니다.\n\n**[Before & After]\npython\n# Before: 텍스트 토큰 전체를 사용\n# (별도의 샤딩 로직 없이 txt_in 통과 후 전체 시퀀스 유지)\n\n# After: SP 설정에 따른 텍스트 토큰 샤딩\nsp_size = get_sp_world_size()\ntxt_is_sharded = (\n sp_size > 1\n and get_ring_parallel_world_size() == 1\n and txt_seq_len >= sp_size\n and not torch.is_grad_enabled()\n)\n\nif txt_is_sharded:\n sp_rank = get_sp_parallel_rank()\n base_text_shard_len = txt_seq_len // sp_size\n extra_text_tokens = txt_seq_len % sp_size\n # 불균형한 길이를 고려한 샤딩 인덱스 계산\n text_shard_start = base_text_shard_len * sp_rank + min(sp_rank, extra_text_tokens)\n text_shard_len = base_text_shard_len + (1 if sp_rank < extra_text_tokens else 0)\n txt = txt[:, text_shard_start : text_shard_start + text_shard_len].contiguous()\n\n\n여기서 주목할 점은 txt_seq_len % sp_size를 통해 나머지가 발생하는 경우에도 토큰이 누락되지 않도록 min(sp_rank, extra_text_tokens)를 사용하여 오프셋을 정교하게 계산한다는 것입니다. 이는 시퀀스 길이가 GPU 개수로 나누어떨어지지 않는 실제 상황을 완벽히 대응합니다.\n\n### 2. Varlen Ulysses Attention 도입 (layer.py)\n\n텍스트 토큰이 샤딩되면 각 GPU가 가진 전체 시퀀스(Image + Text)의 길이가 미세하게 달라질 수 있습니다. 기존의 UlyssesAttention은 고정된 길이를 가정하는 경우가 많았으나, 이번 변경을 통해 가변 길이(Variable Length)를 지원하는 All-to-All 통신 경로가 추가되었습니다.\n\n[Before & After]\npython\n# Before: 고정 길이를 가정한 All-to-All\nqkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1)\n\n# After: seq_lens가 제공될 경우 Varlen 전용 함수 호출\nif seq_lens is None:\n qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1)\nelse:\n # 불균형 시퀀스 길이를 지원하는 새로운 통신 경로\n qkv = _usp_input_all_to_all_varlen(qkv, seq_lens, head_dim=2)\n\n\n_usp_input_all_to_all_varlen 함수는 각 프로세스가 가진 데이터의 크기가 다르더라도 정확하게 head 차원과 sequence 차원을 교환(Transpose)할 수 있게 해줍니다. 이는 DeepSpeed Ulysses의 개념을 확장하여 유연성을 높인 구현입니다.\n\n### 3. Block 수준에서의 데이터 흐름 제어\n\nHunyuanVideo의 DoubleStreamBlockSingleStreamBlock은 이제 텍스트가 샤딩되었는지 여부와 각 샤드의 실제 길이를 인자로 전달받습니다.\n\n[HunyuanVideoBlock.forward]\npython\n# DoubleStreamBlock 내부에서의 처리\nif txt_is_sharded:\n # 이미지와 텍스트 쿼리를 합쳐서 한 번에 Attention 수행\n attn, _ = self.attn(\n torch.cat((img_q, txt_q), dim=1),\n torch.cat((img_k, txt_k), dim=1),\n torch.cat((img_v, txt_v), dim=1),\n seq_lens=seq_lens,\n )\n img_attn, txt_attn = attn.split([image_seq_len, text_seq_len], dim=1)\n\n\n텍스트가 샤딩되었을 때는 이미지와 텍스트를 dim=1(시퀀스 차원)로 결합하여 하나의 거대한 Attention 연산으로 처리합니다. 이때 앞서 계산한 seq_lens 정보를 넘겨줌으로써 분산 환경에서도 정확한 Attention Score 계산이 가능해집니다.\n\n## 왜 이게 좋은가?\n\n### 1. 실질적인 성능 향상\n벤치마크 결과에 따르면, H100 4개 환경에서 FastHunyuan 모델의 추론 속도가 눈에 띄게 개선되었습니다.\n- Steady Denoise Step: 605.24 ms -> 583.57 ms (-3.58%**)\n- Denoise Total: 최대 -5.70% 감소\n\n단순히 텍스트 토큰 몇 개를 나눈 것 같지만, Transformer 블록이 수십 개 겹쳐진 구조에서는 각 블록마다 발생하는 중복 연산과 메모리 I/O가 누적되어 큰 차이를 만듭니다.\n\n### 2. 메모리 효율성\n텍스트 토큰을 복제하지 않고 샤딩함으로써, 각 GPU가 유지해야 하는 활성화(Activation) 메모리 양이 줄어듭니다. 이는 더 긴 비디오를 생성하거나 더 큰 배치 사이즈를 사용할 수 있는 여유를 제공합니다.\n\n### 3. 일반적인 교훈: "No Redundant Work"\n분산 시스템 설계의 핵심 원칙 중 하나는 **"어떤 노드도 다른 노드가 이미 한 일을 반복하지 않게 하는 것"**입니다. 이번 최적화는 비록 텍스트 토큰이 이미지에 비해 짧더라도, 분산 처리의 대상에서 예외를 두지 않음으로써 시스템 전체의 효율성을 극대화했습니다.\n\n## 마치며\n\n이번 PR은 SGLang이 대규모 멀티모달 모델을 얼마나 깊이 있게 최적화하고 있는지를 보여주는 좋은 사례입니다. 특히 Ulysses SP의 유연성을 극대화한 varlen 지원은 향후 다양한 가변 길이 입력 모델에도 적용될 수 있는 중요한 인프라가 될 것입니다. 고성능 서빙 엔진을 고민하는 엔지니어라면, 이러한 세밀한 샤딩 전략을 반드시 참고해 보시기 바랍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글