본문으로 건너뛰기

[sglang] 멀티프로세스 JIT 컴파일로 Custom All Reduce 테스트 속도 향상

PR 링크: sgl-project/sglang#21483 상태: Merged | 변경: +41 / -19

들어가며

SGLang의 custom all reduce 커널은 JIT(Just-in-Time)으로 컴파일됩니다. 테스트에서는 다양한 dtype과 world_size 조합에 대해 커널이 필요한데, 각 조합마다 별도 컴파일이 필요합니다. 기존에는 테스트 실행 중 필요할 때마다 순차적으로 컴파일했지만, 이번 PR은 모든 조합을 사전에 병렬 컴파일하여 테스트 시간을 약 40% 단축합니다.

핵심 코드 분석

1. 병렬 사전 컴파일

Before:

@pytest.mark.parametrize("nproc", [2, 3, 4, 5, 6, 7, 8])
def test_custom_allreduce(nproc: int) -> None:
    device_count = torch.cuda.device_count()
    if device_count < nproc:
        pytest.skip(...)
    run_torchrun(nproc)

After:

def _compile_one(dtype: torch.dtype, world_size: int):
    _jit_custom_all_reduce_push_module(dtype, world_size)
    _jit_custom_all_reduce_pull_module(dtype, world_size)

def _precompile_kernels() -> None:
    process_map: Dict[Tuple[torch.dtype, int], mp.Process] = {}
    COMPILE_SPACE = itertools.product(TEST_DTYPES, [2, 3, 4, 5, 6, 7, 8])
    mp.set_start_method("spawn")
    for config in COMPILE_SPACE:
        process_map[config] = mp.Process(target=_compile_one, args=config)
    for process in process_map.values():
        process.start()
    for (dtype, world_size), process in process_map.items():
        process.join()
        if process.exitcode != 0:
            raise RuntimeError(f"Compilation failed {world_size=} {dtype=}")

@pytest.mark.parametrize("nproc", [1, 2, 3, 4, 5, 6, 7, 8])
def test_custom_allreduce(nproc: int) -> None:
    if nproc == 1:  # 특별 케이스: 사전 컴파일
        return _precompile_kernels()
    # ...

nproc=1 파라미터를 추가하여 pytest의 첫 번째 테스트 케이스에서 모든 커널을 병렬 컴파일합니다. 각 (dtype, world_size) 조합마다 별도 프로세스가 생성되어 동시에 컴파일됩니다.

2. 불필요한 동기화 제거

Before:

torch.cuda.synchronize()
nccl_group.barrier().wait()

After:

# 삭제됨 - all_reduce 자체가 동기화를 보장

3. JIT 모듈 이름 충돌 수정

# Before: 같은 이름으로 push/pull 모듈 로드
return load_jit("custom_all_reduce", ...)

# After: 고유한 이름 사용
return load_jit("custom_all_reduce_pull", ...)  # pull
return load_jit("custom_all_reduce_push", ...)  # push

왜 이게 좋은가

  1. CI 시간 단축: est_time이 500초에서 300초로 감소합니다.
  2. 컴파일 캐시 활용: 사전 컴파일 결과가 캐시되어, 이후 torchrun 프로세스에서 재컴파일이 불필요합니다.
  3. 이름 충돌 방지: push/pull 모듈의 JIT 캐시 키가 동일했던 잠재적 버그도 함께 수정했습니다.

정리

pytest 파라미터를 활용한 사전 컴파일이라는 창의적인 접근으로 CI 시간을 단축한 PR입니다. JIT 컴파일은 한 번만 하면 캐시되므로, 테스트 전에 모든 변형을 병렬로 컴파일하는 것이 효과적입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글