[Ultralytics] Pose Loss의 keypoint 배치 루프를 벡터 연산으로 최적화
PR 링크: ultralytics/ultralytics#23937 상태: Merged | 변경: +7 / -5
들어가며
Pose Estimation 모델의 Loss 계산에서 keypoint 텐서를 배치별로 정리하는 작업은 매 학습 iteration마다 반복됩니다. 기존에는 # TODO: any idea how to vectorize this?라는 주석과 함께 Python for 루프로 구현되어 있었습니다. 이번 PR은 이 TODO를 해결하여, detect/obb Loss 벡터화(#23966)와 동일한 scatter_add_ + cumsum 패턴을 적용합니다.
핵심 코드 분석
keypoint 배치 정리 벡터화
Before:
# TODO: any idea how to vectorize this?
# Fill batched_keypoints with keypoints based on batch_idx
for i in range(batch_size):
keypoints_i = keypoints[batch_idx == i]
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
After:
# Vectorized fill: compute within-batch position for each keypoint using cumulative offsets
batch_idx_long = batch_idx.long()
offsets = torch.zeros(batch_size + 1, dtype=torch.long, device=keypoints.device)
offsets.scatter_add_(0, batch_idx_long + 1, torch.ones_like(batch_idx_long))
offsets = offsets.cumsum(0)
within_idx = torch.arange(len(batch_idx), device=keypoints.device) - offsets[batch_idx_long]
batched_keypoints[batch_idx_long, within_idx] = keypoints
동작 원리:
scatter_add_: 각 배치에 속한 keypoint 수를 카운트cumsum: 배치별 시작 오프셋 계산arange - offsets[batch_idx]: 각 keypoint의 배치 내 순서 산출batched_keypoints[batch_idx, within_idx] = keypoints: 한 번의 advanced indexing으로 전체 할당
이는 PR #23966에서 detect/obb Loss에 적용한 것과 완전히 동일한 패턴이며, keypoint의 다차원 형태([N, K, D])에도 broadcasting 덕분에 그대로 적용됩니다.
왜 이게 좋은가
-
TODO 해결: 코드에 명시적으로 남겨진 TODO를 해결하여 기술 부채를 청산합니다.
-
일관된 패턴 적용: detect/obb/pose 세 가지 Loss 모듈이 동일한 벡터화 패턴을 사용하게 되어, 코드베이스의 일관성이 향상됩니다.
-
Pose 모델 학습 속도 향상: 배치 크기가 큰 Pose 학습에서 Python 루프 오버헤드가 제거되어 GPU 활용도가 높아집니다.
정리
Pose Loss의 keypoint 배치 정리에 scatter_add_ + cumsum 벡터화 패턴을 적용한 7줄의 변경입니다. 기존 코드의 TODO를 해결하고, detect/obb와 동일한 최적화 패턴으로 코드베이스 전체의 일관성을 확보합니다.
참고 자료
이 글은 AI(Claude)의 도움을 받아 작성되었으며, 실제 PR의 코드 변경 사항을 기반으로 분석한 내용입니다.
관련 포스트
PR Analysis 의 다른글
- 이전글 [axolotl] Tensor Parallelism batch_size 계산 버그 수정: dp_world_size 기반으로 전환
- 현재글 : [Ultralytics] Pose Loss의 keypoint 배치 루프를 벡터 연산으로 최적화
- 다음글 [PaddleOCR] MCP 서버에서 모든 OCR 결과 배치를 파싱하도록 수정
댓글