[sglang] NPU 성능 향상을 위한 causal_conv1d_update_v2 도입
PR 링크: sgl-project/sglang#24595 상태: Merged | 변경: +0 / -0
들어가며
최근 대규모 언어 모델(LLM)의 발전과 함께, 모델 추론 속도는 사용자 경험과 직결되는 중요한 요소가 되었습니다. 특히 NPU(신경망 처리 장치)와 같은 특수 하드웨어 환경에서는 최적의 성능을 이끌어내기 위한 세심한 최적화가 필수적입니다. 본 글에서는 sglang 레포지토리의 PR ([NPU] use causal_conv1d_update_v2 for performance)을 분석하여, NPU 환경에서 기존의 torch.ops.npu.causal_conv1d_update 대신 causal_conv1d_update_v2를 도입함으로써 얻게 된 성능 향상과 그 배경에 대해 상세히 알아보겠습니다.
이 PR은 NPU 환경에서의 모델 추론 속도를 개선하는 것을 목표로 합니다. 특히, Mamba와 같은 모델에서 사용되는 인과적 컨볼루션(causal convolution) 연산의 핵심 부분을 더 효율적인 함수로 교체함으로써, 실제 추론 시간을 크게 단축시키는 성과를 달성했습니다.
코드 분석
이번 PR의 핵심 변경 사항은 ascend_gdn_backend.py 파일 내에서 인과적 컨볼루션 업데이트 로직을 수정하는 것입니다. 이전에는 torch.ops.npu.causal_conv1d_update를 사용했지만, 이제는 causal_conv1d_update_v2라는 새로운 함수를 사용하도록 변경되었습니다.
ascend_gdn_backend.py 파일 변경 분석
1. causal_conv1d_update_v2 함수 임포트
먼저, sgl_kernel_npu.mamba.causal_conv1d 모듈에서 causal_conv1d_update_v2 함수를 임포트합니다.
--- a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py
+++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py
@@ -8,6 +8,7 @@
from sgl_kernel_npu.mamba.causal_conv1d import (
causal_conv1d_fn_npu,
causal_conv1d_update_npu,
+ causal_conv1d_update_v2,
)
from sglang.srt.hardware_backend.npu.attention.ascend_hybrid_linear_attn_backend import (
이 변경은 causal_conv1d_update_v2 함수를 사용하기 위한 준비 단계입니다.
2. forward_extend 함수 내 로직 변경
가장 중요한 변경은 forward_extend 함수 내부에서 발생합니다. 기존의 torch.ops.npu.causal_conv1d_update 호출 부분이 causal_conv1d_update_v2를 사용하도록 수정되었습니다.
Before:
@@ -236,23 +235,24 @@
b = b[: forward_batch.num_token_non_padded_cpu]
seq_len = forward_batch.num_token_non_padded_cpu
- mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1)
- num_accept_tokens = torch.full(
+ batch_size = cache_indices.shape[0]
+ draft_token_num = forward_batch.spec_info.draft_token_num
+ num_accepted_tokens = torch.full(
(batch_size,),
draft_token_num,
dtype=torch.int32,
device=mixed_qkv.device,
)
- mixed_qkv = torch.ops.npu.causal_conv1d_update(
- mixed_qkv_reshaped,
- layer.conv_weights.transpose(0, 1).contiguous(),
- conv_states,
- cache_indices,
- layer.bias,
- num_accept_tokens,
- None,
- layer.activation == "silu",
- self.pad_slot_id,
+ mixed_qkv = causal_conv1d_update_v2(
+ x=mixed_qkv.view(batch_size, draft_token_num, -1).contiguous(),
+ conv_state=conv_states.contiguous(),
+ weight=layer.conv_weights.transpose(0, 1).contiguous(),
+ bias=layer.bias,
+ activation=layer.activation,
+ conv_state_indices=cache_indices,
+ num_accepted_tokens=num_accepted_tokens,
+ pad_slot_id=-1,
+ validate_data=False,
).view(seq_len, -1)
else:
mixed_qkv = mixed_qkv.transpose(0, 1)
After:
@@ -236,23 +235,24 @@
b = b[: forward_batch.num_token_non_padded_cpu]
seq_len = forward_batch.num_token_non_padded_cpu
- mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1)
- num_accept_tokens = torch.full(
+ batch_size = cache_indices.shape[0]
+ draft_token_num = forward_batch.spec_info.draft_token_num
+ num_accepted_tokens = torch.full(
(batch_size,),
draft_token_num,
dtype=torch.int32,
device=mixed_qkv.device,
)
- mixed_qkv = torch.ops.npu.causal_conv1d_update(
- mixed_qkv_reshaped,
- layer.conv_weights.transpose(0, 1).contiguous(),
- conv_states,
- cache_indices,
- layer.bias,
- num_accept_tokens,
- None,
- layer.activation == "silu",
- self.pad_slot_id,
+ mixed_qkv = causal_conv1d_update_v2(
+ x=mixed_qkv.view(batch_size, draft_token_num, -1).contiguous(),
+ conv_state=conv_states.contiguous(),
+ weight=layer.conv_weights.transpose(0, 1).contiguous(),
+ bias=layer.bias,
+ activation=layer.activation,
+ conv_state_indices=cache_indices,
+ num_accepted_tokens=num_accepted_tokens,
+ pad_slot_id=-1,
+ validate_data=False,
).view(seq_len, -1)
else:
mixed_qkv = mixed_qkv.transpose(0, 1)
이 변경에서 주목할 점은 다음과 같습니다:
- 함수 호출 변경:
torch.ops.npu.causal_conv1d_update가causal_conv1d_update_v2로 변경되었습니다. - 인자 전달 방식 변경:
causal_conv1d_update_v2는 함수 시그니처가 다르므로, 인자들이 새로운 형식에 맞게 재구성되었습니다. 예를 들어,mixed_qkv는.view(batch_size, draft_token_num, -1)형태로 먼저 reshape된 후.contiguous()가 적용되어x인자로 전달됩니다.conv_states역시.contiguous()가 적용되어conv_state인자로 전달됩니다.layer.conv_weights.transpose(0, 1).contiguous()는weight인자로,layer.bias는bias인자로,layer.activation은activation인자로 전달됩니다.cache_indices는conv_state_indices로,num_accept_tokens는num_accepted_tokens로 이름이 변경되어 전달됩니다.pad_slot_id와validate_data는 새로운 기본값으로 설정되었습니다. is_npu()체크 추가: PR 설명에 따르면,sgl_kernel_npu모듈을 정상적으로 임포트하기 위해is_npu()체크 로직이 추가되었을 가능성이 있습니다. 이는 NPU 환경이 아닐 때 발생할 수 있는 import 오류를 방지하는 중요한 수정입니다. (이 부분은 제공된 diff에는 직접적으로 보이지 않지만, PR 설명에 언급되어 있습니다.)
causal_conv1d_update_v2란 무엇인가?
causal_conv1d_update_v2는 기존의 causal_conv1d_update 함수보다 개선된 성능을 제공하는 것으로 보입니다. 구체적인 내부 구현은 NPU 하드웨어의 특성을 활용하여 연산을 최적화했을 가능성이 높습니다. 예를 들어, 메모리 접근 패턴을 개선하거나, 병렬 처리 효율을 높이는 등의 최적화가 적용되었을 수 있습니다. validate_data=False와 같이 일부 인자의 기본값이 변경된 것은, 특정 사용 사례에 맞춰 불필요한 검증 단계를 생략하여 성능을 향상시키려는 의도로 해석될 수 있습니다.
왜 이게 좋은가?
이번 PR의 가장 큰 장점은 실질적인 추론 속도 향상입니다.
PR 설명에 따르면, 다음과 같은 성능 개선이 확인되었습니다:
- Qwen3.5-397B 모델: 추론 속도가 약 300us에서 140us로 향상되었습니다. 이는 약 2배 이상의 성능 향상입니다.
- Qwen3-next 모델: 추론 속도가 약 550us에서 200us로 향상되었습니다. 이 역시 2배 이상의 성능 향상입니다.
이러한 성능 향상은 다음과 같은 이유로 긍정적으로 평가될 수 있습니다:
- 효율적인 함수 사용:
causal_conv1d_update_v2는 NPU 하드웨어에 더 최적화된 구현일 가능성이 높습니다. 이는 동일한 연산을 더 적은 시간과 자원으로 처리할 수 있게 하여 성능을 극대화합니다. - LLM 추론 속도 개선: LLM의 추론 속도는 실시간 대화형 애플리케이션이나 대규모 서비스 운영에 있어 매우 중요합니다. 이러한 속도 개선은 사용자 경험을 향상시키고, 더 많은 요청을 더 적은 자원으로 처리할 수 있게 하여 비용 효율성을 높입니다.
- 안정성 확보:
is_npu()체크 로직 추가는 NPU 환경이 아닌 경우 발생할 수 있는 잠재적인 오류를 방지하여 코드의 안정성을 높입니다. 이는 다양한 환경에서의 호환성을 보장하는 데 기여합니다.
일반적인 교훈: 특수 하드웨어(NPU, GPU 등)를 사용할 때는 해당 하드웨어에 최적화된 라이브러리나 함수를 적극적으로 활용하는 것이 중요합니다. 때로는 기존 함수를 단순히 호출하는 것보다, 해당 하드웨어 벤더가 제공하는 최신 또는 특화된 API를 사용하는 것이 훨씬 큰 성능 향상을 가져올 수 있습니다. 또한, 성능 테스트와 프로파일링을 통해 병목 구간을 식별하고, 이를 개선하기 위한 구체적인 함수 교체나 로직 변경을 시도하는 것이 효과적인 최적화 전략입니다.
결론
이번 PR은 NPU 환경에서 causal_conv1d_update_v2를 도입하여 Mamba와 같은 모델의 추론 속도를 획기적으로 개선한 좋은 사례입니다. 약 2배 이상의 성능 향상은 LLM 서비스의 효율성과 사용자 경험을 크게 증진시킬 수 있습니다. 이는 하드웨어 특화 최적화의 중요성을 다시 한번 강조하며, 앞으로도 이러한 성능 개선 노력이 지속될 것으로 기대됩니다.
References
- sglang GitHub PR (실제 PR 번호로 대체 필요)
- torch.ops.npu.causal_conv1d_update (실제 NPU SDK 문서 링크 필요)
- causal_conv1d_update_v2 (실제 NPU SDK 문서 링크 필요)
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [vllm] vLLM Mamba2 SSD 커널 웜업: 첫 요청 지연 시간 91% 감소의 비결
- 현재글 : [sglang] NPU 성능 향상을 위한 causal_conv1d_update_v2 도입
- 다음글 [vllm] vLLM W8W8 그룹 양자화 성능 최적화: 2D-Grid를 통한 Divmod 제거
댓글