[Ray RLlib] 모듈별 루프에서 ALL_MODULES 처리량 메트릭을 루프 밖으로 이동하여 바이어스 제거
PR 링크: ray-project/ray#57215 상태: Merged | 변경: +39 / -26
들어가며
Ray RLlib에서 num_module_steps_trained_(lifetime)_throughput 메트릭은 학습 처리량을 측정합니다. 기존 구현에서는 개별 모듈 배치를 순회하는 루프 내부에서 ALL_MODULES 키에 대한 메트릭을 기록했습니다. 이 경우 모듈이 3개면 타임스탬프가 3번 찍히므로, 처리량 계산 시 시간 간격이 실제보다 짧게 측정되어 처리량이 부풀려집니다.
핵심 코드 분석
Before: 모듈 루프 내부에서 ALL_MODULES 기록
def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
for mid, module_batch in batch.policy_batches.items():
module_batch_size = len(module_batch)
# 개별 모듈 메트릭 기록
self.metrics.log_value(
key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
value=module_batch_size, reduce="sum",
)
# ALL_MODULES도 루프 안에서 기록 - 바이어스 발생!
self.metrics.log_value(
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
value=module_batch_size, reduce="sum",
clear_on_reduce=True, with_throughput=True,
)
self.metrics.log_value(
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
value=module_batch_size, reduce="sum", with_throughput=True,
)
After: 루프 밖에서 합산 후 한 번만 기록
def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
total_module_steps = 0
for mid, module_batch in batch.policy_batches.items():
module_batch_size = len(module_batch)
# 개별 모듈 메트릭
self.metrics.log_value(
key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
value=module_batch_size, reduce="sum", with_throughput=True,
)
total_module_steps += module_batch_size
# ALL_MODULES는 루프 밖에서 한 번만 기록
self.metrics.log_value(
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
value=total_module_steps, reduce="sum", clear_on_reduce=True,
)
self.metrics.log_value(
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
value=total_module_steps, reduce="sum", with_throughput=True,
)
왜 이게 좋은가
- 처리량 바이어스 제거:
with_throughput=True메트릭은 기록 시점의 타임스탬프를 사용한다. 루프 내에서 여러 번 기록하면 시간 간격이 매우 짧아져 처리량이 인위적으로 높게 계산된다. - 일관성: Learner와 DifferentiableLearner 양쪽에 동일한 수정을 적용하여 모든 학습기에서 정확한 메트릭을 보장한다.
- 논리적 정확성: ALL_MODULES 합산은 본질적으로 모든 모듈의 총합이므로, 루프 완료 후 한 번에 기록하는 것이 의미론적으로도 올바르다.
참고 자료
관련 포스트
PR Analysis 의 다른글
- 이전글 [Loki] 청크 재정렬 시 파이프라인 처리 바이패스로 CPU 최적화
- 현재글 : [Ray RLlib] 모듈별 루프에서 ALL_MODULES 처리량 메트릭을 루프 밖으로 이동하여 바이어스 제거
- 다음글 [Ultralytics] 학습 중 Multi-GPU 검증 지원
댓글