본문으로 건너뛰기

[Ray RLlib] 모듈별 루프에서 ALL_MODULES 처리량 메트릭을 루프 밖으로 이동하여 바이어스 제거

PR 링크: ray-project/ray#57215 상태: Merged | 변경: +39 / -26

들어가며

Ray RLlib에서 num_module_steps_trained_(lifetime)_throughput 메트릭은 학습 처리량을 측정합니다. 기존 구현에서는 개별 모듈 배치를 순회하는 루프 내부에서 ALL_MODULES 키에 대한 메트릭을 기록했습니다. 이 경우 모듈이 3개면 타임스탬프가 3번 찍히므로, 처리량 계산 시 시간 간격이 실제보다 짧게 측정되어 처리량이 부풀려집니다.

핵심 코드 분석

Before: 모듈 루프 내부에서 ALL_MODULES 기록

def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
    for mid, module_batch in batch.policy_batches.items():
        module_batch_size = len(module_batch)
        # 개별 모듈 메트릭 기록
        self.metrics.log_value(
            key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
            value=module_batch_size, reduce="sum",
        )
        # ALL_MODULES도 루프 안에서 기록 - 바이어스 발생!
        self.metrics.log_value(
            key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
            value=module_batch_size, reduce="sum",
            clear_on_reduce=True, with_throughput=True,
        )
        self.metrics.log_value(
            key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
            value=module_batch_size, reduce="sum", with_throughput=True,
        )

After: 루프 밖에서 합산 후 한 번만 기록

def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
    total_module_steps = 0
    for mid, module_batch in batch.policy_batches.items():
        module_batch_size = len(module_batch)
        # 개별 모듈 메트릭
        self.metrics.log_value(
            key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
            value=module_batch_size, reduce="sum", with_throughput=True,
        )
        total_module_steps += module_batch_size

    # ALL_MODULES는 루프 밖에서 한 번만 기록
    self.metrics.log_value(
        key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
        value=total_module_steps, reduce="sum", clear_on_reduce=True,
    )
    self.metrics.log_value(
        key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
        value=total_module_steps, reduce="sum", with_throughput=True,
    )

왜 이게 좋은가

  1. 처리량 바이어스 제거: with_throughput=True 메트릭은 기록 시점의 타임스탬프를 사용한다. 루프 내에서 여러 번 기록하면 시간 간격이 매우 짧아져 처리량이 인위적으로 높게 계산된다.
  2. 일관성: Learner와 DifferentiableLearner 양쪽에 동일한 수정을 적용하여 모든 학습기에서 정확한 메트릭을 보장한다.
  3. 논리적 정확성: ALL_MODULES 합산은 본질적으로 모든 모듈의 총합이므로, 루프 완료 후 한 번에 기록하는 것이 의미론적으로도 올바르다.

참고 자료

댓글

관련 포스트

PR Analysis 의 다른글