본문으로 건너뛰기

[onnxruntime] ONNX Runtime CUTLASS FMHA: BiasLoader 정렬 문제 해결로 안정성 및 호환성 향상

PR 링크: microsoft/onnxruntime#28369 상태: Merged | 변경: +128 / -3

들어가며

최근 Microsoft의 ONNX Runtime 레포지토리에서는 microsoft/onnxruntime의 Pull Request(#28369)를 통해 CUTLASS 기반의 Flash-Masked Multi-Head Attention (FMHA) 구현에서 발생하던 중요한 버그를 수정했습니다. 이 PR은 특히 BiasLoader가 편향(bias) 데이터를 로드할 때 발생하는 정렬(alignment) 문제를 해결하여, 특정 조건에서 발생하던 cudaErrorMisalignedAddress 오류를 방지하고 커널의 안정성을 높이는 데 기여했습니다. 본 글에서는 이 PR의 변경 내용을 상세히 분석하고, 왜 이러한 수정이 성능과 안정성 측면에서 중요한지 기술적인 관점에서 설명하고자 합니다.

이 PR은 BiasLoaderisAligned 템플릿 플래그와 관계없이 항상 128비트 벡터 로드를 사용하도록 하드코딩되어 있던 문제를 해결합니다. 이로 인해 편향 스트라이드(bias stride)가 8의 배수가 아닐 때, 비정렬 커널 경로가 선택되더라도 BiasLoader는 여전히 128비트 로드를 시도하여 cudaErrorMisalignedAddress 오류를 발생시켰습니다. 이번 수정으로 kAlignmentA 값을 사용하여 커널 경로에 따라 적절한 로드 폭을 선택하게 되면서, 이러한 문제를 근본적으로 해결했습니다.

코드 분석

1. onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h 변경 분석

이 PR의 핵심 변경 사항은 kernel_forward.h 파일 내 AttentionKernel 구조체의 BiasLoader 타입 정의에 있습니다. 기존 코드에서는 편향 타일(bias tile)을 전역 메모리에서 공유 메모리로 효율적으로 로드하기 위해 TileSmemLoader를 사용했는데, 이때 로드 폭을 결정하는 마지막 템플릿 인자가 128 / cutlass::sizeof_bits<scalar_t>::value로 하드코딩되어 있었습니다. 이는 scalar_tfloat16일 경우 8 (128 / 16)이 되어 항상 128비트(8 * fp16 요소) 로드를 의미했습니다.

Before:

-        // input restriction: kv_len has to be a multiple of this value
-        128 / cutlass::sizeof_bits<scalar_t>::value);
+        kAlignmentA);

After:

+    // used for efficient load of bias tile Bij from global to shared memory.
+    // Use kAlignmentA so the unaligned kernel path (kIsAligned=false) uses
+    // narrower vectorized loads (64-bit instead of 128-bit), matching the
+    // relaxed alignment requirement for Q/K/V. This allows bias_strideM
+    // (= total_kv_length) to be any multiple of 4 elements (fp16) rather
+    // than requiring a multiple of 8.
     using BiasLoader = TileSmemLoader<
         scalar_t,
         cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
         MmaCore::kThreads,
-        // input restriction: kv_len has to be a multiple of this value
-        128 / cutlass::sizeof_bits<scalar_t>::value);
+        kAlignmentA>;

수정된 코드에서는 이 하드코딩된 값을 kAlignmentA로 변경했습니다. kAlignmentA는 일반적으로 커널의 정렬 요구사항에 따라 4 또는 8의 값을 가집니다. AttentionKernelkIsAligned 템플릿 플래그가 false (즉, 비정렬 커널 경로)일 때, kAlignmentA는 4가 되어 64비트(4 * fp16 요소) 로드를 사용하게 됩니다. 이는 Q, K, V 입력에 대한 완화된 정렬 요구사항과 일치하며, total_kv_length가 4의 배수이기만 하면 되도록 허용합니다. 이전에는 8의 배수여야만 했습니다.

이 변경은 편향 데이터 로딩 시 메모리 접근의 유연성을 크게 향상시킵니다. 특히, total_kv_length가 8의 배수가 아닌 경우에도 cudaErrorMisalignedAddress 오류 없이 정상적으로 동작하게 됩니다. 이는 다양한 시퀀스 길이와 배치 구성에서 FMHA를 더 안정적으로 사용할 수 있게 함을 의미합니다.

2. onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py 변경 분석

테스트 코드 부분에서는 새로운 테스트 클래스 TestONNXAttentionMHACutlassBiasAlignment가 추가되었습니다. 이 테스트는 수정된 BiasLoader의 정렬 로직이 다양한 total_kv_len 값에 대해 올바르게 작동하는지 검증하기 위해 설계되었습니다.

추가된 테스트 클래스:

+@unittest.skipIf(not has_cuda_device(53), "CUDA device not available, skipping MHA tests.")
+@patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"})
+class TestONNXAttentionMHACutlassBiasAlignment(unittest.TestCase):
+    """Test CUTLASS BiasLoader alignment with unaligned total_kv lengths.
+    ...
+    """
+    # ... (parameterized tests for decode and prompt scenarios)

