본문으로 건너뛰기

[Ultralytics] Pose Loss의 keypoint 배치 루프를 벡터 연산으로 최적화

PR 링크: ultralytics/ultralytics#23937 상태: Merged | 변경: +7 / -5

들어가며

Pose Estimation 모델의 Loss 계산에서 keypoint 텐서를 배치별로 정리하는 작업은 매 학습 iteration마다 반복됩니다. 기존에는 # TODO: any idea how to vectorize this?라는 주석과 함께 Python for 루프로 구현되어 있었습니다. 이번 PR은 이 TODO를 해결하여, detect/obb Loss 벡터화(#23966)와 동일한 scatter_add_ + cumsum 패턴을 적용합니다.

핵심 코드 분석

keypoint 배치 정리 벡터화

Before:

# TODO: any idea how to vectorize this?
# Fill batched_keypoints with keypoints based on batch_idx
for i in range(batch_size):
    keypoints_i = keypoints[batch_idx == i]
    batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i

After:

# Vectorized fill: compute within-batch position for each keypoint using cumulative offsets
batch_idx_long = batch_idx.long()
offsets = torch.zeros(batch_size + 1, dtype=torch.long, device=keypoints.device)
offsets.scatter_add_(0, batch_idx_long + 1, torch.ones_like(batch_idx_long))
offsets = offsets.cumsum(0)
within_idx = torch.arange(len(batch_idx), device=keypoints.device) - offsets[batch_idx_long]
batched_keypoints[batch_idx_long, within_idx] = keypoints

동작 원리:

  1. scatter_add_: 각 배치에 속한 keypoint 수를 카운트
  2. cumsum: 배치별 시작 오프셋 계산
  3. arange - offsets[batch_idx]: 각 keypoint의 배치 내 순서 산출
  4. batched_keypoints[batch_idx, within_idx] = keypoints: 한 번의 advanced indexing으로 전체 할당

이는 PR #23966에서 detect/obb Loss에 적용한 것과 완전히 동일한 패턴이며, keypoint의 다차원 형태([N, K, D])에도 broadcasting 덕분에 그대로 적용됩니다.

왜 이게 좋은가

  1. TODO 해결: 코드에 명시적으로 남겨진 TODO를 해결하여 기술 부채를 청산합니다.

  2. 일관된 패턴 적용: detect/obb/pose 세 가지 Loss 모듈이 동일한 벡터화 패턴을 사용하게 되어, 코드베이스의 일관성이 향상됩니다.

  3. Pose 모델 학습 속도 향상: 배치 크기가 큰 Pose 학습에서 Python 루프 오버헤드가 제거되어 GPU 활용도가 높아집니다.

정리

Pose Loss의 keypoint 배치 정리에 scatter_add_ + cumsum 벡터화 패턴을 적용한 7줄의 변경입니다. 기존 코드의 TODO를 해결하고, detect/obb와 동일한 최적화 패턴으로 코드베이스 전체의 일관성을 확보합니다.

참고 자료


이 글은 AI(Claude)의 도움을 받아 작성되었으며, 실제 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.

댓글

관련 포스트

PR Analysis 의 다른글