본문으로 건너뛰기

[Ultralytics] detect/obb Loss 계산의 preprocess를 벡터화하여 학습 속도 향상

PR 링크: ultralytics/ultralytics#23966 상태: Merged | 변경: +16 / -16

들어가며

딥러닝 학습에서 Loss 계산은 매 iteration마다 실행되므로, 여기의 미세한 비효율도 전체 학습 시간에 누적됩니다. Ultralytics YOLO의 detect/obb Loss에서 preprocess 메서드는 타겟 텐서를 배치별로 정리하는 역할을 합니다. 기존에는 Python for 루프로 배치를 순회했지만, 이번 PR은 PyTorch의 텐서 연산(scatter_add_, cumsum)을 사용한 완전한 벡터화로 GPU 활용도를 극대화합니다.

핵심 코드 분석

detect Loss의 preprocess 벡터화

Before:

i = targets[:, 0]  # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
for j in range(batch_size):
    matches = i == j
    if n := matches.sum():
        out[j, :n] = targets[matches, 1:]

After:

batch_idx = targets[:, 0].long()  # image index
_, counts = batch_idx.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
offsets = torch.zeros(batch_size + 1, dtype=torch.long, device=self.device)
offsets.scatter_add_(0, batch_idx + 1, torch.ones_like(batch_idx))
offsets = offsets.cumsum(0)
within_idx = torch.arange(nl, device=self.device) - offsets[batch_idx]
out[batch_idx, within_idx] = targets[:, 1:]

핵심 아이디어는 각 타겟이 자신의 배치 내에서 몇 번째인지를 벡터 연산으로 계산하는 것입니다:

  1. scatter_add_로 각 배치의 타겟 수를 누적합니다
  2. cumsum으로 각 배치의 시작 오프셋을 구합니다
  3. arange - offsets[batch_idx]로 배치 내 인덱스를 한 번에 계산합니다
  4. advanced indexing out[batch_idx, within_idx]로 결과를 한 번에 할당합니다

동일한 패턴이 obb Loss의 preprocess에도 적용되었으며, obb에서는 추가로 bbox 스케일링(mul_)을 루프 밖으로 끌어내 한 번만 수행합니다.

왜 이게 좋은가

  1. GPU 병렬성 극대화: Python for 루프는 배치 수만큼 순차 실행되지만, 벡터 연산은 GPU의 수천 코어를 동시에 활용합니다. 배치 크기가 클수록 차이가 커집니다.

  2. CPU-GPU 통신 최소화: for 루프 내의 boolean masking과 조건문은 매 반복마다 GPU↔CPU 동기화를 유발할 수 있습니다. 벡터화는 단일 커널 호출로 처리합니다.

  3. 동일한 수학적 결과: 입출력 텐서의 형태와 값이 완전히 동일하므로, Loss 함수의 수학적 동작에 영향이 없습니다.

정리

scatter_add_ + cumsum + advanced indexing 패턴을 사용하여 detect/obb Loss의 preprocess를 완전히 벡터화한 PR입니다. for 루프를 텐서 연산으로 대체하는 이 패턴은 PyTorch에서 배치별 처리가 필요한 모든 곳에 적용할 수 있는 범용적 최적화 기법입니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 실제 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글