智能体/生成式 AI

在 NVIDIA CUDA Tile 中调整 Flash Attention 以实现峰值性能

在本文中,我们将深入探讨现代 AI 中至关重要的工作负载之一:Flash Attention,您将了解:

  • 如何使用 NVIDIA cuTile 实现 Flash AttentioncuTile。浏览完整代码,以构建生产就绪型应用。
  • “捕获并拯救”优化之旅。本案例研究展示了原生优化(如增大图块大小)可能适得其反的原因,并探讨了相应的解决方案。
  • 先进技术如 FMA 模式、快速数学计算、循环分割和自适应平铺,实现卓越性能。

环境要求:

  • CUDA 13.1 或更新版本
  • GPU 架构:NVIDIA Blackwell(例如,NVIDIA B200、GeForce RTX 50 系列)
  • Python:3.10 或更新版本

有关安装 cuTile Python 的更多信息,请参阅快速入门文档

什么是注意力?

注意力机制是 Transformer 模型的计算核心。给定 token 序列,注意力机制使每个 token 能够“查看”其他 token,并动态调整它们对当前 token 的贡献程度。从数学角度看,对于输入矩阵 Query (\(Q\))、Key (\(K\)) 和 Value (\(V\)),输出可表示如下:

\(O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V\)

其中:

  • \(Q \text{ has shape } (N,d),\ N \text{ query tokens, each with dimension } d.\)
  • \(K \text{ has shape } (N,d),\ N \text{ key tokens.}\)
  • \(V \text{ has shape } (N,d),\ N \text{ value tokens.}\)
  • \(\text{The intermediate } QK^{T} \text{ matrix has shape } (N,N), \text{ is a problem.}\)

内存带宽问题

对于长度为 \(N = 16,384\)(在现代 LLM 中很常见)的序列,注意力矩阵 \(QK^{T}\) 包含约 \(N^2 = 268\) 亿个元素。在 FP16 精度下,每个注意力头、每个批量样本需要 512 MB 的中间存储空间。

标准注意力实现:

  1. 计算完整的 \(N \times N\) 注意力矩阵并将其写入全局内存(速度较慢)
  2. 逐行应用 softmax 函数
  3. 读回该矩阵并与 \(V\) 相乘

这种方法受内存限制,因为 GPU 大部分时间都在等待数据在 HBM 与计算单元之间传输,而非等待计算完成。

Flash Attention 如何解决内存带宽瓶颈问题

Flash Attention(由 Dao 等人于 2022 年提出)是一种 IO 感知算法,不会显式构造完整的 \(N \times N\) 矩阵。相反,它:

  1. 将计算划分为小块: 将进程 \(Q, K, V\) 载入快速片内 SMEM 中
  2. 采用在线 softmax: 无需整行数据,可逐步计算 softmax
  3. 融合运算: 将矩阵乘法与 softmax 融合至单个内核中

其结果是速度提升了2-4倍,同时节省了大量内存,从而支持更长的上下文长度。

A tiled flash attention figure showing Q, K^T, V and O in HBM, being accumulated to Q, K, V, and O in SMEM.
图 1. 平铺式闪存注意力计算

了解在线 Softmax

Flash Attention 的关键算法见解是在线 Softmax 技巧。数值稳定的 safe softmax 要求在计算前掌握整行的最大值:

\(\text{softmax}(x_i) = \frac{e^{x_i – \max(x)}}{\sum_j e^{x_j – \max(x)}}\)

但如果我们要处理图块,则无法访问完整的行。Online Softmax 通过维护可逐步更新的运行统计量来解决这一问题。

在线 softmax 算法

我们为每行维护两个运行值:

  • \(m_i\): 截至目前观测到的最大值(用于数值稳定性)
  • \(l_i\): 截至目前观测到的指数之和(softmax 分母)

当我们处理具有值 \(x_{new}\) 的新图块时:

  1. 更新最大值:\(m_{new} = \max(m_i, \max(x_{new}))\)
  2. 计算校正系数:\(\alpha = e^{m_i – m_{new}}\) (对先前的数值进行重新缩放)
  3. 更新总和:\(l_i = l_i \cdot \alpha + \sum e^{x_{new} – m_{new}}\)
  4. 更新累加器:\(acc = acc \cdot \alpha + P_{new} \cdot V_{tile}\)

