[sglang] VLM ShmPointerMMData 최적화: multi-pickle 안전성과 deferred unwrap
PR 링크: sgl-project/sglang#21465 상태: Merged | 변경: +55 / -47
들어가며
SGLang에서 VLM(Vision-Language Model)의 멀티모달 입력(이미지 feature 등)은 프로세스 간에 공유 메모리(shared memory)를 통해 전달됩니다. ShmPointerMMData 클래스가 이 역할을 담당하는데, 기존 구현에는 두 가지 문제가 있었습니다: (1) __getstate__에서 shm이 없을 때 텐서를 다시 공유 메모리에 올리는 복잡한 재생성 로직, (2) __setstate__에서 즉시 clone() + unlink()하여 broadcast 시 다른 rank가 접근하기 전에 shm이 사라질 수 있는 문제입니다.
핵심 코드 분석
1. 생성자 단순화
Before:
def __init__(self, tensor: torch.Tensor):
self.cpu_tensor = tensor.cpu().contiguous()
nbytes = self.cpu_tensor.numel() * self.cpu_tensor.element_size()
self.shm = shared_memory.SharedMemory(create=True, size=nbytes)
try:
shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf)
shm_view[:] = self.cpu_tensor.view(torch.uint8).numpy().flatten()
finally:
self.shm.close()
After:
def __init__(self, tensor: torch.Tensor):
if not tensor.is_cpu:
tensor = tensor.cpu()
if not tensor.is_contiguous():
tensor = tensor.contiguous()
shm = shared_memory.SharedMemory(create=True, size=nbytes)
try:
dst = torch.frombuffer(shm.buf, dtype=torch.uint8)
dst.copy_(tensor.view(torch.uint8).reshape(-1))
except BaseException:
shm.close()
shm.unlink()
raise
self.shm_name = shm.name
shm.close()
self._shm_handle = None
numpy 대신 torch.frombuffer로 직접 복사하고, 예외 시 shm을 정리하는 안전한 패턴을 적용했습니다.
2. Deferred unwrap 패턴
Before (__setstate__):
def __setstate__(self, state):
shm_handle = shared_memory.SharedMemory(name=self.shm_name)
try:
self.tensor = torch.frombuffer(shm_handle.buf, dtype=self.dtype).reshape(self.shape).clone()
finally:
shm_handle.close()
shm_handle.unlink() # 즉시 삭제!
After:
def __setstate__(self, state):
self._shm_handle = shared_memory.SharedMemory(name=self.shm_name)
self.tensor = torch.frombuffer(self._shm_handle.buf, dtype=self.dtype).reshape(self.shape)
# clone과 unlink는 materialize()에서 수행
def materialize(self) -> torch.Tensor:
tensor = self.tensor.clone()
if self._shm_handle is not None:
self._shm_handle.close()
try:
self._shm_handle.unlink()
except FileNotFoundError:
pass # 다른 rank가 이미 unlink
3. Scheduler에서 broadcast 이후 unwrap
Before:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_req = unwrap_shm_features(recv_req) # broadcast 전에 unwrap
After:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
# broadcast 완료 후 unwrap
for req in recv_reqs:
unwrap_shm_features(req)
왜 이게 좋은가
- Race condition 해결: broadcast 전에 shm을 unlink하면 다른 rank가 데이터를 읽지 못할 수 있습니다.
- Zero-copy 가능:
materialize()전까지 shm 버퍼를 직접 참조하여 불필요한 복사를 피합니다. - 에러 안전성: 생성자에서 예외 발생 시 shm 누수를 방지합니다.
정리
IPC에서 공유 메모리의 생명주기 관리는 매우 중요합니다. "생성 -> pickle -> unpickle -> 사용 -> 해제" 순서를 명확히 하고, multi-rank 환경에서의 동시성을 고려한 개선입니다.
참고 자료
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [CPython 3.14] OrderedDict.popitem() 메모리 누수 수정 (backport)
- 현재글 : [sglang] VLM ShmPointerMMData 최적화: multi-pickle 안전성과 deferred unwrap
- 다음글 [Ultralytics] MPS 디바이스에서 메모리 누수 방지를 위한 적극적 메모리 정리
댓글