[Ray] NIXL 메타데이터 캐싱으로 GPU 텐서 전송 등록/해제 오버헤드 제거
PR 링크: ray-project/ray#60689 상태: Merged | 변경: +138 / -41
들어가며
분산 학습에서 weight sync는 동일한 GPU 메모리 버퍼에 반복적으로 가중치를 쓰고 전송합니다. Ray의 NIXL(Network Interface eXpress Layer) 텐서 전송에서는 ray.put() 시 메모리를 등록하고, object ref가 스코프 밖으로 나가면 해제합니다. 매번 같은 버퍼를 등록/해제하는 것은 불필요한 오버헤드입니다.
핵심 코드 분석
새 API: register_nixl_memory
def register_nixl_memory(self, tensor: "torch.Tensor") -> None:
"""텐서의 메모리를 NIXL에 등록하고 참조 카운트를 증가시켜
메모리 영역이 절대 해제되지 않도록 합니다."""
self._add_tensor_descs([tensor])
수신 측 텐서 등록과 해제를 참조 카운트로 관리
Before:
def recv_multiple_tensors(self, tensors, ...):
local_descs = nixl_agent.register_memory(tensors)
# ... 전송 수행 ...
# finally에서 항상 해제
if local_descs:
nixl_agent.deregister_memory(local_descs)
After:
def recv_multiple_tensors(self, tensors, ...):
self._add_tensor_descs(tensors) # 참조 카운트 기반 등록
local_xfer_descs = nixl_agent.get_xfer_descs(tensors)
# ... 전송 수행 ...
# finally에서 참조 카운트 감소, 0이면 해제
for tensor in tensors:
key = tensor.untyped_storage().data_ptr()
tensor_desc = self._tensor_desc_cache[key]
tensor_desc.metadata_count -= 1
if tensor_desc.metadata_count == 0:
nixl_agent.deregister_memory(tensor_desc.reg_desc)
self._tensor_desc_cache.pop(key)
스레드 안전성 강화
# threading.Lock → threading.RLock (재진입 가능)
self._cache_lock = threading.RLock()
# _get_meta, _put_meta에도 lock 추가
def _get_meta(self, object_id: str):
with self._cache_lock:
return self._managed_meta_nixl.get(object_id)
왜 이게 좋은가
- Weight sync처럼 동일 버퍼를 반복 전송하는 패턴에서 매번 NIXL 메모리 등록/해제가 발생하지 않습니다
- 사용자가
register_nixl_memory()를 한 번 호출하면 해당 텐서의 메모리 영역이 프로세스 수명 동안 고정됩니다 - 참조 카운트 기반이므로 같은 underlying storage를 공유하는 여러 텐서에 대해 안전하게 동작합니다
RLock으로 전환하여_add_tensor_descs내부에서 다시 lock을 잡는 재진입 시나리오도 안전합니다
참고 자료
관련 포스트
PR Analysis 의 다른글
- 이전글 [sglang] SGLang, Helios 모델 통합으로 실시간 장편 비디오 생성의 새로운 지평을 열다
- 현재글 : [Ray] NIXL 메타데이터 캐싱으로 GPU 텐서 전송 등록/해제 오버헤드 제거
- 다음글 [Open WebUI] 저장 버튼 스피너 인라인 레이아웃 수정
댓글