본문으로 건너뛰기

[SGLang] Shared Memory Broadcast: 프로세스 간 고속 통신

들어가며

SGLang의 TP 그룹에서 rank 0이 스케줄링 결정을 내리면, 이를 다른 rank들에게 빠르게 전파해야 한다. GPU 텐서가 아닌 Python 객체(배치 메타데이터, 스케줄링 정보)를 전달하는 상황이므로 NCCL은 부적합하다.

shm_broadcast.py공유 메모리 링 버퍼ZMQ Pub/Sub를 결합한 고속 브로드캐스트 시스템을 구현한다. 같은 노드 내 프로세스 간에는 공유 메모리로, 원격 노드에는 ZMQ로 데이터를 전달한다.

구조도

Writer (rank 0)
    │
    ├── 로컬 리더 경로 (같은 노드)
    │     ├── ShmRingBuffer (공유 메모리)
    │     │     ├── chunk0 | chunk1 | ... | chunkN  (데이터)
    │     │     └── meta0  | meta1  | ... | metaN   (상태 플래그)
    │     │
    │     └── ZMQ XPUB (overflow 시 대규모 데이터)
    │           ├── Reader 0 (SUB)
    │           └── Reader 1 (SUB)
    │
    └── 원격 리더 경로 (다른 노드)
          └── ZMQ XPUB
                ├── Remote Reader 0 (SUB)
                └── Remote Reader 1 (SUB)

핵심 코드 분석

ShmRingBuffer 메모리 레이아웃

class ShmRingBuffer:
    def __init__(self, n_reader, max_chunk_bytes, max_chunks, name=None):
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
        self.total_bytes_of_buffer = (
            self.max_chunk_bytes + self.metadata_size) * self.max_chunks
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

메모리 구조는 다음과 같다:

+-------------------------------+----------------------------------------+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+-------------------------------+----------------------------------------+
| max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

메타데이터 상태 머신

각 청크의 메타데이터는 1 + n_reader 바이트로, 첫 바이트가 written 플래그이고 나머지가 각 리더의 read 플래그다.

+--------------+--------------+--------------+-----+--------------+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+

상태 전이:
(1) 0???...??? : 미작성 → 쓰기 가능
(2) 1000...000 : 방금 작성됨 → 읽기 가능
(3) 1???...??? : 일부 리더가 읽음 → 미읽은 리더는 읽기 가능
(4) 1111...111 : 모든 리더가 읽음 → 쓰기 가능

Writer: acquire_write

@contextmanager
def acquire_write(self):
    while True:
        with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
            read_count = sum(metadata_buffer[1:])
            written_flag = metadata_buffer[0]
            if written_flag and read_count != self.buffer.n_reader:
                os.sched_yield()  # 모든 리더가 읽을 때까지 대기
                continue

            # 쓰기 가능 상태 진입
            metadata_buffer[0] = 0  # written 플래그 해제
            with self.buffer.get_data(self.current_idx) as buf:
                yield buf  # 호출자가 데이터 기록

            # 순서 중요: 먼저 read 플래그 리셋, 그 다음 written 플래그 설정
            for i in range(1, self.buffer.n_reader + 1):
                metadata_buffer[i] = 0
            metadata_buffer[0] = 1
            self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
            break

순서가 핵심이다. Reader 플래그를 먼저 0으로 리셋한 후 written 플래그를 1로 설정한다. 역순이면 리더가 중간 상태(case 3)를 잘못 읽을 수 있다.

Reader: acquire_read

@contextmanager
def acquire_read(self):
    while True:
        with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
            read_flag = metadata_buffer[self.local_reader_rank + 1]
            written_flag = metadata_buffer[0]
            if not written_flag or read_flag:
                os.sched_yield()  # 데이터가 없거나 이미 읽었으면 대기
                continue
            with self.buffer.get_data(self.current_idx) as buf:
                yield buf
            metadata_buffer[self.local_reader_rank + 1] = 1  # 읽음 표시
            self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
            break

오버플로 처리

데이터가 max_chunk_bytes를 초과하면 ZMQ로 대체한다.

def enqueue(self, obj):
    serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
    if self.n_local_reader > 0:
        if len(serialized_obj) >= self.buffer.max_chunk_bytes:
            with self.acquire_write() as buf:
                buf[0] = 1  # overflow 플래그
            self.local_socket.send(serialized_obj)
        else:
            with self.acquire_write() as buf:
                buf[0] = 0  # not overflow
                buf[1:len(serialized_obj) + 1] = serialized_obj
    if self.n_remote_reader > 0:
        self.remote_socket.send(serialized_obj)

리더 측에서는 overflow 플래그를 확인하여 공유 메모리 또는 ZMQ에서 데이터를 읽는다:

def dequeue(self):
    if self._is_local_reader:
        with self.acquire_read() as buf:
            overflow = buf[0] == 1
            if not overflow:
                obj = pickle.loads(buf[1:])
        if overflow:
            recv = self.local_socket.recv()
            obj = pickle.loads(recv)
    elif self._is_remote_reader:
        recv = self.remote_socket.recv()
        obj = pickle.loads(recv)
    return obj

ProcessGroup에서 자동 생성

@staticmethod
def create_from_process_group(pg, max_chunk_bytes, max_chunks, writer_rank=0):
    group_rank = dist.get_rank(pg)
    status = in_the_same_node_as(pg, source_rank=writer_rank)
    same_node_ranks = [i for i, s in enumerate(status) if s]
    n_reader = group_world_size - 1
    n_local_reader = len(same_node_ranks) - 1

    if group_rank == writer_rank:
        buffer_io = MessageQueue(
            n_reader=n_reader, n_local_reader=n_local_reader, ...)
        handle = buffer_io.export_handle()
        dist.broadcast_object_list([handle], src=global_ranks[writer_rank], group=pg)
    else:
        recv = [None]
        dist.broadcast_object_list(recv, src=global_ranks[writer_rank], group=pg)
        buffer_io = MessageQueue.create_from_handle(recv[0], group_rank)
    buffer_io.wait_until_ready()
    return buffer_io

in_the_same_node_as()로 같은 노드의 랭크를 판별하여 로컬/원격 리더를 자동 구분한다.

비교: Shared Memory vs ZMQ

특성 ShmRingBuffer ZMQ Pub/Sub
지연시간 수 마이크로초 수십 마이크로초
크기 제한 max_chunk_bytes 무제한
노드 간 불가 가능
동기화 메타데이터 플래그 소켓 프로토콜

설계 근거

왜 락을 사용하지 않는가? 링 버퍼의 Writer는 하나뿐이고, 각 Reader는 자신의 플래그만 수정한다. 원자적 바이트 쓰기와 정확한 순서만으로 동기화가 가능하므로 뮤텍스/세마포어가 불필요하다.

왜 pickle을 사용하는가? 전달되는 객체가 스케줄링 메타데이터(배치 정보, 토큰 ID 등)로, Python 네이티브 직렬화가 가장 범용적이다. pickle.HIGHEST_PROTOCOL로 직렬화 속도를 최적화한다.

관련 포스트

참고

댓글

관련 포스트

SGLang 의 다른글