[Ultralytics] multi_scale 옵션을 auto-batch 계산에 포함하여 OOM 방지
PR 링크: ultralytics/ultralytics#24051 상태: Merged | 변경: +3 / -2
들어가며
YOLO 학습에서 multi_scale 옵션은 학습 중 이미지 크기를 무작위로 변경하여 다양한 스케일에 강건한 모델을 만드는 기법입니다. 문제는 auto-batch 기능이 기본 imgsz만 고려하여 배치 크기를 계산한다는 점입니다. multi_scale이 활성화되면 실제 이미지 크기가 imgsz보다 커질 수 있어, auto-batch가 산출한 배치 크기에서 OOM(Out of Memory)이 발생할 수 있습니다.
핵심 코드 분석
auto_batch 메서드의 이미지 크기 계산 수정
Before:
def auto_batch(self, max_num_obj=0, dataset_size=0):
"""Calculate optimal batch size based on model and device memory constraints."""
return check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
max_num_obj=max_num_obj,
dataset_size=dataset_size,
)
After:
def auto_batch(self, max_num_obj=0, dataset_size=0):
"""Calculate optimal batch size based on model and device memory constraints."""
max_imgsz = int(self.args.imgsz * (1 + self.args.multi_scale)) # need not be stride-aligned
return check_train_batch_size(
model=self.model,
imgsz=max_imgsz,
amp=self.amp,
batch=self.batch_size,
max_num_obj=max_num_obj,
dataset_size=dataset_size,
)
self.args.multi_scale은 0.0~1.0 사이의 float 값으로, 예를 들어 multi_scale=0.5이면 이미지 크기가 기본값의 50%까지 증가할 수 있습니다. imgsz=640이고 multi_scale=0.5인 경우:
- Before: auto-batch가 640x640 기준으로 배치 크기 계산 → 실제 학습 시 960x960 이미지가 들어오면 OOM
- After:
int(640 * 1.5) = 960기준으로 배치 크기 계산 → OOM 방지
주석에서 "need not be stride-aligned"라고 명시한 것은, auto-batch 계산 시에는 stride 정렬 없이 최대 크기만 보수적으로 추정하면 충분하기 때문입니다.
왜 이게 좋은가
-
OOM 사전 방지: multi_scale 학습에서 발생하던 간헐적 OOM을 근본적으로 해결합니다.
-
보수적 추정: 최대 가능 이미지 크기를 기준으로 배치를 산출하므로, 메모리가 부족한 상황이 발생하지 않습니다. 약간의 메모리 여유가 생길 수 있지만, OOM으로 학습이 중단되는 것보다 훨씬 나은 트레이드오프입니다.
-
단 한 줄의 핵심 수정:
max_imgsz계산 한 줄이 전부이며, 기존 함수 시그니처와 동작을 전혀 변경하지 않습니다.
정리
auto-batch 계산에 multi_scale 배율을 반영하여 실제 최대 이미지 크기를 기준으로 배치 크기를 산출하도록 수정한 PR입니다. 3줄의 변경으로 multi_scale 학습 시의 OOM 문제를 우아하게 해결합니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 실제 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [Ultralytics] MPS 디바이스에서 메모리 누수 방지를 위한 적극적 메모리 정리
- 현재글 : [Ultralytics] multi_scale 옵션을 auto-batch 계산에 포함하여 OOM 방지
- 다음글 [sglang] GC Threshold 인자 추가: Python 가비지 컬렉션 주기 튜닝 지원
댓글