본문으로 건너뛰기

[Ray] NIXL 메타데이터 캐싱으로 GPU 텐서 전송 등록/해제 오버헤드 제거

PR 링크: ray-project/ray#60689 상태: Merged | 변경: +138 / -41

들어가며

분산 학습에서 weight sync는 동일한 GPU 메모리 버퍼에 반복적으로 가중치를 쓰고 전송합니다. Ray의 NIXL(Network Interface eXpress Layer) 텐서 전송에서는 ray.put() 시 메모리를 등록하고, object ref가 스코프 밖으로 나가면 해제합니다. 매번 같은 버퍼를 등록/해제하는 것은 불필요한 오버헤드입니다.

핵심 코드 분석

새 API: register_nixl_memory

def register_nixl_memory(self, tensor: "torch.Tensor") -> None:
    """텐서의 메모리를 NIXL에 등록하고 참조 카운트를 증가시켜
    메모리 영역이 절대 해제되지 않도록 합니다."""
    self._add_tensor_descs([tensor])

수신 측 텐서 등록과 해제를 참조 카운트로 관리

Before:

def recv_multiple_tensors(self, tensors, ...):
    local_descs = nixl_agent.register_memory(tensors)
    # ... 전송 수행 ...
    # finally에서 항상 해제
    if local_descs:
        nixl_agent.deregister_memory(local_descs)

After:

def recv_multiple_tensors(self, tensors, ...):
    self._add_tensor_descs(tensors)  # 참조 카운트 기반 등록
    local_xfer_descs = nixl_agent.get_xfer_descs(tensors)
    # ... 전송 수행 ...
    # finally에서 참조 카운트 감소, 0이면 해제
    for tensor in tensors:
        key = tensor.untyped_storage().data_ptr()
        tensor_desc = self._tensor_desc_cache[key]
        tensor_desc.metadata_count -= 1
        if tensor_desc.metadata_count == 0:
            nixl_agent.deregister_memory(tensor_desc.reg_desc)
            self._tensor_desc_cache.pop(key)

스레드 안전성 강화

# threading.Lock → threading.RLock (재진입 가능)
self._cache_lock = threading.RLock()

# _get_meta, _put_meta에도 lock 추가
def _get_meta(self, object_id: str):
    with self._cache_lock:
        return self._managed_meta_nixl.get(object_id)

왜 이게 좋은가

  • Weight sync처럼 동일 버퍼를 반복 전송하는 패턴에서 매번 NIXL 메모리 등록/해제가 발생하지 않습니다
  • 사용자가 register_nixl_memory()를 한 번 호출하면 해당 텐서의 메모리 영역이 프로세스 수명 동안 고정됩니다
  • 참조 카운트 기반이므로 같은 underlying storage를 공유하는 여러 텐서에 대해 안전하게 동작합니다
  • RLock으로 전환하여 _add_tensor_descs 내부에서 다시 lock을 잡는 재진입 시나리오도 안전합니다

참고 자료

댓글

관련 포스트

PR Analysis 의 다른글