이 테스트 클래스는 @patch.dict(os.environ, {"ORT_DISABLE_FLASH_ATTENTION": "1"}) 데코레이터를 사용하여 ONNX Runtime의 Flash Attention 기능을 비활성화하고 CUTLASS MEA(Memory Efficient Attention) 경로를 강제로 사용하도록 설정합니다. 이는 PR에서 수정된 BiasLoader 로직이 실제로 실행되는지 확인하기 위함입니다. 리뷰어 tianleiwu가 지적한 것처럼, SM80 이상 GPU에서는 기본적으로 Flash Attention이 우선적으로 사용될 수 있기 때문에, CUTLASS MEA 경로를 명시적으로 테스트하는 것이 중요합니다.

테스트는 parameterized.expand를 사용하여 다음과 같은 다양한 시나리오를 커버합니다:

  • test_mha_bias_alignment_decode: 디코딩 단계(q_sequence_length=1)에서 past_kv_lentotal_kv_len을 다르게 설정하여, 정렬된(total_kv_len % 8 == 0) 경우와 비정렬된(total_kv_len % 8 != 0) 경우 모두에서 올바른 결과를 생성하는지 확인합니다. Gemma4 Attention + mask 시나리오를 모든 시퀀스 길이 1-32에 대해 테스트했다고 명시되어 있습니다.
  • test_mha_bias_alignment_prompt: 프롬프트 단계(past_kv 없음)에서 kv_seq_len을 다양하게 설정하여 동일한 정렬 검증을 수행합니다.
  • test_gqa_bias_alignment_decode: Grouped Query Attention (GQA)과 같이 쿼리 헤드 수와 키/밸류 헤드 수가 다른 경우에도 동일한 정렬 검증을 수행합니다.

이러한 포괄적인 테스트는 수정 사항이 다양한 사용 사례에서 안정적으로 작동함을 보장합니다.

왜 이게 좋은가?

1. cudaErrorMisalignedAddress 오류 방지 및 안정성 향상

가장 직접적인 이점은 cudaErrorMisalignedAddress 오류가 발생하는 것을 방지하는 것입니다. 이 오류는 GPU가 메모리에서 데이터를 로드하려고 할 때, 해당 데이터의 주소가 하드웨어 요구사항(예: 벡터 로드의 경우 특정 바이트 경계)을 만족하지 못할 때 발생합니다. 이전에는 BiasLoadertotal_kv_length의 정렬 상태와 관계없이 128비트 로드를 강제했기 때문에, total_kv_length가 8의 배수가 아닐 때 문제가 발생했습니다. kAlignmentA를 사용하도록 변경함으로써, 비정렬 경로에서도 더 작은 단위(예: 64비트)로 로드할 수 있게 되어 이러한 충돌을 피할 수 있습니다.

2. 호환성 및 유연성 증대

이 수정은 FMHA 커널이 더 넓은 범위의 입력 데이터와 구성에 대해 호환되도록 만듭니다. 특히, total_kv_length가 8의 배수가 아닌 경우에도 모델이 정상적으로 작동할 수 있게 되므로, 사용자는 시퀀스 길이나 배치 크기를 선택할 때 이러한 제약에 얽매이지 않아도 됩니다. 이는 모델 배포 및 최적화 과정에서 상당한 유연성을 제공합니다.

3. CUTLASS 라이브러리 활용의 모범 사례

CUTLASS는 고성능 GPU 커널을 작성하기 위한 템플릿 라이브러리입니다. TileSmemLoader와 같은 CUTLASS의 컴포넌트를 사용할 때, 제공되는 정렬 관련 템플릿 파라미터(kAlignmentA 등)를 올바르게 활용하는 것이 중요합니다. 이 PR은 CUTLASS 라이브러리의 기능을 정확하게 이해하고 적용한 좋은 예시를 보여줍니다. 하드코딩 대신 동적인(템플릿 기반) 설정을 사용함으로써 라이브러리의 설계 의도를 따르고, 더 견고한 코드를 작성했습니다.

4. 테스트 커버리지 강화

새로운 테스트 케이스 추가는 코드 변경의 중요성을 강조합니다. 특히, 리뷰어의 피드백을 반영하여 Flash Attention을 비활성화하고 CUTLASS MEA 경로를 명시적으로 테스트함으로써, 수정된 로직이 실제 환경에서 의도한 대로 작동함을 검증했습니다. 이는 향후 유사한 문제가 발생했을 때 재발을 방지하고, 코드 품질을 유지하는 데 중요한 역할을 합니다.

결론

이번 ONNX Runtime PR(#28369)은 CUTLASS FMHA 구현에서 BiasLoader의 정렬 문제를 해결함으로써, CUDA 커널의 안정성과 호환성을 크게 향상시켰습니다. kAlignmentA를 사용하여 로드 폭을 동적으로 조절함으로써 cudaErrorMisalignedAddress 오류를 방지하고, 더 다양한 입력 조건에서 FMHA를 사용할 수 있게 되었습니다. 이는 ONNX Runtime을 사용하는 개발자들에게 더 견고하고 유연한 환경을 제공하며, 고성능 딥러닝 모델의 효율적인 배포에 기여할 것입니다. 또한, CUTLASS 라이브러리의 기능을 올바르게 활용하고 철저한 테스트를 통해 이를 검증하는 모범 사례를 보여줍니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글