\(P_{new}\) 是注意力权重矩阵,\(V_{tile}\) 是与当前迭代的键图块对应的值矩阵图块。最后,我们进行归一化处理:\(O = acc / l_i\)。

这使我们能够在不存储整行的情况下,计算出精确的 softmax。

因果注意力和分组查询注意力

在深入探讨实现细节之前,我们先来了解现代大语言模型中采用的两种重要注意力变体:

因果注意力

在 GPT、LLaMA 和克劳德等自回归语言模型中,每个 token 只能基于序列中前面的 tokens 进行处理,而无法访问未来的信息。这种机制可防止模型在训练过程中提前获知后续内容,从而避免“作弊”现象的发生。

从数学角度来看,我们对注意力分数应用了三角形掩码

\(\text{mask}_{ij} = \begin{cases} 0 & \text{if } i \geq j \text{ (query position ≥ key position)} \ -\infty & \text{if } i < j \text{ (future tokens)} \end{cases}\)

被屏蔽的注意力变成:

\(O = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + \text{mask}\right)V\)

将 \(-\infty\) 添加到未来位置,可确保其在 softmax 之后变为零,从而有效阻断未来 token 的信息流动。

Causal attention mask matrix for 4 tokens showing how the upper triangle of the matrix is masked to 0, meaning that those values are not used in the computation.
图 2. 四个 tokens 的因果注意力遮罩

使用因果遮罩时,约有一半的注意力矩阵会被遮罩(上三角部分)。我们可以跳过这些被遮罩区域的计算,提供两倍的算法加速。这对 K 循环分块优化至关重要。

分组查询注意力

标准的多头注意力机制为每个注意力头都配备独立的 \(K,V\) 矩阵,导致内存占用较高:

  • 多头注意力 (MHA):32 个查询头 + 32 个 K/V 头(1:1 比率)
  • 分组查询注意力 (GQA):32 个查询头 + 4 个 K/V 头(8:1 比率)
  • 多查询注意力 (MQA):32 个查询头 + 1 个 K/V 头(32:1 比率)

在 GQA 中,多个查询头共享相同的 K/V 头。例如,当使用 32 个查询头和 4 个 K/V 头时:

  • 查询头 0-7 使用 K/V 头 0
  • 查询头 8-15 使用 K/V 头 1
  • 查询头 16-23 使用 K/V 头 2
  • 查询头 24-31 使用 K/V 头 3

这在推理期间将 K/V 缓存大小减少了 8 倍,对提供长上下文模型具有重要意义。LlamA 2、Llama 3、Mistral 和 Qwen 等现代 LLM 广泛采用 GQA。

在 Flash Attention 的实现中,每个 CUDA 块负责计算一个查询头的注意力,同时会加载对应的共享 K/V 头:

head_idx = bid_y % num_heads              # Which query head (0-31)
kv_head_idx = head_idx // query_group_size # Which K/V head (0-3)

查询组大小为 8 时,查询头 0-7 均映射至 kv_head_idx = 0,在内存中共享相同的 K/V 图块。

第 1 部分:CUDA Tile 中的 Flash 注意力内核

我们来逐步实现 Flash Attention。基准版本采用较小的 64 × 64 图块和简洁的代码——虽然正确,但尚未优化。

1. 定义内核接口

在 cuTile 中,@ct.kernel 装饰器用于将 Python 函数标记为 GPU 内核。我们通过 ct.Constant[T] 类型标注来传递编译时常量:

import math
import cuda.tile as ct

# Type aliases for compile-time constants
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]

# Conversion factor: we use exp2 instead of exp for efficiency
INV_LOG_2 = 1.0 / math.log(2)

@ct.kernel()
def fmha_kernel(
    Q, K, V, Out,              # Input/output tensors
    qk_scale: float,           # Scale factor (1/sqrt(d))
    input_pos: int,            # Position offset for causal masking
    TILE_D: ConstInt,          # Head dimension (for example, 128)
    H: ConstInt,               # Number of attention heads
    TILE_M: ConstInt,          # Tile size for Q dimension (for example, 64)
    TILE_N: ConstInt,          # Tile size for K/V dimension (for example, 64)
    QUERY_GROUP_SIZE: ConstInt,# For Grouped Query Attention (GQA)
    CAUSAL: ConstBool,         # Whether to apply causal mask
    EVEN_K: ConstBool,         # Whether K length is divisible by TILE_N
):

