본문으로 건너뛰기

[ACE-Step-1.5] Apple Silicon을 위한 네이티브 MLX DiT 백엔드 도입: 2-3배 성능 향상

PR 링크: ace-step/ACE-Step-1.5#439 상태: Merged | 변경: +1164 / -9

들어가며

Apple Silicon(M1/M2/M3/M4) 환경에서 PyTorch의 mps 백엔드를 사용하여 Diffusion Transformer(DiT) 모델을 실행할 때, 반복적인 추론 루프에서 발생하는 디스패치 및 동기화 오버헤드가 큰 병목 현상으로 작용합니다. 특히 오디오 생성과 같은 실시간성이 중요한 작업에서 이러한 오버헤드는 사용자 경험을 저해합니다. 본 PR은 이 문제를 해결하기 위해 Apple의 네이티브 프레임워크인 MLX를 도입하여 DiT 디코더 루프를 재구현함으로써, PyTorch-to-MPS 오버헤드를 완전히 제거하고 2-3배의 성능 향상을 달성했습니다.

코드 분석

1. acestep/handler.py: MLX 백엔드 통합 및 Fallback 로직

핵심 변경 사항은 AceStepHandler에 MLX 초기화 및 실행 경로를 추가한 것입니다. 기존 PyTorch 경로를 유지하면서, Apple Silicon 환경에서만 선택적으로 MLX를 사용하도록 설계되었습니다.

Before (기존):

# PyTorch MPS 경로만 존재
self.model.generate_audio(**generate_kwargs)

After (개선):

# MLX 가용성 확인 후 fast-path 실행
if self.use_mlx_dit and self.mlx_decoder is not None:
    try:
        return self._mlx_run_diffusion(...)
    except Exception as e:
        logger.warning(f"MLX diffusion failed: {e}, falling back to PyTorch")
        return self.model.generate_audio(...)

2. acestep/mlx_dit/convert.py: 가중치 변환

PyTorch의 state_dict를 MLX 배열로 변환하는 과정에서 텐서 레이아웃을 최적화했습니다. 예를 들어 Conv1d 레이어의 경우, PyTorch의 [out, in, K] 형식을 MLX의 효율적인 연산을 위해 [out, K, in]으로 재배치합니다.

# Conv1d 변환 예시
pt_w = sd["proj_in.1.weight"].numpy() # [256, 192, 2]
mlx_w = np.array(weight_dict["proj_in.weight"]) # [256, 2, 192]
# swapaxes(1, 2)를 통해 메모리 레이아웃 최적화

왜 이게 좋은가

  1. 오버헤드 제거: PyTorch의 mps 백엔드는 커널 디스패치 시 동기화 비용이 큽니다. MLX는 Metal 기반의 그래프 실행을 통해 이 오버헤드를 최소화합니다.
  2. 안정성: try-except 블록을 통해 MLX 실행 실패 시 즉시 기존 PyTorch 경로로 Fallback되도록 설계하여 사용자 경험을 보장합니다.
  3. 성능: 실제 벤치마크에서 DiT 디코더 루프 기준 2-3배의 wall-clock 속도 향상을 확인했습니다.

교훈: 특정 하드웨어(Apple Silicon)에 최적화된 프레임워크(MLX)를 선택적으로 활용하는 것은 범용 프레임워크(PyTorch)의 한계를 극복하는 매우 효과적인 전략입니다. 특히 추론 루프가 반복적일수록 이러한 네이티브 가속의 이점은 극대화됩니다.

리뷰어 피드백 반영

리뷰 과정에서 VAE 디코딩 시 VRAM 부족 문제(0GB vram detection)가 언급되었습니다. 이 PR은 DiT 루프를 가속화하지만, VAE 디코딩은 여전히 PyTorch MPS를 사용하므로 기존의 Tiled VAE 디코딩 로직과 호환성을 유지하도록 설계되었습니다. 또한, torch.compile이 활성화된 경우 MLX 경로를 비활성화하여 충돌을 방지했습니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글