본문으로 건너뛰기

[Ray] DefaultCollateFn 병렬화로 Arrow-to-Tensor 변환 가속

PR 링크: ray-project/ray#58821 상태: Merged | 변경: +209 / -23

들어가며

Ray Data에서 PyTorch 학습을 위해 Apache Arrow 배치를 텐서로 변환하는 DefaultCollateFn은 모든 컬럼을 순차적으로 처리하고 있었다. 다수의 컬럼을 가진 데이터셋에서는 이 변환이 병목이 될 수 있다. 이 PR은 ThreadPoolExecutor를 도입하여 컬럼별 변환을 병렬화함으로써 처리 속도를 개선한다.

핵심 코드 분석

arrow_batch_to_tensors에 threadpool 파라미터 추가

def arrow_batch_to_tensors(
    batch: "pyarrow.Table",
    dtypes: Optional[Union[torch.dtype, Dict[str, torch.dtype]]] = None,
    combine_chunks: bool = False,
    pin_memory: bool = False,
    threadpool: Optional[ThreadPoolExecutor] = None,
) -> Union[Dict[str, torch.Tensor], Dict[str, List[torch.Tensor]]]:

컬럼이 여러 개일 때 병렬 처리

if num_columns > 1 and threadpool is not None:
    def process_column(
        col_name_col_array: Tuple[str, np.ndarray]
    ) -> Tuple[str, torch.Tensor]:
        col_name, col_array = col_name_col_array
        return col_name, convert_ndarray_batch_to_torch_tensor_batch(
            col_array,
            dtypes=dtypes[col_name] if isinstance(dtypes, dict) else dtypes,
            pin_memory=pin_memory,
        )

    processed_cols = threadpool.map(process_column, numpy_batch.items())
    return dict(processed_cols)

DefaultCollateFn에서 ThreadPool 관리

class DefaultCollateFn(ArrowBatchCollateFn):
    _DEFAULT_NUM_WORKERS = env_integer(
        "RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS",
        4,
    )

    def __init__(self, ..., num_workers: int = _DEFAULT_NUM_WORKERS):
        self.num_workers = num_workers
        self._threadpool: Optional[ThreadPoolExecutor] = None

    def __del__(self):
        if getattr(self, "_threadpool", None):
            self._threadpool.shutdown(wait=False)

왜 이게 좋은가

  1. GIL 우회: NumPy의 ndarray 연산과 PyTorch의 텐서 생성은 내부적으로 C 확장을 사용하므로 Python GIL의 영향을 덜 받는다. 따라서 ThreadPoolExecutor로도 실제 병렬 처리 효과를 얻을 수 있다.
  2. 환경 변수 제어: RAY_DATA_DEFAULT_COLLATE_FN_THREADPOOL_MAX_WORKERS로 워커 수를 조절할 수 있어 운영 환경에 맞게 튜닝이 가능하다.
  3. 조건부 병렬화: 컬럼이 1개이거나 threadpool이 없으면 기존 순차 처리를 유지하여 오버헤드를 방지한다.
  4. 청크 단위 병렬화: combine_chunks=False인 경우에도 전체 배열을 플랫하게 병렬 처리하고, 결과를 원래 컬럼-인덱스 구조로 재조립한다.

참고 자료

댓글

관련 포스트

PR Analysis 의 다른글