2. 块 ID 映射

每个 CUDA 块计算一个输出图块。通过使用 ct.bid ,我们将二维网格映射到批量/头索引:

# Get block indices
    bid_x = ct.bid(0)  # Which tile along the sequence dimension
    bid_y = ct.bid(1)  # Which batch-head combination
    
    # Decode batch and head from flattened index
    batch_idx = bid_y // H
    head_idx = bid_y % H
    
    # For Grouped Query Attention: multiple Q heads share one K/V head
    off_kv_h = head_idx // QUERY_GROUP_SIZE

3. 初始化累加器

在主循环开始前,我们先初始化在线 softmax 状态和输出累加器:

# Convert scale for base-2 exponential (faster than natural exp)
    qk_scale = qk_scale * INV_LOG_2
    
    # Create position indices for this tile
    offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32)
    offs_m += input_pos
    offs_m = offs_m[:, None]  # Shape: [TILE_M, 1]
    
    offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)
    offs_n_tile = offs_n_tile[None, :]  # Shape: [1, TILE_N]
    
    # Online softmax state (float32 for numerical stability)
    m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)  # Running max
    l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)        # Running sum
    acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)   # Output accumulator

我们将 float32 用于累加器,即使输入为 float16,也能在迭代的 softmax 计算过程中保持数值精度。

4. 加载查询图块

查询图块加载一次,并在所有键/值迭代中重复使用:

    # Load Q tile: shape [1, 1, TILE_M, TILE_D] -> [TILE_M, TILE_D]
    q = ct.load(
        Q, 
        index=(batch_idx, head_idx, bid_x, 0), 
        shape=(1, 1, TILE_M, TILE_D)
    ).reshape((TILE_M, TILE_D))

当图块延伸至张量边缘时, ct.load 函数会自动处理边界条件。

5. K/ V 图块的主循环

这是 Flash Attention 的核心:我们将迭代 K/V 图块。

   # Calculate loop bounds
    m_end = input_pos + (bid_x + 1) * TILE_M
    k_seqlen = K.shape[2]
    
    if CAUSAL:
        # For causal attention, stop early (future tokens are masked)
        Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
    else:
        Tc = ct.cdiv(k_seqlen, TILE_N)
    
    for j in range(0, Tc):
        # --- Step A: Load Key tile and compute QK^T ---
        k = ct.load(
            K,
            index=(batch_idx, off_kv_h, 0, j),
            shape=(1, 1, TILE_D, TILE_N),
            order=(0, 1, 3, 2),  # Transpose for correct layout
            latency=2            # Hint for memory prefetching
        ).reshape((TILE_D, TILE_N))
        
        # Matrix multiply: Q @ K^T
        qk = ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32)
        qk = ct.mma(q, k, qk)  # Uses Tensor Cores automatically

参数中的 order=(0,1,3,2) 告知 cuTile 加载运算时使用 K 的转置,而 latency=2 表示我们可以容忍一定程度的延迟以实现更优的流水线化。随后,我们通过 ct.mma=(q, k, k,qk) 执行 cuTile 矩阵乘积累加运算。

6. 应用因果掩码

用于自回归模型(如 GPT、Llama 等),每个 token 只能处理其之前的 tokens:

# --- Step B: Apply causal masking ---
        if CAUSAL or not EVEN_K:
            offs_n = j * TILE_N + offs_n_tile
            mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
            
            # Boundary mask (for non-divisible sequence lengths)
            if not EVEN_K:
                mask = mask & (offs_n < k_seqlen)
            
            # Causal mask: query position >= key position
            if CAUSAL:
                mask = mask & (offs_m >= offs_n)
            
            # Convert to additive mask: True->0, False->-inf
            mask = ct.where(mask, 0.0, -math.inf)
            qk += mask

-inf 添加至遮罩位置,可确保其在 softmax 运算后变为零。

7. 在线 Softmax 更新

