[ACE-Step-1.5] MLX VAE 디코딩 메모리 최적화: Apple Silicon에서 피크 메모리 56% 절감
PR 링크: ace-step/ACE-Step-1.5#1042 상태: Merged | 변경: +None / -None
들어가며
최근 ACE-Step 프로젝트에서 Apple Silicon 기반 Mac 사용자들의 VAE(Variational Autoencoder) 디코딩 과정에서 발생하는 높은 피크 메모리 사용량 문제를 해결하기 위한 중요한 최적화 작업이 진행되었습니다. 특히 통합 메모리(unified memory) 아키텍처를 사용하는 Mac 환경에서 대용량 데이터를 처리할 때 스왑(swap) 발생으로 인한 성능 저하가 두드러졌습니다. 이 PR은 MLX 프레임워크를 사용하는 VAE 디코딩 로직의 청크(chunk) 처리 방식을 개선하여, 피크 GPU 메모리 사용량을 획기적으로 줄이는 동시에 스왑 압력을 완화하는 것을 목표로 합니다.
코드 분석: MLX VAE 디코딩 최적화
이번 PR은 주로 acestep/core/generation/handler/mlx_vae_decode_native.py 파일의 _mlx_decode_single 함수를 수정하여 VAE 디코딩 로직을 최적화했습니다. 핵심 변경사항은 디코딩 청크 크기 조정과 중간 활성화(intermediate activations) 캐시 정리입니다.
1. mlx_vae_decode_native.py: 청크 크기 조정 및 캐시 정리
이 파일에서는 VAE 디코딩 시 처리하는 mlx_chunk의 기본 크기를 2048에서 512 프레임으로 줄였습니다. 이는 통합 메모리 시스템에서 메모리 할당 및 해제의 효율성을 높여 피크 메모리 사용량을 줄이는 데 기여합니다.
Before:
decode_fn = self._resolve_mlx_decode_fn()
latent_frames = z_nlc.shape[1]
mlx_chunk = 2048
mlx_overlap = 64
if latent_frames <= mlx_chunk:
# No tiling needed
return decode_fn(z_nlc)
After:
decode_fn = self._resolve_mlx_decode_fn()
latent_frames = z_nlc.shape[1]
# Smaller chunks reduce peak memory on unified-memory Apple Silicon.
# 512 produces byte-identical output to 2048 with ~56% less peak GPU.
mlx_chunk = 512
mlx_overlap = 64
if latent_frames <= mlx_chunk:
# No tiling needed
return decode_fn(z_nlc)
주석에서 명시된 바와 같이, 512 프레임으로 청크 크기를 줄임으로써 2048 프레임과 동일한 바이트 단위 출력을 유지하면서 피크 GPU 메모리를 약 56% 절감할 수 있습니다. 이는 특히 메모리 용량이 제한적인 Apple Silicon 환경에서 매우 중요한 개선입니다.
또한, 디코딩 루프 내에서 각 청크 처리 후 중간 활성화를 명시적으로 해제하고 MLX 캐시를 주기적으로 정리하는 로직이 추가되었습니다. 이는 메모리 사용량을 더욱 최적화하는 데 도움을 줍니다.
Before:
trim_end = int(round((win_end - core_end) * upsample_factor))
audio_len = audio_chunk.shape[1]
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
decoded_parts.append(audio_chunk[:, trim_start:end_idx, :])
return mx.concatenate(decoded_parts, axis=1)
After:
trim_end = int(round((win_end - core_end) * upsample_factor))
audio_len = audio_chunk.shape[1]
end_idx = audio_len - trim_end if trim_end > 0 else audio_len
trimmed = audio_chunk[:, trim_start:end_idx, :]
mx.eval(trimmed)
decoded_parts.append(trimmed)
del audio_chunk, chunk, trimmed
if (idx + 1) % 4 == 0:
mx.clear_cache()
return mx.concatenate(decoded_parts, axis=1)
mx.eval(trimmed)은 MLX 연산을 즉시 실행하여 메모리에 로드되도록 합니다. 이후 del audio_chunk, chunk, trimmed를 통해 더 이상 필요 없는 중간 객체들을 명시적으로 삭제하여 가비지 컬렉션이 메모리를 회수할 수 있도록 합니다. 마지막으로 if (idx + 1) % 4 == 0: mx.clear_cache()는 4개의 청크마다 MLX의 내부 캐시를 비워줌으로써, 누적되는 중간 활성화로 인한 메모리 증가를 방지합니다. 이는 특히 긴 시퀀스를 처리할 때 메모리 안정성을 크게 향상시킵니다.
2. mlx_vae_native_test.py: 새로운 테스트 케이스 추가
새로운 최적화 로직이 올바르게 작동하는지 검증하기 위해 test_mlx_decode_single_default_chunk_tiles_at_512 테스트 케이스가 추가되었습니다. 이 테스트는 1500 프레임과 같이 이전에 단일 청크로 처리되었을 시퀀스가 이제 512 프레임 청크로 분할되어 처리되는지 확인합니다.
After:
def test_mlx_decode_single_default_chunk_tiles_at_512(self):
"""Default chunk=512 tiles sequences that previously decoded in one shot."""
host = _Host()
fake_mx_core = _fake_mx_core_module()
fake_mlx_pkg = types.ModuleType("mlx")
fake_mlx_pkg.__path__ = []
z_nlc = np.ones((1, 1500, 1), dtype=np.float32)
call_sizes = []
def tracking_decode(chunk):
call_sizes.append(chunk.shape[1])
return np.repeat(chunk, 2, axis=1)
with patch.dict(sys.modules, {"mlx": fake_mlx_pkg, "mlx.core": fake_mx_core}):
out = host._mlx_decode_single(z_nlc, decode_fn=tracking_decode)
self.assertEqual(tuple(out.shape), (1, 3000, 1))
self.assertGreater(len(call_sizes), 1)
for size in call_sizes:
self.assertLessEqual(size, 512)
이 테스트는 call_sizes 리스트를 통해 tracking_decode 함수가 호출될 때마다 전달되는 청크의 크기를 기록하고, 모든 청크 크기가 512 이하인지 확인하여 새로운 청크 분할 로직이 의도대로 작동함을 검증합니다. 이는 코드 변경이 예상대로 동작하며 회귀(regression)를 발생시키지 않음을 보장합니다.
왜 이게 좋은가?
이 PR은 Apple Silicon 환경에서 MLX 기반 VAE 디코딩의 메모리 효율성을 극적으로 향상시키는 좋은 최적화 사례입니다.
1. 피크 메모리 사용량 대폭 감소
벤치마크 결과에 따르면, 600초 길이의 전체 파이프라인 실행 시 피크 GPU 메모리(MLX)가 31.08 GB에서 13.44 GB로 약 56% 감소했습니다. 이는 통합 메모리 아키텍처를 사용하는 Mac에서 스왑 발생 가능성을 크게 줄여줍니다. 스왑은 디스크 I/O를 유발하여 전체 시스템 성능을 저하시키는 주범이므로, 이를 줄이는 것은 사용자 경험에 직접적인 영향을 미칩니다.
2. 스왑 압력 완화
스왑 델타(Swap delta) 또한 +12.2 GB에서 +9.8 GB로 감소했습니다. 비록 PyTorch MPS 드라이버의 메모리 할당 문제로 인해 여전히 스왑 압력이 존재하지만, 이 PR은 VAE 디코딩 단계에서 발생하는 스왑을 줄여 전체적인 시스템 안정성에 기여합니다.
3. 안정적인 성능과 정확성 유지
메모리 사용량을 줄이는 동시에, 출력 결과는 750, 1500, 3000, 15000 latent frames에서 바이트 단위로 동일함이 검증되었습니다. 이는 최적화가 성능 저하나 결과의 정확성 손실 없이 이루어졌음을 의미합니다.
4. 성능-메모리 트레이드오프의 현명한 선택
이 최적화는 VAE 디코딩 시간이 68초에서 78초로 약 15% 증가하는 성능 저하를 수반합니다. 하지만 피크 GPU 메모리를 56% 절감하고 스왑 압력을 완화하는 이점은 특히 메모리 제약이 있는 환경에서 훨씬 더 중요합니다. 이는 메모리 안정성(memory-stability)과 성능(speed) 사이의 현명한 트레이드오프를 보여주는 좋은 예시입니다.
일반적 교훈
- 청크 처리 및 타일링(Tiling) 전략: 대용량 데이터를 처리할 때는 데이터를 작은 청크로 나누어 처리하는 타일링 전략이 메모리 사용량을 줄이는 데 매우 효과적입니다. 특히 통합 메모리 시스템에서는 더욱 중요합니다.
- 명시적인 메모리 관리: Python과 같은 언어에서는 가비지 컬렉터가 메모리 관리를 담당하지만,
del키워드를 사용하여 더 이상 필요 없는 객체를 명시적으로 삭제하고,mx.clear_cache()와 같은 프레임워크별 캐시 정리 함수를 활용하는 것은 피크 메모리를 줄이는 데 큰 도움이 됩니다. - 벤치마킹 및 검증: 최적화 작업 후에는 반드시 실제 환경에서 벤치마킹을 수행하여 성능 변화를 측정하고, 기존 기능에 대한 회귀 테스트를 통해 정확성을 검증해야 합니다. 이 PR은 벤치마크와 바이트 단위의 결정론적 검증을 통해 이러한 원칙을 잘 따랐습니다.
- AI 지원 개발: 이 PR은 Claude Opus 4.6 및 GPT 5.4 High와 같은 AI 모델의 도움을 받아 개발 및 검토되었습니다. 이는 AI가 소프트웨어 개발 프로세스에서 효과적인 도구로 활용될 수 있음을 보여줍니다.
이 PR은 단순히 코드를 변경하는 것을 넘어, 특정 하드웨어 아키텍처의 특성을 이해하고 그에 맞는 최적화 전략을 적용하여 사용자 경험을 실질적으로 개선한 모범적인 사례입니다.
참고 자료
- https://ml-explore.github.io/mlx/build/html/index.html
- https://ml-explore.github.io/mlx/build/html/api/generated/mlx.core.eval.html
- https://ml-explore.github.io/mlx/build/html/api/generated/mlx.core.clear_cache.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] sglang 성능 최적화: torch.compile 퓨전 복원을 통한 TopK 후처리 개선
- 현재글 : [ACE-Step-1.5] MLX VAE 디코딩 메모리 최적화: Apple Silicon에서 피크 메모리 56% 절감
- 다음글 [sglang] SGLang AMD 환경에서의 GLM-5-FP8 성능 벤치마크 도입 및 최적화
댓글