[ollama] Ollama MLX Gemma4 성능 최적화: Fused Operations를 통한 효율성 증대
PR 링크: ollama/ollama#15587 상태: Merged | 변경: +None / -None
들어가며
최근 대규모 언어 모델(LLM)의 온디바이스 배포가 중요해지면서, 효율적인 추론(inference) 성능은 핵심적인 요소가 되었습니다. Ollama는 다양한 LLM을 로컬 환경에서 쉽게 실행할 수 있도록 돕는 강력한 도구이며, MLX와 같은 최적화된 백엔드를 활용하여 성능을 극대화하고 있습니다. 이번에 분석할 PR은 ollama/ollama 레포지토리에서 MLX 백엔드를 사용하는 Gemma4 모델의 성능을 개선하기 위한 중요한 최적화 작업을 담고 있습니다.
이 PR의 핵심 목표는 Gemma4 모델의 추론 속도를 향상시키는 것입니다. 특히, 모델 내부에서 자주 사용되는 일련의 연산들을 단일 "fused operation"으로 묶어 처리함으로써, 연산 오버헤드를 줄이고 GPU/NPU와 같은 가속기 하드웨어의 활용 효율을 높이는 데 중점을 둡니다. 이는 GELUApprox와 Multiply 조합, 그리고 Divide, Tanh, Multiply 조합과 같은 여러 단계의 연산을 하나의 커널로 컴파일하여 실행하는 방식입니다. 이러한 최적화를 통해 Gemma4 모델의 prefill 및 gen 속도가 모델 크기에 따라 최대 16.6%까지 향상되었습니다.
코드 분석: Fused Operations를 통한 성능 개선
이번 PR은 크게 두 가지 파일에서 변경 사항을 포함하고 있습니다. x/mlxrunner/mlx/act.go 파일에서는 새로운 fused operation을 정의하고, x/models/gemma4/gemma4.go 파일에서는 이 새로운 operation들을 Gemma4 모델의 추론 경로에 적용합니다.
x/mlxrunner/mlx/act.go: 새로운 Fused Operations 정의
이 파일에서는 두 가지 새로운 fused operation인 GeGLU와 LogitSoftcap이 Compile2 함수를 사용하여 정의되었습니다. Compile2는 MLX의 컴파일러 기능을 활용하여 여러 MLX 연산을 하나의 최적화된 커널로 묶는 역할을 합니다.
GeGLU 정의
GeGLU는 Gemma-family MLP 및 MoE 경로에서 사용되는 gelu_approx(gate) * up 연산을 퓨징합니다. 기존에는 GELUApprox와 Multiply 두 단계로 나뉘어 실행되던 연산을 하나로 묶습니다.
Before: (개념적 코드, 실제 act.go에는 없지만 gemma4.go에서 사용)
gate := mlx.GELUApprox(input)
result := mlx.Mul(gate, up)
After:
--- a/x/mlxrunner/mlx/act.go
+++ b/x/mlxrunner/mlx/act.go
@@ -42,3 +42,23 @@ var SwiGLU = Compile2(
},
Shapeless(),
)
+
+// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
+// geglu, used by Gemma-family MLP and MoE paths.
+var GeGLU = Compile2(
+ "GeGLU",
+ func(gate, up *Array) *Array {
+ return GELUApprox(gate).Multiply(up)
+ },
+ Shapeless(),
+)
GELUApprox(gate).Multiply(up)는 이제 MLX 런타임에 의해 단일 최적화된 커널로 컴파일되어 실행됩니다. 이는 중간 결과를 메모리에 쓰고 다시 읽는 오버헤드를 줄여줍니다.
LogitSoftcap 정의
LogitSoftcap은 tanh(x / cap) * cap 연산을 퓨징합니다. 이는 mlx_lm의 logit_softcap과 일치하며, Divide, Tanh, Multiply 세 단계의 연산을 하나로 묶습니다.
Before: (개념적 코드, 실제 act.go에는 없지만 gemma4.go에서 사용)
logits = mlx.MulScalar(logits, m.SoftcapInv)
logits = logits.Tanh()
logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
After:
--- a/x/mlxrunner/mlx/act.go
+++ b/x/mlxrunner/mlx/act.go
@@ -42,3 +42,23 @@ var SwiGLU = Compile2(
},
Shapeless(),
)
+
+// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
+// geglu, used by Gemma-family MLP and MoE paths.
+var GeGLU = Compile2(
+ "GeGLU",
+ func(gate, up *Array) *Array {
+ return GELUApprox(gate).Multiply(up)
+ },
+ Shapeless(),
+)
+
+// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
+// mlx_lm's logit_softcap. cap must have the same dtype as x.
+var LogitSoftcap = Compile2(
+ "LogitSoftcap",
+ func(x, cap *Array) *Array {
+ return x.Divide(cap).Tanh().Multiply(cap)
+ },
+ Shapeless(),
+)
마찬가지로 x.Divide(cap).Tanh().Multiply(cap) 연산이 단일 커널로 컴파일되어 실행됩니다. 이는 특히 FinalLogitSoftcapping이 활성화된 경우 최종 로짓 계산 단계에서 성능 이점을 가져옵니다.
x/models/gemma4/gemma4.go: Gemma4 모델에 Fused Operations 적용
이 파일에서는 위에서 정의된 GeGLU와 LogitSoftcap을 Gemma4 모델의 추론 경로에 통합합니다. 이는 기존의 개별 연산 호출을 새로운 fused operation 호출로 대체하는 방식입니다.
Unembed 함수 변경
Unembed 함수는 모델의 최종 로짓을 계산하는 부분입니다. FinalLogitSoftcapping이 적용될 때 LogitSoftcap을 사용하도록 변경되었습니다.
Before:
--- a/x/models/gemma4/gemma4.go
+++ b/x/models/gemma4/gemma4.go
@@ -1114,9 +1110,8 @@ func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
logits := m.LMHead.Forward(x)
if m.FinalLogitSoftcapping > 0 {
- logits = mlx.MulScalar(logits, m.SoftcapInv)
- logits = logits.Tanh()
- logits = mlx.MulScalar(logits, m.FinalLogitSoftcapping)
+ cap := mlx.FromValue(m.FinalLogitSoftcapping).AsType(logits.DType())
+ logits = mlx.LogitSoftcap(logits, cap)
}
return logits
m.SoftcapInv 필드는 더 이상 사용되지 않으므로 TextConfig 구조체에서도 제거되었습니다. 이는 리뷰어 jessegross의 코멘트(m.SoftcapInv is now unused, I think.)와 일치하는 변경 사항입니다.
DecoderLayer.Forward 함수 변경
DecoderLayer의 Forward 함수 내 PLE injection 부분에서 GeGLU가 사용되도록 변경되었습니다.
Before:
--- a/x/models/gemma4/gemma4.go
+++ b/x/models/gemma4/gemma4.go
@@ -1231,8 +1226,7 @@ func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Tex
// PLE injection (after MLP residual).
if l.PLE != nil && pleInput != nil {
residual := h
- gate := mlx.GELUApprox(l.PLE.InputGate.Forward(h))
- gated := mlx.Mul(gate, pleInput)
+ gated := mlx.GeGLU(l.PLE.InputGate.Forward(h), pleInput)
projected := l.PLE.Projection.Forward(gated)
projected = mlx.RMSNormFn(projected, l.PLE.PostNormScaled, cfg.RMSNormEps)
h = mlx.Add(residual, projected)
MLP.Forward 함수 변경
MLP의 Forward 함수에서도 GeGLU가 적용되었습니다. 이는 Gemma 모델의 핵심 구성 요소 중 하나인 MLP 블록의 효율성을 높입니다.
Before:
--- a/x/models/gemma4/gemma4.go
+++ b/x/models/gemma4/gemma4.go
@@ -1375,9 +1369,9 @@ func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding b
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
- gate := mlx.GELUApprox(m.GateProj.Forward(x))
+ gate := m.GateProj.Forward(x)
up := m.UpProj.Forward(x)
- return m.DownProj.Forward(mlx.Mul(gate, up))
+ return m.DownProj.Forward(mlx.GeGLU(gate, up))
}
// Forward runs the router to select top-k experts per token.
MoEBlock.Forward 함수 변경
Mixture-of-Experts (MoE) 블록의 Forward 함수 내에서도 GeGLU가 여러 곳에 적용되었습니다. MoE 모델은 여러 전문가 네트워크를 활용하므로, 이 부분의 최적화는 전체 모델 성능에 큰 영향을 미칠 수 있습니다.
Before:
--- a/x/models/gemma4/gemma4.go
+++ b/x/models/gemma4/gemma4.go
@@ -1457,13 +1451,13 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
- hidden = mlx.Mul(mlx.GELUApprox(gate), up)
+ hidden = mlx.GeGLU(gate, up)
} else {
gate := mlx.GatherQMM(xFlat, m.GateWeightQ, m.GateScales, m.GateBiases,
nil, idxFlat, true, m.GateGroupSize, m.GateBits, m.QuantMode, doSort)
up := mlx.GatherQMM(xFlat, m.UpWeightQ, m.UpScales, m.UpBiases,
nil, idxFlat, true, m.UpGroupSize, m.UpBits, m.QuantMode, doSort)
- hidden = mlx.Mul(mlx.GELUApprox(gate), up)
+ hidden = mlx.GeGLU(gate, up)
}
downMode := m.DownQuantMode
if downMode == "" {
@@ -1482,11 +1476,11 @@ func (m *MoEBlock) Forward(x *mlx.Array, scores, inds *mlx.Array, cfg *TextConfi
up := mlx.SliceStartStop(gateUp,
[]int32{0, 0, 0, mid},
[]int32{int32(guDims[0]), int32(guDims[1]), int32(guDims[2]), int32(guDims[len(guDims)-1])})
- hidden = mlx.Mul(mlx.GELUApprox(gate), up)
+ hidden = mlx.GeGLU(gate, up)
} else {
gate := mlx.GatherMM(xFlat, m.GateWeight, nil, idxFlat, doSort)
up := mlx.GatherMM(xFlat, m.UpWeight, nil, idxFlat, doSort)
- hidden = mlx.Mul(mlx.GELUApprox(gate), up)
+ hidden = mlx.GeGLU(gate, up)
}
down = mlx.GatherMM(hidden, m.DownWeight, nil, idxFlat, doSort)
}
왜 이게 좋은가: 성능 향상과 일반적 교훈
이 PR의 최적화는 딥러닝 모델 추론 성능 향상에 있어 fused operations의 중요성을 잘 보여줍니다.
성능 수치
PR 설명에 따르면, 이 최적화를 통해 Gemma4 모델의 prefill 및 gen 성능이 다음과 같이 향상되었습니다.
| Size | Metric | Baseline | New (compiled) | Δ |
|---|---|---|---|---|
| E2B | prefill | 18,777 | 21,901 | +16.6% |
| E2B | gen | 153.8 | 176.9 | +15.0% |
| E4B | prefill | 6,980 | 8,086 | +15.8% |
| E4B | gen | 99.1 | 110.6 | +11.6% |
| 26B | prefill | 3,957 | 4,372 | +10.5% |
| 26B | gen | 101.3 | 107.5 | +6.1% |
| 31B | prefill | 531 | 593 | +11.7% |
| 31B | gen | 21.4 | 22.4 | +4.8% |
가장 작은 모델인 E2B(2Billion)의 prefill에서 16.6%라는 상당한 성능 향상을 보였으며, 다른 모델 크기에서도 일관되게 4.8%에서 15.8%까지 성능이 개선되었습니다. 이는 특히 prefill 단계에서 더 큰 이점을 제공하는데, 이는 prefill이 더 많은 연산을 한 번에 처리하는 경향이 있기 때문입니다.
Fused Operations의 이점
- 메모리 대역폭 절감: 여러 연산을 개별적으로 실행할 경우, 각 연산의 중간 결과가 메인 메모리(또는 GPU 메모리)에 쓰여지고 다음 연산을 위해 다시 읽혀야 합니다. Fused operation은 이러한 중간 결과를 레지스터나 캐시에 유지하여 메모리 접근 횟수를 줄이고 메모리 대역폭 사용을 최적화합니다. 이는 특히 메모리 대역폭이 병목이 되는 경우 큰 성능 이점을 제공합니다.
- 커널 실행 오버헤드 감소: GPU와 같은 가속기에서 각 연산은 별도의 커널(kernel)로 실행됩니다. 커널 실행에는 드라이버 호출, 컨텍스트 스위칭 등 일정량의 오버헤드가 발생합니다. 여러 연산을 하나의 fused kernel로 묶으면 이러한 커널 실행 오버헤드를 한 번으로 줄일 수 있습니다.
- 컴파일러 최적화 기회 증가: MLX의
Compile2와 같은 컴파일러는 여러 연산을 한 번에 볼 수 있게 되어, 전체 연산 그래프에 대한 더 깊은 최적화를 수행할 수 있습니다. 예를 들어, 불필요한 연산을 제거하거나, 데이터 흐름을 재구성하여 더 효율적인 병렬 처리를 가능하게 할 수 있습니다. - 에너지 효율성: 메모리 접근 감소와 연산 효율성 증가는 전력 소비 감소로 이어져, 온디바이스 AI 환경에서 특히 중요한 에너지 효율성 개선에 기여합니다.
일반적 교훈
이 PR은 딥러닝 프레임워크를 활용하여 모델을 최적화할 때 다음과 같은 중요한 교훈을 제공합니다.
- 하위 레벨 최적화의 중요성: 고수준의 모델 아키텍처 변경 없이도, 하위 레벨의 연산 퓨징과 같은 최적화를 통해 상당한 성능 향상을 이끌어낼 수 있습니다.
- 프레임워크 기능 활용: MLX의
Compile2와 같은 프레임워크가 제공하는 컴파일러 및 최적화 기능을 적극적으로 활용하는 것이 중요합니다. PyTorch의torch.compile이나 TensorFlow의 XLA 컴파일러도 유사한 목적을 가집니다. - 벤치마킹의 중요성: 최적화 전후의 성능을 정확하게 측정하는 벤치마킹은 변경 사항의 효과를 검증하고, 어떤 부분이 병목인지 파악하는 데 필수적입니다.
- 도메인 지식:
GeGLU와LogitSoftcap과 같이 특정 모델(Gemma-family)에서 자주 사용되는 패턴을 식별하고 이를 최적화하는 것은 도메인 지식이 뒷받침될 때 가능합니다.
결론적으로, 이 PR은 MLX 백엔드를 사용하는 Ollama의 Gemma4 모델에 대한 매우 효과적인 최적화 사례를 보여줍니다. Fused operations를 통해 연산 효율성을 높이고, 메모리 접근을 줄여 전반적인 추론 성능을 크게 향상시켰습니다. 이는 온디바이스 LLM의 실용성을 높이는 데 기여하는 중요한 진전입니다.
References
- MLX Documentation
- MLX Array API
- GELU (Gaussian Error Linear Unit) Activation Function
- PyTorch torch.compile
참고 자료
- https://ml-explore.github.io/mlx/build/html/index.html
- https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.array.html
- https://arxiv.org/abs/1606.08415
- https://pytorch.org/docs/stable/generated/torch.compile.html
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [ollama] Ollama의 Gemma 4 모델 Flash Attention 비활성화: 성능 회귀(Regression) 해결 사례
- [vllm] vLLM Gemma4 모델의 GPU/CPU 동기화 병목 현상 해결하기: non_blocking 전송의 중요성
- [sglang] sglang, AMD MI35x 환경에서 GLM-5-MXFP4 모델의 성능 및 정확도 테스트 추가
- [SGLang] Hardware Backends: MLX, NPU, XPU 하드웨어 추상화
- [SGLang] Activation Functions: SiLU, GELU 커스텀 구현
PR Analysis 의 다른글
- 이전글 [sglang] SGLang, Diffusion 모델의 RL 기반 후처리 최적화를 위한 새로운 Rollout API 및 정밀도 개선
- 현재글 : [ollama] Ollama MLX Gemma4 성능 최적화: Fused Operations를 통한 효율성 증대
- 다음글 [cpython] CPython의 BINARY_OP_EXTEND 최적화: 타입 정보 전파를 통한 성능 개선
댓글