[ray] Ray Data의 hash_partition 성능을 7배 향상시킨 최적화 전략
PR 링크: ray-project/ray#63498 상태: Merged | 변경: +18 / -14
들어가며
Ray Data는 대규모 데이터 처리를 위한 강력한 프레임워크입니다. 하지만 hash_partition 연산은 데이터셋의 크기가 커질수록 성능 병목을 유발하는 구간이었습니다. 기존 구현은 파티션 개수(N)와 행의 개수(R)에 대해 O(N * R)의 복잡도를 가지는 스캔을 수행하고, 불필요한 데이터 복사(defragmentation)를 강제하여 메모리 사용량과 처리 시간을 모두 낭비하고 있었습니다. 본 글에서는 최근 Ray PR에서 적용된 sort_indices와 zero-copy slice를 활용한 최적화 기법을 살펴봅니다.
코드 분석
기존 방식의 문제점
기존의 hash_partition은 다음과 같은 비효율적인 과정을 거쳤습니다.
# Before
indices = [np.where(partitions_array == p)[0] for p in range(num_partitions)]
table = try_combine_chunked_columns(table)
return {p: table.take(idx) for p, idx in enumerate(indices) if len(idx) > 0}
- O(N * R) 스캔:
np.where를 파티션 개수만큼 반복 호출하여 전체 데이터를 매번 스캔했습니다. - 불필요한 Defragmentation:
try_combine_chunked_columns를 통해 전체 테이블을 메모리에 복사하여 청크를 합쳤습니다. - 반복적인 Take 연산:
table.take를 N번 호출하며 매번 오버헤드가 발생했습니다.
최적화된 방식
새로운 구현은 정렬을 통해 데이터를 재배치하고, 슬라이싱을 통해 메모리 복사를 제거했습니다.
# After
sort_indices = pac.sort_indices(pyarrow.array(partitions_array))
counts = np.bincount(partitions_array, minlength=num_partitions)
offsets = np.zeros(num_partitions + 1, dtype=np.int64)
offsets[1:] = np.cumsum(counts)
sorted_table = take_table(table, sort_indices)
return {
p: sorted_table.slice(int(offsets[p]), int(counts[p]))
for p in range(num_partitions) if counts[p] > 0
}
- Radix Sort 활용:
pac.sort_indices를 사용하여 전체 데이터를 단 한 번의 O(R) 패스로 정렬했습니다. - 단일 Take 연산: 정렬된 인덱스를 사용하여 전체 테이블을 한 번만
take함으로써 오버헤드를 최소화했습니다. - Zero-copy Slicing: 정렬된 테이블에서 각 파티션의 시작점과 길이를 계산하여
slice()를 호출합니다. 이는 데이터를 복사하지 않고 뷰(view)만 생성하므로 메모리 사용량을 획기적으로 줄입니다.
왜 이게 좋은가
이번 최적화의 핵심은 '데이터를 물리적으로 복사하지 않고 정렬된 인덱스를 통해 논리적으로 분할'한 점입니다. 벤치마크 결과에 따르면, 1GB 데이터 기준 기존 대비 5~8배의 속도 향상을 보였으며, 특히 청크가 많은 데이터셋(K=256)에서는 메모리 피크 사용량이 약 40% 감소하는 성과를 거두었습니다.
일반적 교훈
- 반복적인 스캔을 피하라: 데이터 전체를 여러 번 훑는 대신, 정렬이나 해시를 통해 한 번의 패스로 처리할 방법을 찾으세요.
- Zero-copy를 활용하라: PyArrow와 같은 라이브러리에서
slice는 메모리 복사 없이 데이터를 분할할 수 있는 강력한 도구입니다. - Take 연산의 비용을 이해하라:
take연산은 청크가 많을수록 비용이 큽니다. 이를 여러 번 호출하는 대신, 한 번의 정렬 후 슬라이싱하는 패턴이 훨씬 효율적입니다.
결론
이번 개선은 Ray Data의 파이프라인 성능을 크게 높였을 뿐만 아니라, 대규모 데이터 처리 시 메모리 효율성을 극대화하는 모범 사례를 보여줍니다. 데이터 엔지니어링에서 '복사'를 줄이는 것이 곧 '성능'임을 다시 한번 확인하게 됩니다.
참고 자료
- https://arrow.apache.org/docs/python/generated/pyarrow.compute.sort_indices.html
- https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.slice
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [Ray Data] PyArrow 스키마 해싱 방식 개선으로 대규모 데이터셋 성능 향상
- [ray] Ray Data의 차세대 데이터 소스 API: DataSourceV2 설계 및 최적화 전략
- [Ray] concat_tables의 Happy Path를 최적화하여 동일 스키마 테이블 연결 가속화
- [Ray RLlib] SingleAgentEnvRunner의 validate 호출 위치 최적화로 3.1배 속도 향상
- [Ray Serve] Pack 스케줄링 최적화: O(replicas x total_replicas)에서 O(replicas x nodes)로
PR Analysis 의 다른글
- 이전글 [sglang] DeepSeek V4의 Prefill 성능을 1.35배 향상시킨 FlashAttention 최적화
- 현재글 : [ray] Ray Data의 hash_partition 성능을 7배 향상시킨 최적화 전략
- 다음글 [flashinfer] FlashInfer의 MoE Routing 성능 최적화: Batcher's Odd-Even Merge Sort 도입
댓글