본문으로 건너뛰기

[triton] Gluon tmem_load에서 Register Layout 자동 추론

PR 링크: triton-lang/triton#9594 상태: Merged | 변경: +256 / -257

들어가며

NVIDIA Blackwell의 tensor memory(tmem)에서 데이터를 레지스터로 load할 때, 기존에는 사용자가 get_tmem_reg_layout를 명시적으로 호출하여 register layout을 지정해야 했습니다. 이 PR은 tensor memory descriptor 타입 자체에서 register layout을 자동 추론하도록 변경하여 API를 단순화합니다. BC(Backward Compatibility)-breaking 변경입니다.

핵심 코드 분석

Before (매번 layout 명시):

tmem_reg_layout = get_tmem_reg_layout(
    gl.float32, (BLOCK_M, BLOCK_N),
    acc_tmem_layout, num_warps=gl.num_warps())
acc = acc_tmem.load(tmem_reg_layout)

After (자동 추론):

acc = acc_tmem.load()

내부적으로는 tensor_memory_descriptor_type이 layout 정보를 포함하도록 변경되었습니다:

qk_tmem_ty = tensor_memory_descriptor_type(
    gl.float32, self.qk_shape, self.qk_tmem_layout, self.qk_shape)
self.qk_layout = gl.constexpr(
    qk_tmem_ty.get_reg_layout(num_warps=self.num_warps,
                              instr_variant="32x32b_splitn"))

C++ 바인딩에서도 allocShape를 별도 파라미터로 받도록 변경:

m.def("compute_tmem_reg_layout",
    [](py::object elementTyObj, std::vector<int64_t> shape,
       std::vector<int64_t> allocShape, py::object layoutObj, ...) {

왜 이게 좋은가

이 변경은 사용자 코드에서 get_tmem_reg_layout 호출을 완전히 제거하여 보일러플레이트를 크게 줄입니다. 테스트 코드만 봐도 7곳에서 6줄씩 제거되어 약 42줄이 절약됩니다. Layout 추론을 descriptor 타입에 위임함으로써, layout 불일치로 인한 실수를 원천 차단합니다. 예를 들어 load_max() 호출에서도 layout 인자가 불필요해졌습니다.

정리

tensor memory descriptor 타입에 get_reg_layout 메서드를 추가하고, tmem.load()가 layout을 자동 추론하도록 변경하여 get_tmem_reg_layout 함수의 사용을 제거했습니다.

참고 자료

이 분석은 AI가 실제 코드 diff를 기반으로 작성했습니다.

댓글

관련 포스트

PR Analysis 의 다른글