[sglang] [AMD/ROCm] Temporal Unfolding을 통한 VAE Conv3D 성능 최적화 분석
PR 링크: sgl-project/sglang#22971 상태: Merged | 변경: +0 / -0
들어가며
비디오 생성 모델(Video Generation Model)의 추론 과정에서 VAE(Variational Autoencoder) 디코딩은 상당한 연산 자원을 소모합니다. 특히 Wan2.1과 같은 최신 모델은 시간축 정보를 처리하기 위해 Conv3d를 사용하는데, 이는 공간축만 처리하는 Conv2d에 비해 연산 복잡도가 높고 GPU 커널 최적화가 까다로운 경우가 많습니다.
최근 SGLang 프로젝트에 반영된 이 PR은 AMD ROCm 환경에서 Temporal Unfolding 기법을 사용하여 Conv3d를 수학적으로 동일한 Batched Conv2d로 변환하는 최적화를 구현했습니다. 이를 통해 시각적 품질 저하 없이 약 3.6%의 End-to-End 성능 향상을 이끌어냈습니다.
핵심 변경 사항 분석
1. Temporal Unfolding: 3D를 2D로 변환하는 마법
가장 핵심적인 로직은 rocm.py에 추가된 _conv3d_as_batched_conv2d 함수입니다. 3D 컨볼루션 커널 $(K_t, K_h, K_w)$을 시간축($T$)에 대해 펼쳐서(unfold) 배치(batch) 차원으로 밀어넣는 방식입니다.
Before (Standard Conv3d)
기존에는 PyTorch의 표준 F.conv3d를 호출하여 5D 텐서(N, C, T, H, W)를 처리했습니다.
After (Batched Conv2d)
# python/sglang/multimodal_gen/runtime/platforms/rocm.py
# (N, C_in, T, H, W) -> (N, T_out, Kt, C_in, H, W) -> (N*T_out, Kt*C_in, H, W)
unfolded = x_padded.unfold(2, kt, stride_t)
unfolded = unfolded.permute(0, 2, 5, 1, 3, 4).reshape(
N * T_out, kt * C_in, H, W
)
# 2D 컨볼루션 실행
out = F.conv2d(unfolded, w, b, stride=(stride_h, stride_w))
# 다시 5D 텐서로 복구
return out.reshape(N, T_out, C_out, H_out, W_out).permute(0, 2, 1, 3, 4)
이 방식은 시간축의 윈도우를 채널 차원($K_t \times C_{in}$)으로 통합하여, GPU에서 매우 고도로 최적화된 Conv2d 커널을 활용할 수 있게 합니다.
2. 가중치 사전 변환 및 캐싱 (Weight Pre-transformation)
매번 추론 시마다 가중치를 재배열하면 오버헤드가 발생합니다. 이 PR은 모델 로드 시점에 가중치를 미리 변환하여 버퍼로 등록합니다.
+ # Pre-compute the 2-D weight: [C_out, C_in, Kt, Kh, Kw]
+ # -> [C_out, Kt*C_in, Kh, Kw] (cached as a buffer)
+ weight_2d = (
+ child.weight.data.permute(0, 2, 1, 3, 4)
+ .reshape(child.out_channels, kt * child.in_channels, kh, kw)
+ .contiguous()
+ )
+ child.register_buffer("_weight_2d", weight_2d)
register_buffer를 통해 변환된 가중치를 모듈의 일부로 저장함으로써, 실제 forward 시에는 추가적인 연산 없이 즉시 최적화된 커널을 사용할 수 있습니다.
3. BF16 연산 선택권 제공
성능을 극대화하기 위해 SGLANG_USE_ROCM_VAE_CONV2D_BF16 환경 변수를 도입했습니다. 입력 데이터가 FP32이더라도 실제 컨볼루션 연산은 BF16에서 수행하고 다시 캐스팅하는 옵션입니다.
if compute_bf16 and orig_dtype != torch.bfloat16:
unfolded = unfolded.to(torch.bfloat16)
w = w.to(torch.bfloat16)
# ... 연산 후 ...
out = out.to(orig_dtype)
왜 이게 좋은 최적화인가?
1. 수학적 동등성 유지와 품질 보장
이 최적화는 근사치가 아닌 수학적으로 동일한 연산을 수행합니다. 테스트 결과 PSNR 38.12 dB, SSIM 0.9637로 원본과 시각적으로 거의 차이가 없음이 증명되었습니다. 이는 공격적인 최적화임에도 불구하고 모델의 정확도를 해치지 않는다는 점에서 매우 가치가 높습니다.
2. 하드웨어 가속기(ROCm) 특성 활용
많은 GPU 아키텍처, 특히 AMD의 ROCm 환경에서 Conv2d는 Conv3d보다 훨씬 더 성숙하고 최적화된 라이브러리(MIOpen 등) 지원을 받습니다. 복잡한 3D 연산을 익숙한 2D 연산으로 치환함으로써 하드웨어 잠재력을 더 끌어올린 사례입니다.
3. 유연한 적용 (Monkey Patching)
기존 모델 코드를 직접 수정하지 않고, optimize_vae 함수 내에서 named_modules()를 순회하며 조건에 맞는 레이어만 동적으로 교체(types.MethodType 사용)하는 방식을 택했습니다. 이는 코드의 침습성을 최소화하면서도 강력한 성능 향상을 제공합니다.
결론
이번 최적화는 "데이터의 형태를 변경하여 더 효율적인 연산 경로를 찾는다"는 엔지니어링의 정석을 보여줍니다. 특히 비디오 모델처럼 고차원 데이터를 다루는 경우, 차원 축소나 Unfolding을 통한 최적화가 얼마나 효과적인지 잘 보여주는 사례입니다.
AMD GPU에서 SGLang을 사용하는 유저라면 SGLANG_USE_ROCM_VAE_CONV2D_BF16=1 설정을 통해 즉시 3% 이상의 속도 향상을 경험해 보시기 바랍니다.
참고 자료
- https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html
- https://pytorch.org/docs/stable/generated/torch.nn.functional.conv2d.html
- https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [sglang] AMD GPU에서 FP8 MLA를 활용한 Diffusion 모델 성능 최적화
- [sglang] SGLang의 AMD AITER AllReduce 최적화: 하드코딩된 제약 제거 및 성능 개선
- [sglang] AMD ROCm 환경에서의 DeepSeek-V4 성능 최적화: Aiter MHC 커널 통합 분석
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [sglang] SGLang에서 GLM-5 모델 성능 최적화: Aiter 백엔드 활용 및 텐서 패딩 전략
PR Analysis 의 다른글
- 이전글 [sglang] AMD GPU에서 FP8 MLA를 활용한 Diffusion 모델 성능 최적화
- 현재글 : [sglang] [AMD/ROCm] Temporal Unfolding을 통한 VAE Conv3D 성능 최적화 분석
- 다음글 [flashinfer] FlashInfer, MoE 및 FP8 GEMM 성능 향상을 위한 커널 업데이트
댓글