[Ray] DefaultCollateFn 병렬화로 Arrow-to-Tensor 변환 가속
PR 링크: ray-project/ray#58821 상태: Merged | 변경: +209 / -23
들어가며
Ray Data에서 PyTorch 학습을 위해 Apache Arrow 배치를 텐서로 변환하는 DefaultCollateFn은 모든 컬럼을 순차적으로 처리하고 있었다. 다수의 컬럼을 가진 데이터셋에서는 이 변환이 병목이 될 수 있다. 이 PR은 ThreadPoolExecutor를 도입하여 컬럼별 변환을 병렬화함으로써 처리 속도를 개선한다.
핵심 코드 분석
arrow_batch_to_tensors에 threadpool 파라미터 추가
def arrow_batch_to_tensors(
batch: "pyarrow.Table",
dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
combine_chunks: bool = False,
pin_memory: bool = False,
threadpool: Optional[ThreadPoolExecutor] = None,
) -> Union[Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]]]:
컬럼이 여러 개일 때 병렬 처리
if num_columns > 1 and threadpool is not None:
def process_column(
col_name_col_array: Tuple[str, np.ndarray]
) -> Tuple[str, torch.Tensor]:
col_name, col_array = col_name_col_array
return col_name, convert_ndarray_batch_to_torch_tensor_batch(
col_array,
dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
pin_memory=pin_memory,
)
processed_cols = threadpool.map(process_column, numpy_batch.items())
return dict(processed_cols)
DefaultCollateFn에서 ThreadPool 관리
class DefaultCollateFn(ArrowBatchCollateFn):
_DEFAULT_NUM_WORKERS = env_integer(
"RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS",
4,
)
def __init__(self, ..., num_workers: int = _DEFAULT_NUM_WORKERS):
self.num_workers = num_workers
self._threadpool: Optional[ThreadPoolExecutor] = None
def __del__(self):
if getattr(self, "_threadpool", None):
self._threadpool.shutdown(wait=False)
왜 이게 좋은가
- GIL 우회: NumPy의
ndarray연산과 PyTorch의 텐서 생성은 내부적으로 C 확장을 사용하므로 Python GIL의 영향을 덜 받는다. 따라서ThreadPoolExecutor로도 실제 병렬 처리 효과를 얻을 수 있다. - 환경 변수 제어:
RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS로 워커 수를 조절할 수 있어 운영 환경에 맞게 튜닝이 가능하다. - 조건부 병렬화: 컬럼이 1개이거나 threadpool이 없으면 기존 순차 처리를 유지하여 오버헤드를 방지한다.
- 청크 단위 병렬화:
combine_chunks=False인 경우에도 전체 배열을 플랫하게 병렬 처리하고, 결과를 원래 컬럼-인덱스 구조로 재조립한다.
참고 자료
관련 포스트
PR Analysis 의 다른글
- 이전글 [triton] Out-of-tree TTIR/TTGIR 패스 플러그인 시스템
- 현재글 : [Ray] DefaultCollateFn 병렬화로 Arrow-to-Tensor 변환 가속
- 다음글 [Open WebUI] 외부 임베딩 API 호출을 병렬화하여 50배 성능 향상
댓글