[CUDA] Reuse blocks with record_stream during CUDA Graph capture in the CUDACachingAllocator #158352

https://github.com/pytorch/pytorch/pull/158352

专栏另一篇文章解读了record_stream相关内容,但CUDACachingAllocator源码中存在一个分支,会根据CUDAAllocatorConfig::graph_capture_record_stream_reuse()判断是否free和malloc某个block,这里填个坑,找一下对应的pr看是怎么复用的。

背景:CUDA Graph 下显存管理限制

在 CUDA Graph 捕获过程中,CUDACachingAllocator对内存的管理存在一个关键限制:必须等到捕获结束后才能回收内存块。这一限制源于 CUDA 的底层约束 —— 捕获阶段不允许查询event状态(因为此时 CUDA 操作尚未执行),而CUDACachingAllocator依赖事件驱动逻辑判断内存块是否可回收。而这会导致本该被释放的block未能被正确释放,直到捕获完全结束,比如free中的代码片段:

    if (!block->stream_uses.empty()) {
      if (C10_UNLIKELY(!captures_underway.empty())) {
        // It's forbidden to cudaEventQuery an event recorded during CUDA graph
        // capture. We conservatively defer recording end-of-life events until
        // the next call to process_events() (which won't happen until no
        // captures are underway)
        needs_events_deferred_until_no_capture.push_back(block);

cudaGraph下安全的多流复用

为实现安全重用,PR 首先明确了两个核心术语,作为后续判断逻辑的基础:

  • Free marker:通过cudaGraphAddEmptyNode创建的 “捕获合法” 空节点,插入到每个使用过该内存块的stream中,且位于该块最后一次被使用的操作之后,用于标记 “内存块已空闲”。
  • Terminal Node:流或捕获图中 “最新操作” 的集合,新捕获的操作会附加在终端节点之后。对于正在捕获的流,可通过cudaStreamGetCaptureInfo的dependencies_out参数获取终端节点集合。

内存块可重用性判断规则

cudaGraph生成一个DAG,因此作者提出了两种平衡 “安全性” 与 “灵活性” 的判断规则:

  • Strong Rule (Graph-Wide Safety):从全局保障安全,规定若内存块的所有空闲标记均为所有活跃流的终端节点的 “前驱节点”,则该块可安全重用,通过严格的全局执行顺序避免重放时的生命周期重叠;
  • Per-stream Rule (A Practical Optimization): 仅验证单流,对于流 S 上的分配请求,若内存块的所有空闲标记是流 S 的终端节点的前驱节点,即可在流 S 上重用,因此无需验证全图安全性。

Image Alt Text

核心实现流程

捕获期间执行free

  1. 对block->stream_uses中的每个流及 “分配流”,插入free marker,关联对应的block和free marker所对应的空node
        if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
          deferred_blocks.emplace(block, insert_free_marker(block)); // 插入free marker
        } else {
          deferred_blocks.emplace(block, std::vector<cudaGraphNode_t>{}); // 捕获结束后访问
        }
  1. 捕获期间malloc
      if (CUDAAllocatorConfig::graph_capture_record_stream_reuse()) {
        free_safe_blocks_in_capture(context, stream);
      }

free_safe_blocks_in_capture函数内容:

get_reusable_empty_nodes通过获取当前流的终端节点,通过反向 DFS 遍历每个终端节点的依赖链,统计空闲节点能到达的终端数量;最终筛选出 “能到达所有终端节点” 的空闲节点,构成当前流可复用的空闲节点集合,从而判断哪些free marker属于当前流可复用空闲节点,之后,判断free时载入过的block中哪些可以被正确复用,如果可复用则释放:

auto reusable_empty_nodes = get_reusable_empty_nodes(stream);
if (reusable_empty_nodes.empty()) {
  return;
}

std::vector<Block*> blocks_to_erase; // 记录待删除的块

for (auto& [block, inserted_empty_nodes] : deferred_blocks) {
  // 跳过两类块:1. 无空闲标记的块(无法判断安全);2. 分配流≠当前流的块
  if (inserted_empty_nodes.empty() || block->stream != stream) {
    continue;
  }

  // 该块的所有空闲标记是否都在reusable_empty_nodes中
  bool is_reusable = true;
  for (const auto& node : inserted_empty_nodes) {
    if (reusable_empty_nodes.find(node) == reusable_empty_nodes.end()) {
      is_reusable = false;
      break;
    }
  }

  // 若所有标记都安全 → 该块可回收
  if (is_reusable) {
    // 清除stream_uses:Graph已通过依赖保证同步,无需再跟踪多流使用
    block->stream_uses.clear();
    // 将块回收到空闲列表
    free_block(block, context);
    // 记录该块,后续从deferred_blocks中删除
    blocks_to_erase.push_back(block);
  }
}

现在,只要开启graph_capture_record_stream_reuse,pytorch会自动尝试在cuda_graph中通过插入free marker来提前释放一些block,使其生命周期变短,而不再是等到捕获完全结束时一次性释放。