[SGLang] Shared Memory Broadcast: 프로세스 간 고속 통신
들어가며
SGLang의 TP 그룹에서 rank 0이 스케줄링 결정을 내리면, 이를 다른 rank들에게 빠르게 전파해야 한다. GPU 텐서가 아닌 Python 객체(배치 메타데이터, 스케줄링 정보)를 전달하는 상황이므로 NCCL은 부적합하다.
shm_broadcast.py는 공유 메모리 링 버퍼와 ZMQ Pub/Sub를 결합한 고속 브로드캐스트 시스템을 구현한다. 같은 노드 내 프로세스 간에는 공유 메모리로, 원격 노드에는 ZMQ로 데이터를 전달한다.
구조도
Writer (rank 0)
│
├── 로컬 리더 경로 (같은 노드)
│ ├── ShmRingBuffer (공유 메모리)
│ │ ├── chunk0 | chunk1 | ... | chunkN (데이터)
│ │ └── meta0 | meta1 | ... | metaN (상태 플래그)
│ │
│ └── ZMQ XPUB (overflow 시 대규모 데이터)
│ ├── Reader 0 (SUB)
│ └── Reader 1 (SUB)
│
└── 원격 리더 경로 (다른 노드)
└── ZMQ XPUB
├── Remote Reader 0 (SUB)
└── Remote Reader 1 (SUB)
핵심 코드 분석
ShmRingBuffer 메모리 레이아웃
class ShmRingBuffer:
def __init__(self, n_reader, max_chunk_bytes, max_chunks, name=None):
self.n_reader = n_reader
self.metadata_size = 1 + n_reader
self.max_chunk_bytes = max_chunk_bytes
self.max_chunks = max_chunks
self.total_bytes_of_buffer = (
self.max_chunk_bytes + self.metadata_size) * self.max_chunks
self.data_offset = 0
self.metadata_offset = self.max_chunk_bytes * self.max_chunks
메모리 구조는 다음과 같다:
+-------------------------------+----------------------------------------+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+-------------------------------+----------------------------------------+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |
메타데이터 상태 머신
각 청크의 메타데이터는 1 + n_reader 바이트로, 첫 바이트가 written 플래그이고 나머지가 각 리더의 read 플래그다.
+--------------+--------------+--------------+-----+--------------+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+
상태 전이:
(1) 0???...??? : 미작성 → 쓰기 가능
(2) 1000...000 : 방금 작성됨 → 읽기 가능
(3) 1???...??? : 일부 리더가 읽음 → 미읽은 리더는 읽기 가능
(4) 1111...111 : 모든 리더가 읽음 → 쓰기 가능
Writer: acquire_write
@contextmanager
def acquire_write(self):
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
os.sched_yield() # 모든 리더가 읽을 때까지 대기
continue
# 쓰기 가능 상태 진입
metadata_buffer[0] = 0 # written 플래그 해제
with self.buffer.get_data(self.current_idx) as buf:
yield buf # 호출자가 데이터 기록
# 순서 중요: 먼저 read 플래그 리셋, 그 다음 written 플래그 설정
for i in range(1, self.buffer.n_reader + 1):
metadata_buffer[i] = 0
metadata_buffer[0] = 1
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break
순서가 핵심이다. Reader 플래그를 먼저 0으로 리셋한 후 written 플래그를 1로 설정한다. 역순이면 리더가 중간 상태(case 3)를 잘못 읽을 수 있다.
Reader: acquire_read
@contextmanager
def acquire_read(self):
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
os.sched_yield() # 데이터가 없거나 이미 읽었으면 대기
continue
with self.buffer.get_data(self.current_idx) as buf:
yield buf
metadata_buffer[self.local_reader_rank + 1] = 1 # 읽음 표시
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break
오버플로 처리
데이터가 max_chunk_bytes를 초과하면 ZMQ로 대체한다.
def enqueue(self, obj):
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write() as buf:
buf[0] = 1 # overflow 플래그
self.local_socket.send(serialized_obj)
else:
with self.acquire_write() as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)
리더 측에서는 overflow 플래그를 확인하여 공유 메모리 또는 ZMQ에서 데이터를 읽는다:
def dequeue(self):
if self._is_local_reader:
with self.acquire_read() as buf:
overflow = buf[0] == 1
if not overflow:
obj = pickle.loads(buf[1:])
if overflow:
recv = self.local_socket.recv()
obj = pickle.loads(recv)
elif self._is_remote_reader:
recv = self.remote_socket.recv()
obj = pickle.loads(recv)
return obj
ProcessGroup에서 자동 생성
@staticmethod
def create_from_process_group(pg, max_chunk_bytes, max_chunks, writer_rank=0):
group_rank = dist.get_rank(pg)
status = in_the_same_node_as(pg, source_rank=writer_rank)
same_node_ranks = [i for i, s in enumerate(status) if s]
n_reader = group_world_size - 1
n_local_reader = len(same_node_ranks) - 1
if group_rank == writer_rank:
buffer_io = MessageQueue(
n_reader=n_reader, n_local_reader=n_local_reader, ...)
handle = buffer_io.export_handle()
dist.broadcast_object_list([handle], src=global_ranks[writer_rank], group=pg)
else:
recv = [None]
dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg)
buffer_io = MessageQueue.create_from_handle(recv[0], group_rank)
buffer_io.wait_until_ready()
return buffer_io
in_the_same_node_as()로 같은 노드의 랭크를 판별하여 로컬/원격 리더를 자동 구분한다.
비교: Shared Memory vs ZMQ
| 특성 | ShmRingBuffer | ZMQ Pub/Sub |
|---|---|---|
| 지연시간 | 수 마이크로초 | 수십 마이크로초 |
| 크기 제한 | max_chunk_bytes | 무제한 |
| 노드 간 | 불가 | 가능 |
| 동기화 | 메타데이터 플래그 | 소켓 프로토콜 |
설계 근거
왜 락을 사용하지 않는가? 링 버퍼의 Writer는 하나뿐이고, 각 Reader는 자신의 플래그만 수정한다. 원자적 바이트 쓰기와 정확한 순서만으로 동기화가 가능하므로 뮤텍스/세마포어가 불필요하다.
왜 pickle을 사용하는가? 전달되는 객체가 스케줄링 메타데이터(배치 정보, 토큰 ID 등)로, Python 네이티브 직렬화가 가장 범용적이다. pickle.HIGHEST_PROTOCOL로 직렬화 속도를 최적화한다.
관련 포스트
참고
관련 포스트
SGLang 의 다른글
- 이전글 [SGLang] Ray 통합: 분산 엔진과 스케줄러 액터
- 현재글 : [SGLang] Shared Memory Broadcast: 프로세스 간 고속 통신
- 다음글 [SGLang] 하드웨어별 통신: HPU, NPU, XPU 커뮤니케이터
댓글