본문으로 건너뛰기

[transformers] Apple Silicon의 MPS에서 Flash Attention 최적화: 속도와 효율성 향상

PR 링크: huggingface/transformers#45974 상태: Merged | 변경: +30 / -5

들어가며

최근 LLM(거대 언어 모델)의 발전 속도는 눈부십니다. 특히 Apple Silicon과 같이 효율적인 하드웨어가 주목받으면서, 해당 환경에서의 모델 추론 성능 최적화는 매우 중요한 과제가 되었습니다. Hugging Face의 transformers 라이브러리는 다양한 모델과 추론 환경을 지원하며 지속적으로 발전하고 있습니다. 이번 PR은 Apple Silicon의 MPS(Metal Performance Shaders) 환경에서 generategenerate_batch 함수의 성능을 크게 향상시키는 중요한 개선을 담고 있습니다.

기존에는 MPS 환경에서 Flash Attention의 성능을 제대로 활용하지 못하는 경우가 있었습니다. 이 PR은 kernels-community/metal-flash-sdpa라는 커뮤니티에서 개발된 최적화된 커널을 MPS 환경에서 기본적으로 사용하도록 설정하고, 연속 배치 처리 시 메모리 관리의 정확성을 높여 전반적인 추론 속도를 개선하는 것을 목표로 합니다.

본 글에서는 이 PR이 어떤 문제를 해결하고, 어떤 코드 변경을 통해 성능을 향상시키는지, 그리고 이러한 최적화가 왜 좋은지에 대해 자세히 알아보겠습니다.

코드 분석

이번 PR의 핵심 변경 사항은 크게 두 부분으로 나눌 수 있습니다: model_manager.py에서의 어텐션 구현 자동 선택 로직 추가, 그리고 requests.py에서의 MPS 메모리 관리 정확도 개선입니다.

1. src/transformers/cli/serving/model_manager.py

이 파일은 모델 로딩 및 관리를 담당하는 ModelManager 클래스를 정의합니다. 이번 PR에서는 __init__ 메서드와 새로 추가된 _resolve_attn_implementation 메서드를 통해 MPS 환경에서의 어텐션 구현 방식을 개선했습니다.

변경 전:

-        self.attn_implementation = attn_implementation
+        self.attn_implementation = self._resolve_attn_implementation(attn_implementation, self.device)

변경 후:

+    @classmethod
+    def _resolve_attn_implementation(cls, attn_implementation: str | None, device: str | int) -> str | None:
+        r"""
+        Default to a fast kernel for `mps` when available.
+        """
+        if attn_implementation is not None:
+            return attn_implementation
+
+        import torch
+
+        from ...integrations.hub_kernels import _kernels_available
+
+        is_mps_device = (
+            isinstance(device, str)
+            and device.startswith("mps")
+            or (device == "auto" and torch.backends.mps.is_available() and not torch.cuda.is_available())
+        )
+        if is_mps_device and _kernels_available:
+            logger.warning_once(
+                "MPS detected and `kernels` is installed: defaulting attention to "
+                "`kernels-community/metal-flash-sdpa@223ca3350d7ba32ecf19341ff2cbb8c43fa47d62. "
+                "Pass `--attn-implementation sdpa` to opt out."
+            )
+            return "kernels-community/metal-flash-sdpa@223ca3350d7ba32ecf19341ff2cbb8c43fa47d62"
+        return attn_implementation

설명:

  • __init__ 메서드에서 기존에는 단순히 전달받은 attn_implementation 값을 그대로 사용했습니다. 하지만 변경 후에는 _resolve_attn_implementation 메서드를 호출하여 attn_implementation을 결정합니다.
  • _resolve_attn_implementation 메서드는 다음과 같은 로직을 수행합니다:
    • 만약 attn_implementation이 명시적으로 제공되었다면, 해당 값을 그대로 반환합니다. 사용자가 특정 어텐션 구현을 원할 경우 이를 존중합니다.
    • device가 MPS 장치인지 확인합니다. device가 문자열로 "mps"로 시작하거나, "auto"이고 MPS가 사용 가능하며 CUDA가 사용 불가능한 경우 MPS 장치로 간주합니다.
    • MPS 장치이고, kernels 라이브러리가 설치되어 있다면 (_kernels_available 확인), 기본 어텐션 구현을 kernels-community/metal-flash-sdpa로 설정합니다. 이 커널은 Apple Silicon의 Metal API에 최적화되어 있어 상당한 성능 향상을 기대할 수 있습니다.
    • 이때, 사용자에게 어떤 어텐션 구현이 선택되었는지, 그리고 원치 않을 경우 --attn-implementation sdpa 옵션으로 비활성화할 수 있음을 알리는 경고 메시지를 출력합니다.
    • 위 조건에 해당하지 않으면, 원래의 attn_implementation 값을 반환합니다.

이 변경은 사용자가 별도의 설정을 하지 않아도 MPS 환경에서 자동으로 최적화된 Flash Attention 커널을 사용하게 함으로써, 별도의 노력 없이 성능 향상을 누릴 수 있게 합니다. 이는 사용자 경험을 크게 개선하는 중요한 개선입니다.

2. src/transformers/generation/continuous_batching/requests.py

이 파일은 연속 배치 처리(Continuous Batching) 시 요청(Request) 관련 로직을 다룹니다. 특히 get_device_and_memory_breakdown 함수에서 MPS 장치의 메모리 사용량 계산 방식을 수정했습니다.

변경 전:

-        # MPS memory reporting (PyTorch 2.0+)
-        total_memory = torch.mps.driver_allocated_memory()
-        allocated_memory = total_memory - getattr(torch.mps, "recommended_max_memory")()
-        reserved_memory = 0  # MPS does not track reserved separately
+        total_memory = torch.mps.recommended_max_memory()
+        allocated_memory = torch.mps.current_allocated_memory()
+        reserved_memory = torch.mps.driver_allocated_memory()

설명:

  • 기존 코드에서는 MPS의 메모리 관련 함수(torch.mps.driver_allocated_memory, torch.mps.recommended_max_memory)를 사용하여 total_memory, allocated_memory, reserved_memory를 계산했습니다. 하지만 이 계산 방식이 MPS의 실제 메모리 사용량을 정확하게 반영하지 못했을 가능성이 있습니다.
  • 변경 후에는 PyTorch 2.0 이상에서 제공하는 MPS 메모리 관련 API를 더 정확하게 사용하도록 수정되었습니다:
    • total_memory: torch.mps.recommended_max_memory()를 사용하여 권장 최대 메모리를 가져옵니다. 이는 시스템에서 MPS에 할당할 수 있는 최대 메모리 양을 나타낼 수 있습니다.
    • allocated_memory: torch.mps.current_allocated_memory()를 사용하여 현재 할당된 메모리 양을 직접 가져옵니다. 이는 모델이나 연산에 의해 실제로 사용 중인 메모리를 나타냅니다.
    • reserved_memory: torch.mps.driver_allocated_memory()를 사용하여 드라이버가 할당한 메모리 양을 가져옵니다. 이는 PyTorch가 관리하는 메모리 풀과는 별개로, 드라이버 수준에서 예약된 메모리를 의미할 수 있습니다.

이 변경은 연속 배치 처리 시 MPS 장치의 메모리 사용량을 더 정확하게 파악하고 관리하는 데 도움을 줍니다. 정확한 메모리 정보는 효율적인 자원 할당 및 잠재적인 메모리 부족 오류 방지에 필수적입니다. 리뷰어(remi-or)가 기본값(default) 누락을 지적한 부분과 관련하여, 이 수정은 해당 문제를 해결하고 더 신뢰할 수 있는 메모리 추적을 제공합니다.

왜 이게 좋은가?

이번 PR의 가장 큰 장점은 성능 향상사용 편의성 증대입니다.

1. 획기적인 성능 향상

PR 설명에 포함된 벤치마크 결과는 매우 인상적입니다:

  • 기존 sdpa 구현: 149.33초 소요, 158.4 tok/s
  • kernels-community/metal-flash-sdpa 사용: 89.78초 소요, 256.0 tok/s

이는 1.66배의 속도 향상을 의미합니다. 동일한 작업에 걸리는 시간이 크게 단축되어, 특히 대규모 언어 모델을 Apple Silicon 환경에서 서비스하거나 실험하는 경우 생산성이 크게 향상될 것입니다. 정확도 또한 거의 동일하게 유지되면서 성능만 개선된 점은 매우 고무적입니다.

2. 사용자 경험 개선

model_manager.py의 변경 사항 덕분에, 사용자는 별도의 복잡한 설정 없이도 MPS 환경에서 자동으로 최적화된 metal-flash-sdpa 커널을 사용할 수 있게 되었습니다. 이는 Hugging Face transformers 라이브러리가 다양한 하드웨어 환경을 최대한 활용하도록 지원하는 좋은 예시입니다. 사용자는 --attn-implementation sdpa와 같은 옵션을 통해 원치 않을 경우 이전 구현으로 쉽게 전환할 수도 있습니다.

3. 정확한 메모리 관리

requests.py의 메모리 관리 로직 개선은 연속 배치 처리와 같이 메모리 효율성이 중요한 시나리오에서 안정성을 높입니다. 정확한 메모리 사용량 추적은 모델 로딩, 배치 크기 조정, 그리고 잠재적인 Out-of-Memory(OOM) 오류 방지에 필수적입니다. 이는 특히 MPS와 같이 메모리 관리가 까다로울 수 있는 환경에서 더욱 중요합니다.

일반적 교훈

  • 하드웨어별 최적화의 중요성: 특정 하드웨어 아키텍처(예: Apple Silicon의 Metal)에 최적화된 커널을 활용하는 것은 상당한 성능 향상을 가져올 수 있습니다. 라이브러리는 이러한 최적화된 커널을 자동으로 감지하고 적용하는 메커니즘을 제공해야 합니다.
  • 자동화된 기본값 설정: 사용자가 복잡한 설정을 하지 않아도 최적의 성능을 경험할 수 있도록, 환경에 맞는 최적의 설정을 기본값으로 제공하는 것이 중요합니다.
  • 정확한 자원 모니터링: 특히 동적 자원 할당이 필요한 연속 배치 처리와 같은 고급 기능에서는, 시스템 자원(메모리, GPU 등)을 정확하게 모니터링하고 관리하는 것이 안정성과 성능의 핵심입니다.

결론

Hugging Face transformers 라이브러리의 이번 PR은 Apple Silicon의 MPS 환경에서 LLM 추론 성능을 획기적으로 개선하는 중요한 발걸음입니다. kernels-community/metal-flash-sdpa의 자동 적용과 정확한 메모리 관리 개선을 통해, 사용자들은 별도의 노력 없이도 더 빠르고 안정적인 모델 추론을 경험할 수 있게 되었습니다. 이는 LLM 기술의 접근성을 높이고, 다양한 하드웨어 환경에서의 활용을 더욱 촉진할 것입니다.

References

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글