[sglang] 멀티프로세스 JIT 컴파일로 Custom All Reduce 테스트 속도 향상
PR 링크: sgl-project/sglang#21483 상태: Merged | 변경: +41 / -19
들어가며
SGLang의 custom all reduce 커널은 JIT(Just-in-Time)으로 컴파일됩니다. 테스트에서는 다양한 dtype과 world_size 조합에 대해 커널이 필요한데, 각 조합마다 별도 컴파일이 필요합니다. 기존에는 테스트 실행 중 필요할 때마다 순차적으로 컴파일했지만, 이번 PR은 모든 조합을 사전에 병렬 컴파일하여 테스트 시간을 약 40% 단축합니다.
핵심 코드 분석
1. 병렬 사전 컴파일
Before:
@pytest.mark.parametrize("nproc", [2, 3, 4, 5, 6, 7, 8])
def test_custom_allreduce(nproc: int) -> None:
device_count = torch.cuda.device_count()
if device_count < nproc:
pytest.skip(...)
run_torchrun(nproc)
After:
def _compile_one(dtype: torch.dtype, world_size: int):
_jit_custom_all_reduce_push_module(dtype, world_size)
_jit_custom_all_reduce_pull_module(dtype, world_size)
def _precompile_kernels() -> None:
process_map: Dict[Tuple[torch.dtype, int], mp.Process] = {}
COMPILE_SPACE = itertools.product(TEST_DTYPES, [2, 3, 4, 5, 6, 7, 8])
mp.set_start_method("spawn")
for config in COMPILE_SPACE:
process_map[config] = mp.Process(target=_compile_one, args=config)
for process in process_map.values():
process.start()
for (dtype, world_size), process in process_map.items():
process.join()
if process.exitcode != 0:
raise RuntimeError(f"Compilation failed {world_size=} {dtype=}")
@pytest.mark.parametrize("nproc", [1, 2, 3, 4, 5, 6, 7, 8])
def test_custom_allreduce(nproc: int) -> None:
if nproc == 1: # 특별 케이스: 사전 컴파일
return _precompile_kernels()
# ...
nproc=1 파라미터를 추가하여 pytest의 첫 번째 테스트 케이스에서 모든 커널을 병렬 컴파일합니다. 각 (dtype, world_size) 조합마다 별도 프로세스가 생성되어 동시에 컴파일됩니다.
2. 불필요한 동기화 제거
Before:
torch.cuda.synchronize()
nccl_group.barrier().wait()
After:
# 삭제됨 - all_reduce 자체가 동기화를 보장
3. JIT 모듈 이름 충돌 수정
# Before: 같은 이름으로 push/pull 모듈 로드
return load_jit("custom_all_reduce", ...)
# After: 고유한 이름 사용
return load_jit("custom_all_reduce_pull", ...) # pull
return load_jit("custom_all_reduce_push", ...) # push
왜 이게 좋은가
- CI 시간 단축: est_time이 500초에서 300초로 감소합니다.
- 컴파일 캐시 활용: 사전 컴파일 결과가 캐시되어, 이후 torchrun 프로세스에서 재컴파일이 불필요합니다.
- 이름 충돌 방지: push/pull 모듈의 JIT 캐시 키가 동일했던 잠재적 버그도 함께 수정했습니다.
정리
pytest 파라미터를 활용한 사전 컴파일이라는 창의적인 접근으로 CI 시간을 단축한 PR입니다. JIT 컴파일은 한 번만 하면 캐시되므로, 테스트 전에 모든 변형을 병렬로 컴파일하는 것이 효과적입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ultralytics] v8.4.28 — autobatch를 데이터셋 크기로 제한하여 소규모 데이터셋 학습 안정화
- 현재글 : [sglang] 멀티프로세스 JIT 컴파일로 Custom All Reduce 테스트 속도 향상
- 다음글 [llm-compressor] GPTQ Block Quantization 지원
댓글