现在,我们更新正在运行的 softmax 统计数据:

   # --- Step C: Online softmax ---
        # Find max in current tile
        qk_max = ct.max(qk, axis=-1, keepdims=True)
        qk_max_scaled = qk_max * qk_scale
        
        # Update running maximum
        m_ij = max(m_i, qk_max_scaled)
        
        # Scale QK scores
        qk = qk * qk_scale
        qk = qk - m_ij
        
        # Compute attention weights (using exp2 for speed)
        p = ct.exp2(qk)
        
        # Update running sum
        l_ij = ct.sum(p, axis=-1, keepdims=True)
        alpha = ct.exp2(m_i - m_ij)  # Correction factor
        l_i = l_i * alpha
        l_i = l_i + l_ij
        
        # Rescale previous accumulator
        acc = acc * alpha

8. 累积输出

最后,加载“Value” (值) 图块并累加:

# --- Step D: Load V and accumulate ---
        v = ct.load(
            V,
            index=(batch_idx, off_kv_h, j, 0),
            shape=(1, 1, TILE_N, TILE_D),
            latency=4
        ).reshape((TILE_N, TILE_D))
        
        # Cast attention weights back to input dtype for Tensor Core MMA
        p = p.astype(Q.dtype)
        
        # Accumulate: acc += P @ V
        acc = ct.mma(p, v, acc)
        
        # Update max for next iteration
        m_i = m_ij

9. 最终规范化和存储

处理完所有图块后,我们依据总和进行归一化处理,并将结果写入。

   # --- Final: Normalize and store ---
    acc = ct.truediv(acc, l_i)
    acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)

启动核函数:主机侧代码

现在,我们来看一下启动核函数的主机端代码:

import torch
from math import ceil

def tile_fmha(q, k, v, sm_scale=None, is_causal=True):
    """
    Launch the Flash Attention kernel.
    
    Args:
        q: Query tensor, shape [batch, heads, seq_len, head_dim]
        k: Key tensor, shape [batch, kv_heads, seq_len, head_dim]
        v: Value tensor, shape [batch, kv_heads, seq_len, head_dim]
        sm_scale: Softmax scale (default: 1/sqrt(head_dim))
        is_causal: Whether to apply causal masking
    
    Returns:
        Output tensor, same shape as q
    """
    if sm_scale is None:
        sm_scale = 1.0 / math.sqrt(q.size(-1))
    
    batch_size, num_heads, seq_len, head_dim = q.shape
    _, num_kv_heads, _, _ = k.shape
    
    # Calculate query group size for GQA
    query_group_size = num_heads // num_kv_heads
    
    # Ensure contiguous memory layout
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    
    # Allocate output
    o = torch.empty_like(q)
    
    # Choose tile sizes (we'll optimize this later!)
    TILE_M, TILE_N = 64, 64
    
    # Calculate grid dimensions
    grid_x = ceil(seq_len / TILE_M)  # Number of tiles along sequence
    grid_y = batch_size * num_heads  # One block per batch-head pair
    grid = (grid_x, grid_y, 1)
    
    # Check if K length is evenly divisible
    EVEN_K = (k.shape[2] % TILE_N) == 0
    
    # Launch kernel
    ct.launch(
        torch.cuda.current_stream(),
        grid,
        fmha_kernel,
        (q, k, v, o, sm_scale, 0, head_dim, num_heads,
         TILE_M, TILE_N, query_group_size, is_causal, EVEN_K)
    )
    
    return o

使用 64 * 64 个图块时,此基准可以正常运行。但能否进一步提升速度?让我们来探究一下。

第 2 部分:“陷阱和救援”优化之旅

我们基于以下配置进行基准测试:

  • 硬件:NVIDIA B200
  • 批量:4,打印头:32,打印头尺寸:128
  • 注意力机制:因果型,数据类型:FP16
  • 序列长度:1024、2048、4096、8192、16384

为解释每个步骤,我们使用带有极小截面集的 Nsight Compute:

  • LaunchStats
  • Occupancy
  • SpeedOfLight
  • ComputeWorkloadAnalysis
  • MemoryWorkloadAnalysis

基准性能

SeqLen 吞吐量 (TFLOPS)
1024 330
2,048 441
4,096 511
8,192 546
16,384 566
表 1. 基准性能,未进行任何特定优化

这是我们的起点,由 64 * 64 个图块组成,尚未进行优化。

NCU 见解 ( SeqLen 1024、B200):

  • 寄存器/线程:128
  • 理论/实际占用率:25%/19.8%
  • 计算 (SM) 吞吐量:37.8%
  • 显存吞吐量:19.7%
  • 网格大小:2048

1. 大牌的陷阱

