본문으로 건너뛰기

[sglang] VLM ShmPointerMMData 최적화: multi-pickle 안전성과 deferred unwrap

PR 링크: sgl-project/sglang#21465 상태: Merged | 변경: +55 / -47

들어가며

SGLang에서 VLM(Vision-Language Model)의 멀티모달 입력(이미지 feature 등)은 프로세스 간에 공유 메모리(shared memory)를 통해 전달됩니다. ShmPointerMMData 클래스가 이 역할을 담당하는데, 기존 구현에는 두 가지 문제가 있었습니다: (1) __getstate__에서 shm이 없을 때 텐서를 다시 공유 메모리에 올리는 복잡한 재생성 로직, (2) __setstate__에서 즉시 clone() + unlink()하여 broadcast 시 다른 rank가 접근하기 전에 shm이 사라질 수 있는 문제입니다.

핵심 코드 분석

1. 생성자 단순화

Before:

def __init__(self, tensor: torch.Tensor):
    self.cpu_tensor = tensor.cpu().contiguous()
    nbytes = self.cpu_tensor.numel() * self.cpu_tensor.element_size()
    self.shm = shared_memory.SharedMemory(create=True, size=nbytes)
    try:
        shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf)
        shm_view[:] = self.cpu_tensor.view(torch.uint8).numpy().flatten()
    finally:
        self.shm.close()

After:

def __init__(self, tensor: torch.Tensor):
    if not tensor.is_cpu:
        tensor = tensor.cpu()
    if not tensor.is_contiguous():
        tensor = tensor.contiguous()
    shm = shared_memory.SharedMemory(create=True, size=nbytes)
    try:
        dst = torch.frombuffer(shm.buf, dtype=torch.uint8)
        dst.copy_(tensor.view(torch.uint8).reshape(-1))
    except BaseException:
        shm.close()
        shm.unlink()
        raise
    self.shm_name = shm.name
    shm.close()
    self._shm_handle = None

numpy 대신 torch.frombuffer로 직접 복사하고, 예외 시 shm을 정리하는 안전한 패턴을 적용했습니다.

2. Deferred unwrap 패턴

Before (__setstate__):

def __setstate__(self, state):
    shm_handle = shared_memory.SharedMemory(name=self.shm_name)
    try:
        self.tensor = torch.frombuffer(shm_handle.buf, dtype=self.dtype).reshape(self.shape).clone()
    finally:
        shm_handle.close()
        shm_handle.unlink()  # 즉시 삭제!

After:

def __setstate__(self, state):
    self._shm_handle = shared_memory.SharedMemory(name=self.shm_name)
    self.tensor = torch.frombuffer(self._shm_handle.buf, dtype=self.dtype).reshape(self.shape)
    # clone과 unlink는 materialize()에서 수행

def materialize(self) -> torch.Tensor:
    tensor = self.tensor.clone()
    if self._shm_handle is not None:
        self._shm_handle.close()
        try:
            self._shm_handle.unlink()
        except FileNotFoundError:
            pass  # 다른 rank가 이미 unlink

3. Scheduler에서 broadcast 이후 unwrap

Before:

recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_req = unwrap_shm_features(recv_req)  # broadcast 전에 unwrap

After:

recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
# broadcast 완료 후 unwrap
for req in recv_reqs:
    unwrap_shm_features(req)

왜 이게 좋은가

  1. Race condition 해결: broadcast 전에 shm을 unlink하면 다른 rank가 데이터를 읽지 못할 수 있습니다.
  2. Zero-copy 가능: materialize() 전까지 shm 버퍼를 직접 참조하여 불필요한 복사를 피합니다.
  3. 에러 안전성: 생성자에서 예외 발생 시 shm 누수를 방지합니다.

정리

IPC에서 공유 메모리의 생명주기 관리는 매우 중요합니다. "생성 -> pickle -> unpickle -> 사용 -> 해제" 순서를 명확히 하고, multi-rank 환경에서의 동시성을 고려한 개선입니다.

참고 자료

⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글