본문으로 건너뛰기

[llm-compressor] Transformers Tracing: 모델 그래프 추적과 부분 forward

들어가며

Sequential Pipeline은 모델을 "레이어 단위로 쪼개어" 순차 실행한다. 이 분할이 가능하려면 모델의 forward 그래프를 정적으로 분석할 수 있어야 한다. llm-compressor는 PyTorch의 **torch.fx**를 사용해 HuggingFace 모델을 추적(trace)한다. 그런데 HF 모델들은 종종 control flow, Python 제어 구문, 외부 함수 호출 때문에 FX 추적이 실패한다. 이를 해결하는 것이 src/llmcompressor/transformers/tracing/의 헬퍼들이다.

핵심 구조/코드 분석

FX Tracing의 기본

PyTorch FX는 모델의 forward를 심볼릭 실행해 그래프를 만든다.

from torch.fx import symbolic_trace

traced = symbolic_trace(model)
print(traced.graph)
# call_module[input_layernorm]
# call_module[self_attn.q_proj]
# call_module[self_attn.k_proj]
# ...

이 그래프는 "각 노드가 어떤 모듈/함수를 호출하는지"를 보여준다. 서브그래프 분할은 이 그래프를 후처리해서 이뤄진다.

문제: HF 모델의 control flow

HF 모델 코드에는 종종 control flow가 들어간다.

# LlamaModel.forward 에서
if attention_mask is not None:
    causal_mask = self._update_causal_mask(...)
else:
    causal_mask = None

FX의 심볼릭 실행은 심볼릭 텐서에 대한 if 분기를 처리할 수 없다. if tensor is None: 같은 코드는 에러를 낸다.

tracing/debug.py의 헬퍼들

src/llmcompressor/transformers/tracing/debug.py는 이 문제들을 진단하고 우회하는 도구를 제공한다.

def diagnose_tracing_failure(model, sample_input, tracing_ignore: list[str]):
    """
    Attempt to trace model, report untraceable functions with suggestions.
    """
    try:
        from transformers.utils.fx import symbolic_trace as hf_symbolic_trace
        traced = hf_symbolic_trace(
            model,
            input_names=list(sample_input.keys()),
        )
        return traced
    except Exception as e:
        # 에러 메시지 분석해 어느 함수가 문제인지 파악
        offending = _extract_offending_function(str(e))
        logger.error(
            f"Tracing failed at {offending}. "
            f"Add '{offending}' to `tracing_ignore` in DatasetArguments."
        )
        raise


def _extract_offending_function(error_msg: str) -> str:
    """
    Parse FX error to find the name of the untraceable function.
    """
    # "TraceError: symbolically traced variables cannot be used as inputs to control flow"
    # 같은 메시지에서 함수 이름 추출
    ...

**핵심 유틸리티는 tracing_ignore**다. 사용자가 "이 함수들은 추적하지 말고 그대로 실행해라"를 지정할 수 있다.

기본 tracing_ignore 목록

Oneshot 진입점 글에서 본 것처럼, oneshot()은 기본 tracing_ignore 목록을 가진다.

tracing_ignore: list[str] = [
    "_update_causal_mask",
    "create_causal_mask",
    "_update_mamba_mask",
    "make_causal_mask",
    "get_causal_mask",
    "mask_interface",
    "mask_function",
    "_prepare_4d_causal_attention_mask",
    "_prepare_fsmt_decoder_inputs",
    "_prepare_4d_causal_attention_mask_with_cache_position",
    "_update_linear_attn_mask",
    "project_per_layer_inputs",
]

이 목록은 HF Transformers의 일반적인 attention mask 처리 함수들이다. FX가 이들을 추적하지 못하는 이유는 내부에 if None: 같은 dynamic check가 들어 있기 때문이다. llm-compressor는 이들을 "leaf 함수"로 마킹해 FX가 내부를 건드리지 않고 call 노드로만 기록하게 한다.

leaf_function 매핑

FX는 wrap 메커니즘으로 특정 함수를 leaf로 마킹할 수 있다.

import torch.fx as fx

# 이 함수 호출은 심볼릭 실행하지 않고 call_function 노드로만 기록
fx.wrap("_update_causal_mask")

llm-compressor의 tracing 헬퍼는 tracing_ignore 목록을 순회하며 각 함수를 이렇게 wrap한다. 이후 symbolic_trace가 실행되면 이 함수들은 그래프의 leaf 노드가 되어 내부를 들여다보지 않는다.

Sequential Pipeline과의 통합

Sequential Pipeline은 이 tracing 결과를 사용해 모델을 서브그래프로 쪼갠다.

# pipelines/sequential/pipeline.py 에서
def trace_subgraphs(model, sample_input, sequential_targets, ignore, ...):
    # 1) HF 의 fx symbolic_trace 실행 (tracing_ignore 적용)
    traced = symbolic_trace_with_ignores(model, sample_input, ignore)

    # 2) 그래프를 sequential_targets 기준으로 분할
    # 예: LlamaDecoderLayer 를 경계로 0..15, 16..31 ... 서브그래프 생성
    subgraphs = split_by_target(traced, sequential_targets)

    return subgraphs

split_by_target는 "각 LlamaDecoderLayer 경계에서 그래프를 잘라" 독립적으로 실행 가능한 서브그래프 리스트를 만든다. 각 서브그래프는 자신의 입력과 출력을 명시한다.

트레이싱 실패 시 디버깅

사용자가 새 모델에 llm-compressor를 적용할 때 tracing 실패는 자주 일어난다. debug.py의 유틸리티는 이런 경우 도움을 준다.

# 사용자 코드
from llmcompressor.transformers.tracing.debug import diagnose_tracing_failure

try:
    diagnose_tracing_failure(model, sample_input, tracing_ignore=[])
except Exception as e:
    # 에러 메시지에 "Add '..._some_func' to tracing_ignore" 가 포함됨
    print(e)

경험적으로 "새 모델에서 트레이싱 실패 → 에러 메시지의 함수 이름을 tracing_ignore에 추가 → 다시 시도"를 몇 번 반복하면 대부분 해결된다.

왜 이 설계인가

1. FX 기반 분할. torch.fx는 PyTorch 공식 그래프 IR이다. 안정적이고 생태계가 성숙해 있다. 외부 도구 없이 PyTorch 표준만으로 서브그래프 분할을 구현할 수 있다.

2. tracing_ignore로 우회. 모든 HF 모델을 완벽히 추적하려 하는 것은 비현실적이다. Control flow가 있는 함수들을 leaf로 마킹해 그 내부는 무시하면 대부분의 모델이 추적 가능해진다.

3. 기본 목록의 실용성. HF 모델의 공통 "문제 함수"들(attention mask 처리)을 기본 tracing_ignore에 포함. 사용자는 90% 이상의 모델을 기본 설정으로 돌릴 수 있다.

4. debug.py의 자기 진단. 에러 메시지에서 "어느 함수를 추가해야 하는지"를 파싱해 사용자에게 알려준다. 시행착오를 줄인다.

5. Leaf function 방식. fx.wrap이 내부적으로 함수 호출을 "단일 call 노드"로 만들므로 서브그래프 분할에 영향이 없다. 내부 값이 필요하지 않은 헬퍼 함수들에 적합한 방식.

마무리

Transformers Tracing은 llm-compressor가 HuggingFace 생태계와 매끄럽게 동작하게 만드는 접착제다. 다음 글은 압축된 모델을 파일로 저장하는 Compression Save를 본다.

참고 자료

댓글

관련 포스트

llm-compressor 의 다른글