GPU 编程的一种常见直觉是“图块越大,性能越高。”图块越大,计算密度越高,内存访问效率也往往随之提升。

  • 摊销内存访问开销,
  • 提升二级缓存利用率,
  • 降低每个元素的内核启动开销。

我们将图块大小从 64* 64 调整为 256* 128

TILE_M, TILE_N = 256, 128  # Was 64, 64

预期是更高的内存带宽利用率 → 更快的性能。然而,以 TFLOPS 计算的结果却是

SeqLen Baseline (64*64) 使用更大的图块 (256*128) 性能下降
1024 330 187 -43%
2048 441 268 -39%
4096 511 347 -32%
8192 546 415 -24%
16384 566 463 -18%
表 2. 基准性能与较大图块尺寸的性能对比,显示采用较大图块尺寸时性能有所下降

所有序列长度的性能都会降低 18-43%。这是陷阱,使用大块图块会导致性能 下降

为什么会出现这种情况?

  1. 计算瓶颈: 随着每个图块中的元素不断增多,低效运算(独立的复加运算、精确的数学运算)逐渐成为性能瓶颈。
  2. 指令用度: 每个图块承担的任务越多,执行下一次内存操作前所需的指令也越多。

课程: 图块大小与计算效率相互关联,只有在计算效率足够高时,采用较大的图块才能发挥作用。

NCU 见解 ( SeqLen™ 1024、NVIDIA B200):

  • 寄存器/ 线程跳转至 168 (~ 31%) ,理论占用率降至 18.75%
  • 实际占用率降至 16.5%
  • 计算吞吐量降至 17.4% (存在陷阱)
  • 显存吞吐量降至 7.4%
  • 网格大小缩小至 512 (图块越大,对应图块尺寸越小)

2. 借助快速数学运算进行救援

其中一个瓶颈是特殊函数:exp2(指数运算)和 truediv(真除法)。默认情况下,这些运算采用 IEEE-754 精度,精度较高,但运算速度较慢。

对于深度学习,我们可以通过微小的精度损失来实现显著的加速:

之前 (精确运算) :

p = ct.exp2(qk)
alpha = ct.exp2(m_i - m_ij)
acc = ct.truediv(acc, l_i)

After (快速数学运算) :

p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)

这些标志的作用:

  • flush_to_zero=True : Denormal numbers (extremely small values near zero) become exactly zero. This avoids slow microcode paths on the GPU.
  • rounding_mode=RMd.APPROX : Skips iterative refinement after initial hardware approximation.

借助快速数学运算,我们成功优化了大型图块,TFLOPS 结果如下:

SeqLen Larger tile (trap) Fast math (rescue) Improvement
1024 187 322 +72%
2048 268 436 +63%
4096 347 524 +51%
8192 415 585 +41%
16384 463 620 +34%
表 3. 采用两种快速数学运算优化后的性能提升

现在我们已达到或超越小图块基准,序列更长时可获得 10-20% 的提升

NCU 见解 ( SeqLen™ 1024、NVIDIA B200):

  • 寄存器/线程:168 个(保持不变)
  • 理论/实际占用率:18.75% / 16.6%(保持不变)
  • 计算吞吐量回升至 24.0%
  • 显存吞吐量提升至 12.9%

3. K 环拆分

为引起因果注意,我们采用了一个三角形掩码:每个查询仅能处理此前位置的键。在基准测试中,我们会在每次循环迭代时检查 if CAUSAL: mask

但想一想:对于位置为 1000 的查询图块,多数关键图块 (0-900) 根本不需要任何掩码。只有靠近对角线的图块需要遮罩。查询位置以外的图块被完全遮罩(我们可以完全跳过它们)。 无需任何掩码。只有靠近对角线的图块需要遮罩。查询位置以外的图块被 完全遮罩(我们可以完全跳过它们)。

Q by K tiled causal attention matrix showing 8 tiles per side and showing how the lower triangle is computed. The diagonal is partially computed, and the upper triangle is skipped.
图 3. 图块因果注意力矩阵(每边 8 个图块)

该优化将循环分为几个阶段:

