[axolotl] Axolotl, 대규모 언어 모델 학습 시 메모리 부족 문제 해결: 효율적인 데이터셋 처리 개선
PR 링크: axolotl-ai-cloud/axolotl#3711 상태: Merged | 변경: +142 / -13
들어가며
대규모 언어 모델(LLM)을 학습시킬 때, 특히 지도 미세 조정(Supervised Fine-Tuning, SFT) 데이터셋을 다룰 때 메모리 부족(Out-Of-Memory, OOM) 문제는 흔하게 발생합니다. 이는 데이터셋의 크기가 방대하고, 각 데이터 샘플이 긴 컨텍스트를 포함할 경우 더욱 심화됩니다. 최근 axolotl-ai-cloud/axolotl 레포지토리의 PR #2975는 이러한 문제를 해결하기 위해 calculate_total_num_steps 함수에서 발생하는 메모리 누수를 개선했습니다. 본 글에서는 해당 PR의 코드 변경 사항을 분석하고, 왜 이러한 변경이 효율적인 데이터셋 처리에 기여하는지 자세히 살펴보겠습니다.
기존 코드의 문제는 calculate_total_num_steps 함수가 전체 input_ids와 labels 컬럼을 to_pandas()를 사용하여 Python 객체로 변환하면서 발생했습니다. 이는 Hugging Face datasets 라이브러리의 Arrow 기반 메모리 매핑(mmap) 효율성을 저해하고, 특히 긴 컨텍스트를 가진 데이터셋의 경우 토큰당 약 7-8바이트의 메모리를 추가로 사용하여 총 메모리 사용량을 급증시키는 결과를 초래했습니다. 예를 들어, 9.5백만 행에 16k 컨텍스트 길이를 가진 데이터셋에서는 약 1.3TB에 달하는 메모리가 필요할 수 있습니다.
이번 PR은 이러한 비효율적인 데이터 처리 방식을 개선하여 메모리 사용량을 획기적으로 줄이고, 대규모 데이터셋 처리 속도를 향상시키는 것을 목표로 합니다.
코드 분석
src/axolotl/utils/trainer.py 파일 변경 분석
이번 PR의 핵심 변경 사항은 calculate_total_num_steps 함수 내에서 input_ids와 labels 컬럼의 총 토큰 수를 계산하는 방식에 있습니다.
1. total_num_tokens 계산 방식 변경
기존 코드 (Before):
- total_num_tokens = np.sum(
- train_dataset.select_columns("input_ids")
- .to_pandas()["input_ids"]
- .apply(len)
- .values
- )
기존 코드에서는 train_dataset.select_columns("input_ids").to_pandas()를 호출하여 input_ids 컬럼 전체를 Pandas DataFrame으로 변환했습니다. 이후 각 행의 input_ids 리스트 길이를 계산하고, 이를 NumPy 배열로 변환한 뒤 합산했습니다. 이 과정에서 전체 input_ids 데이터가 메모리로 로드되어 상당한 메모리 사용량을 발생시켰습니다.
변경된 코드 (After):
+ if "length" in train_dataset.data.column_names:
+ total_num_tokens = int(
+ pc.sum(train_dataset.data.column("length"), min_count=0).as_py()
+ )
+ else:
+ total_num_tokens = int(
+ pc.sum(
+ pc.list_value_length(train_dataset.data.column("input_ids")),
+ min_count=0,
+ ).as_py()
+ )
변경된 코드는 두 가지 경우를 고려합니다.
length컬럼이 존재하는 경우: 만약 데이터셋에 미리 계산된length컬럼이 있다면,train_dataset.data.column("length")를 직접 참조하여pyarrow.compute.sum함수로 합계를 계산합니다. 이는 가장 효율적인 방법으로, 별도의 길이 계산 없이 저장된 값을 바로 사용합니다.length컬럼이 없는 경우:length컬럼이 없다면,pyarrow.compute.list_value_length(train_dataset.data.column("input_ids"))를 사용하여 각input_ids리스트의 길이를 계산합니다. 이 결과에 대해pyarrow.compute.sum을 적용하여 총 토큰 수를 얻습니다. 이 방식은to_pandas()를 사용하는 것보다 훨씬 효율적입니다.pyarrow.compute함수들은 Arrow 테이블의 내부 표현을 직접 활용하므로, 전체 데이터를 Python 객체로 변환하는 오버헤드가 없습니다.
pc.sum(...).as_py()를 통해 최종적으로 계산된 값을 Python 정수형으로 변환합니다.
2. total_supervised_tokens 계산 방식 변경
기존 코드 (Before):
- total_supervised_tokens = (
- train_dataset.data.column("labels")
- .to_pandas()
- .apply(lambda x: np.sum(np.array(x) != -100))
- .sum()
- )
기존 코드 역시 train_dataset.data.column("labels").to_pandas()를 사용하여 labels 컬럼 전체를 Pandas DataFrame으로 변환했습니다. 이후 각 행의 labels 리스트에서 -100이 아닌 토큰의 개수를 세고, 이 값들을 합산했습니다. 이 역시 대규모 데이터셋에서는 상당한 메모리 사용량을 유발했습니다.
변경된 코드 (After):
+ # Stream the labels column in record-batch chunks instead of
+ # .to_pandas(), which materializes one Python list per row.
+ total_supervised_tokens = 0
+ for batch in train_dataset.data.to_batches(max_chunksize=1024):
+ labels = batch.column("labels")
+ if pa.types.is_list(labels.type) or pa.types.is_large_list(labels.type):
+ flat = labels.flatten().to_numpy(zero_copy_only=False)
+ else:
+ flat = labels.to_numpy(zero_copy_only=False)
+ total_supervised_tokens += int((flat != -100).sum())
변경된 코드는 train_dataset.data.to_batches(max_chunksize=1024)를 사용하여 labels 컬럼을 작은 배치(chunk) 단위로 스트리밍합니다. 각 배치에 대해:
batch.column("labels")로labels데이터를 가져옵니다.labels가 리스트 타입인지 확인하고,flatten()을 통해 1차원 배열로 만듭니다. (이 부분은labels가 중첩된 리스트가 아닐 경우에도 안전하게 동작하도록to_numpy로 변환됩니다.)to_numpy(zero_copy_only=False)를 사용하여 NumPy 배열로 변환합니다.flat != -100조건을 사용하여-100이 아닌 토큰의 개수를 센 후, 이를total_supervised_tokens에 누적합니다.
이 방식은 전체 labels 데이터를 한 번에 메모리에 올리지 않고 배치 단위로 처리하므로 메모리 사용량을 크게 줄일 수 있습니다. max_chunksize=1024는 처리할 배치 크기를 지정하며, 필요에 따라 조절할 수 있습니다.
tests/test_calculate_total_num_steps.py 파일 변경 분석
새롭게 추가된 tests/test_calculate_total_num_steps.py 파일은 calculate_total_num_steps 함수의 정확성과 효율성을 검증하기 위한 테스트 케이스들을 포함합니다.
TestCalculateTotalNumSteps클래스: 다양한 시나리오에서 토큰 계산이 올바르게 이루어지는지 테스트합니다.test_in_memory_single_chunk: 메모리 내 단일 청크 데이터셋에 대한 테스트.test_disk_round_trip: 디스크에 저장했다가 다시 로드한 데이터셋에 대한 테스트.test_large_list:LargeList타입의 데이터셋에 대한 테스트.test_length_column_takes_precedence:length컬럼이 존재할 때 우선적으로 사용되는지 확인하는 테스트.test_empty_dataset: 빈 데이터셋에 대한 테스트.test_update_false_does_not_mutate_cfg:update=False옵션 사용 시 설정이 변경되지 않는지 확인하는 테스트.
이 테스트들은 변경된 코드가 다양한 데이터셋 구조와 크기에서도 안정적으로 작동함을 보장합니다.
왜 이게 좋은가?
이번 PR의 변경 사항은 다음과 같은 이유로 매우 긍정적입니다:
- 획기적인 메모리 절감: 가장 큰 이점은 메모리 사용량의 급격한 감소입니다. 기존의
to_pandas()방식은 전체 데이터셋을 메모리로 로드하여 OOM 오류를 유발했지만, 변경된 코드는pyarrow.compute와 배치 스트리밍을 활용하여 필요한 데이터만 처리합니다. 이는 수백 GB 또는 TB 단위의 메모리 사용량을 수 GB 수준으로 줄일 수 있어, 더 큰 모델이나 더 긴 컨텍스트를 가진 데이터셋을 메모리 제약 없이 학습할 수 있게 합니다. - 성능 향상: 메모리 로딩 및 Python 객체 변환 오버헤드가 줄어들면서 데이터셋 처리 속도 또한 향상됩니다. 특히 대규모 데이터셋의 경우, 이로 인한 학습 준비 시간 단축 효과는 상당할 것입니다.
datasets라이브러리 활용 극대화: Hugging Facedatasets라이브러리는 Arrow를 기반으로 효율적인 데이터 처리를 지원합니다. 이번 변경은pyarrow.compute와 같은 라이브러리의 기능을 적극적으로 활용하여 Arrow의 장점을 최대한 끌어냈습니다. 이는datasets라이브러리를 사용하는 다른 프로젝트에서도 참고할 만한 모범 사례입니다.- 코드의 견고성 및 테스트 커버리지 향상: 새로운 테스트 케이스 추가는 코드의 신뢰성을 높이고, 향후 발생할 수 있는 회귀 오류를 방지하는 데 기여합니다. 다양한 엣지 케이스를 다루는 테스트는 라이브러리의 안정성을 보장합니다.
일반적 교훈:
- 대규모 데이터셋 처리 시, 전체 데이터를 메모리로 로드하는
to_pandas()와 같은 방식은 피해야 합니다. 대신 Arrow의 내장 함수(pyarrow.compute)나 배치 스트리밍(to_batches)을 활용하여 메모리 효율성을 높여야 합니다. - 데이터셋에 미리 계산된 메타데이터(예:
length)가 있다면, 이를 적극적으로 활용하여 불필요한 계산을 줄여야 합니다. - 성능 병목 지점을 식별하고, 해당 부분을 라이브러리의 저수준 API나 최적화된 기능을 사용하여 개선하는 것이 중요합니다.
리뷰 피드백 반영
PR 설명과 CodeRabbit의 요약에 따르면, 이 변경은 주로 성능 개선과 메모리 효율성 증대에 초점을 맞추고 있습니다. 리뷰어들의 구체적인 피드백 내용은 제공되지 않았지만, 이러한 종류의 메모리 최적화 PR은 일반적으로 코드의 명확성, 잠재적인 엣지 케이스 처리, 그리고 성능 향상의 실제 측정치 등을 중점적으로 검토받습니다. 추가된 테스트 케이스들은 이러한 검토 과정에서 요구될 수 있는 견고성을 입증하는 데 도움이 됩니다.
References
- Hugging Face Datasets
pyarrow.computedocumentation - Hugging Face Datasets
to_batchesdocumentation - Arrow
list_value_lengthfunction
참고 자료
- https://arrow.apache.org/docs/python/compute.html
- https://huggingface.co/docs/datasets/process#streaming-with-batches
- https://arrow.apache.org/docs/python/compute.html#list_value_length
⚠️ 알림: 이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.
관련 포스트
- [ACE-Step-1.5] 외부 의존성을 걷어내고 성능을 잡다: ACE-Step 1.5의 커스텀 vLLM 엔진 도입기
- [axolotl] Axolotl에 도입된 Stateless 최적화: SinkGD로 메모리 효율 극대화하기
- [sglang] SGLang의 긴 문맥 처리 최적화: fill_ids 재구성 오버헤드 줄이기
- [sglang] SGLang 성능 최적화: torch.cuda.empty_cache() 호출 제어를 통한 가중치 업데이트 병목 해결
- [onnxruntime] ONNX Runtime QMoE SwiGLU GEMV 최적화: Split-K2 커널로 LLM 추론 가속화
PR Analysis 의 다른글
- 이전글 [onnxruntime] ONNX Runtime: MoE Router GEMV 최적화 및 Bias Fusion 구현
- 현재글 : [axolotl] Axolotl, 대규모 언어 모델 학습 시 메모리 부족 문제 해결: 효율적인 데이터셋 처리 개선
- 다음글 [sglang] SGLang 성능 최적화: D2H 복사 연산의 비동기 오버랩 구현
댓글