# Calculate where masking starts being necessary
mask_start = (input_pos + bid_x * TILE_M) // TILE_N
mask_start = min(mask_start, k_seqlen // TILE_N)

# Calculate where to stop (for causal, we exit early)
if CAUSAL:
    Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
else:
    Tc = ct.cdiv(k_seqlen, TILE_N)

for j in range(0, Tc):
    # Load K and compute QK...
    
    # ONLY apply masking when necessary
    if (CAUSAL or not EVEN_K) and j >= mask_start:
        offs_n = j * TILE_N + offs_n_tile
        mask = ct.full((TILE_M, TILE_N), True, dtype=ct.bool_)
        if not EVEN_K:
            mask = mask & (offs_n < k_seqlen)
        if CAUSAL:
            mask = mask & (offs_m >= offs_n)
        mask = ct.where(mask, 0.0, -math.inf)
        qk += mask
    
    # Continue with softmax and accumulation...

为何如此重要: 对于长度为 16K 的序列,若每个图块包含 256 个 token:

  • ・50% 的图块完全取消掩码(无分支,无掩码计算)
  • 每行有 1 个图块被部分掩码(完整逻辑)
  • 其余图块则被完全跳过(提前退出)

结果为 TFLOPS:

SeqLen 快速数学运算 循环分割 优化
1024 322 373 +16%
2048 436 552 +27%
4096 524 684 +31%
8192 585 770 +32%
16384 620 813 +31%
表 4. 采用 K 循环拆分优化后的性能提升

这是单次优化中幅度最大的一次,所有序列长度的速度提升高达 32%。

NCU 见解 ( SeqLen 1024、B200):

  • 寄存器/线程:168 个(保持不变)
  • 理论/实际占用率:18.75% / 16.6%(保持不变)
  • 显存吞吐量提升至 14.5%(减少了工作浪费)
  • 计算吞吐量仍为 24.0%(任务效率提高,但每个周期的执行速度未必更快)

4. ProgramId 重映射

其中一个细微的优化是反转因果注意力的分块顺序。当从右下角向左上角反向处理图块时,由于因果掩码的作用,后启动的方块所需计算量较小,从而改善了负载均衡,减轻了尾部效应。

之前 (标准顺序):

bid_x = ct.bid(0)  # Process tiles 0, 1, 2, ...

之后 (因果反向) :

if CAUSAL:
    bid_x = NUM_M_BLOCKS - 1 - ct.bid(0)  # Process tiles N, N-1, N-2, ...
else:
    bid_x = ct.bid(0)

这一微小的调整优化了波形调度,使 GPU 上的线程块完成进度更加一致。

结果为 TFLOPS:

SeqLen Loop split Remapping Improvement
1024 373 377 1
2048 552 560 1.5
4096 684 696 1.8
8192 770 781 1.5
16384 813 835 2.6
表 5. 重新映射图块的块顺序后,性能有所提升

适度但一致的 1-3% 的增益,在尾部效果尤为明显的更长序列中表现突出。

5. 自动调整

我们优化了大图块,但存在一个问题:短序列仍然倾向于使用小图块

为什么?当使用 1024-token 序列和 256-token 图块时,我们仅有 4 个图块,这不足以在 B200 上充分占用所有 SM。而较小的图块(64 × 64)可生成 16 个图块,从而更有效地利用 GPU 资源。

我们可以让 cuTile 的自动调整器 对多个配置进行基准测试,并为每个输入形状缓存较优配置,而不是手动选择值。

自动调整器方法

def _fmha_autotune_configs():
    """Search space for autotuning.

    The autotuner will benchmark these configurations and cache the best one
    per input shape (sequence length, batch size, etc.).
    """
    gpu_capability = torch.cuda.get_device_capability()

    if gpu_capability in [(12, 0), (12, 1)]:
        # RTX 50 series (sm120, sm121)
        yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
    else:
        # B200/GB200 (sm100) - Try multiple tile sizes
        # Autotuner will discover:
        # - 64x64 is best for short sequences (1024-2048)
        # - 128x128 may be best for medium sequences (4096)
        # - 256x128 is best for long sequences (8192+)
        yield SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=2)
        yield SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2)
        yield SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=1)

如何通过自动调整启动:

使用 ct_experimental.autotune_launch 而不是直接调用 ct.launch

import cuda.tile_experimental as ct_experimental

def autotune_launch_fmha(
    stream, q, k, v, o, sm_scale, input_pos,
    hidden_size, num_heads, query_group_size, is_causal
):
    batch_size, _, q_len, _ = q.shape

    def _grid_fn(cfg):
        return (math.ceil(q_len / cfg.TILE_M), batch_size * num_heads, 1)

    def _args_fn(cfg):
        num_m_blocks = math.ceil(q_len / cfg.TILE_M)
        even_k = (k.shape[2] % cfg.TILE_N) == 0
        return (
            q, k, v, o, sm_scale, input_pos,
            hidden_size, num_heads, cfg.TILE_M, cfg.TILE_N,
            query_group_size, is_causal, even_k, num_m_blocks,
        )

    ct_experimental.autotune_launch(
        stream,
        grid_fn=_grid_fn,
        kernel=fmha_kernel,
        args_fn=_args_fn,
        hints_fn=lambda cfg: {"num_ctas": cfg.num_ctas, "occupancy": cfg.occupancy},
        search_space=_fmha_autotune_configs,
    )

注意:autotuner API 可能会发生变化。

自动调整器以智能方式工作:

  1. 使用 seq_len = 1024 进行首次调用:对全部 3 个配置执行基准测试,缓存表现较优的配置之一
  2. 首次调用 seq_len = 2048:对全部 3 个配置执行基准测试,缓存表现较优的配置之一
  3. 后续调用:直接采用缓存配置(零开销)

缓存键包含张量形状,因此不同序列长度会自动匹配不同的优化配置。

结果为 TFLOPS:

SeqLen Baseline Remapping Autotune Speedup 与基准
1024 330 377 548 1.66 x
2048 441 560 708 1.61 x
4096 511 696 817 1.60 x
8192 546 781 887 1.62 x
16384 566 835 918 1.62 x
表 6. 原始基准与第 5 步和第 6 步自动调整结果的比较

自动调优器会发现 64* 64 图块适合序列 2048,随后切换到更大的图块以处理更长的序列。相比固定的大型图块,这种方法可将短序列的性能提升 45%,同时在长序列情况下保持峰值性能。

自动调整器的选择 (在 B200 上):

  • SeqLen 1024:64 × 64 图块(高并行度)
  • SeqLen 2048:64 × 64 或 128 × 128 图块(兼顾性能与效率)
  • SeqLen 4096 及以上:128 × 128 或 256 × 128 图块(显存效率优先)

现在,我们无需手动调整,便可在所有序列长度上实现优异性能。

摘要:优化堆栈

优化 关键见解 影响
基准 (64* 64) 正确但未优化 基准
大型图块 (256* 128) 捕获点:速度较慢的 18–43%! -18% 到 -43%
+ 快速数学运算 (FTZ,APPROX) 提供显著改善:大型图块现可从陷阱 +34% 到 +72% 从陷阱
+ K 循环拆分 单次优化幅度最大 +16% 到 +32%
+ 程序 ID 重新映射 实现更佳负载均衡 +1% 到 +3%
+ 自动调整 每序列最优图块配置 +10% 到 +45%
表 7. 逐步优化结果,每一步骤均对性能产生影响

最终加速:在所有序列长度上提升至 1.60 x-1.66 倍。

开始使用

编写高性能内核时,通常难以找到单一的“魔术”设置。正如我们在“陷阱和救援”中所见:

  1. 优化是相互依赖的:在我们修复数学运算之前,大型图块的速度较慢。图块大小无法被孤立地评估。
  2. 数学至关重要flush_to_zeroAPPROX 等标志对于释放 Tensor Core 吞吐量具有关键作用。在深度学习中,精确的数学运算往往过于简单。
  3. 算法优势可叠加:K 循环拆分通过消除冗余计算,带来了显著的单次性能提升(最高达 32%)。
  4. 自动调整优于手动启发式策略:cuTile 的自动调整器能针对不同序列长度动态选择最优图块配置(短序列采用 64×64,长序列采用 256×128),相比固定配置可获得 10% 至 45% 的性能增益。
  5. 累积效应呈倍数增长:完整的优化组合在所有序列长度上实现了 1.60 倍至 1.66 倍的加速,远超单一优化所能达到的效果。

cuTile 使开发者能够在简洁、可读的 Python 代码中表达这些优化(平铺、快速数学控制、循环分割、自动调整),同时为 NVIDIA GPU 生成高度优化的 PTX。

您可以在 TileGym 资源库 中找到经过全面优化的内核。祝使用愉快。

 

 

标签