开发与优化

通过 Python API 利用 CuTe DSL 实现 CUTLASS C++ 级性能

CuTe,是 CUTLASS 3.x 的核心组件,它提供了统一的代数体系,用于描述数据布局和线程映射,并将复杂的内存访问模式抽象为可组合的数学运算。

尽管 CUTLASS 3.x 与 CuTe 使内核开发者能够通过直观的抽象在 Tensor Core 上实现高性能,但其广泛使用 C++ 模板导致编译时间过长。与此同时,Python 及即时(JIT)编译技术在研究与生产级生成式 AI 工作流中的应用日益广泛,进一步推动了 CUTLASS 4 的演进与发展。

本文将介绍使用 CuTe DSL 的优势。我们将展示其提供的一致性 C++ API,能够在不同 GPU 芯片上实现相近的 Tensor Core 利用效率,同时具备比传统 C++ 更低的编译开销。

有关 CuTe 和 CUTLASS 3.x 基础知识的更多详细信息,可参考CUTLASS:通过张量和空间微核处理多维数据的原理抽象以及CUTLASS 3.x:用于 GEMM 内核设计的正交、可复用且可组合的抽象

CuTe DSL:CUTLASS 4 的基础

CUTLASS 4 中新推出的 CuTe DSL(测试版)将 CuTe 的强大功能引入 Python 编程环境,使开发者能够以更简便的方式编写低级 GPU 内核,而无需依赖复杂的 C++ 模板元编程。

为简化与新 DSL 相关的学习曲线,CuTe DSL 延续了 CuTe 的核心设计理念。更多示例(包括密集 GEMM 的持久化变体分组 GEMM 以及融合多头注意力 FMHA)可访问 GitHub 上的 NVIDIA/cutlass 项目查看。

CuTe DSL 和 CuTe C++ 的比较

CuTe 凭借其可靠的布局表示和代数,在十多年的 NVIDIA GPU 架构中始终提供一致的 GPU 编程模型。CuTe DSL 在保留用户从 C++ 版本 CuTe 所熟悉的完整编程模型的同时,还具备 Python 的易用性。这不仅显著提升了编译速度,大幅改善了错误提示信息,也使学习过程更加顺畅,并能够快速无缝地集成到 Python 原生的深度学习框架中。

C++ 与 DSL 代码的并行对比表明,二者在编程模型和编程模式上具有一致性,唯一的差异在于 C++ 和 Python 的语言语法。

TiledMMA

cute::TiledMma is a spatial microkernel that describes the tiling and permutations of any hardware MMA atom across a set of “threads” and data. Its representation enables writing canonical triple for loops for any hardware MMA, be it SIMT FP64 or the cutting-edge NVFP4 Blackwell tensor core instructions.

auto tiled_mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, TC,
                                         128, 128,
                                         UMMA::Major::MN, UMMA::Major::MN>{},
                           Layout<Shape<_1,_1>>{});
 
// Allocate "fragments" -- these are actually umma tmem and smem descriptors
 Tensor tCrA = tiled_mma.make_fragment_A(sA);  // (MMA,MMA_M,MMA_K,PIPE)
 Tensor tCrB = tiled_mma.make_fragment_B(sB);  // (MMA,MMA_M,MMA_K,PIPE)
  
 // Allocate TMEM
 Tensor tCtC = tiled_mma.make_fragment_C(tCgC);// (MMA,MMA_M,MMA_N)
 for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
   static_assert(size<2>(tCrA) == size<2>(tCrB), "A and B contraction modes do not match!");
   gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCtC)
 }

# Construct a tiled_mma item
atom = tcgen05.MmaF16BF16Op(
        io_dtype,
        acc_dtype,
        mma_inst_shape_mnk, #(128, 128, 64)
        tcgen05.CtaGroup.ONE,
        tcgen05.OperandSource.SMEM,
        tcgen05.OperandMajorMode.K,
        tcgen05.OperandMajorMode.K,
    )
tiled_mma = cute.make_tiled_mma(atom)

tCrA = tiled_mma.make_fragment_A(sA)   # (MMA, MMA_M, MMA_K,PIPE)
tCrB = tiled_mma.make_fragment_B(sB)   # (MMA, MMA_N, MMA_K,PIPE)
tCtC = tiled_mma.make_fragemnt_C(tCgC) # (MMA, MMA_M, MMA_N)

for k_block_idx in cute.size(tCrA, mode = 2):
   assert(cute.size(tCrA, mode = 2) == cute.size(tCrB, mode = 2), "A and B contraction modes do not match!");
    cute.gemm(
        tiled_mma, tCtC, tCrA[None, None, k_block_idx], tCrB[None, None, k_block_idx], tCtC)

TiledCopy

典型的 cute::copy 包含一个单循环,用于发出数据移动指令,将一个张量复制到另一个张量,并利用张量的布局信息来描述传输过程中可能发生的转置或排列操作。cute::TiledCopy 则是一种用于表示和验证任意两个张量之间经过优化的数据传输是否适用的类型。

例如,无论是否结合布局转换(如转置),均可在不同显存空间(如全局内存与共享内存之间)使用任何硬件加速的复制原子操作。

using TMEM_LOAD = typename std::conditional<sizeof(TC) == 4, SM100_TMEM_LOAD_16dp256b1x, SM100_TMEM_LOAD_16dp256b1x_16b>::type;
// tCtC are accumuator layout
 auto tiled_ldtm = make_tmem_copy(TMEM_LOAD{}, tCtC);
 auto thr_ldtm   = tiled_ldtm.get_slice(threadIdx.x);
 Tensor tDtC = thr_ldtm.partition_S(tCtC);  // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
 Tensor tDgC = thr_ldtm.partition_D(tCgC);  // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
 Tensor tDrC = make_tensor<TC>(shape(tDgC));// ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
// TMEM_LOAD
copy(tiled_ldtm, tDtC, tDrC);
# Construct a tensor memory to register memory (T2R) tiled_copy item
# tCtACC are accumulator tensor, layout as (MMA, MMA_M, MMA_N)
# tCgC is the partitioned results (MMA, MMA_M, MMA_N, RestM, RestN, RestL) of global tensor C (M, N)
copy_atom = cute.make_copy_atom(
  tcgen05.Ld32x32bOp(tcgen05.Repetition.x128, tcgen05.Pack.NONE),
  cutlass.Float32)
tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom, tCtACC)
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
# This is tensor memory layout (T2R_M, T2R_N, EPI_M, EPI_N)
tT2R_tAcc = thr_copy_t2r.partition_S(tCtACC)
# (T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL)
tT2R_gC = thr_copy_t2r.partition_D(tCgC)
# Construct register memory layout from the partitioned global tensor 
tT2R_rAcc = cute.make_fragment(
   tT2R_gC[None, None, None, None, 0, 0, 0].shape, cutlass.Float32)
cute.copy(tiled_copy_t2r, tT2R_tAcc, tT2R_rAcc)

CuTe DSL 在多代 GPU 上的性能表现

推动 CUTLASS C++ 在训练和推理框架中广泛应用的关键因素之一,在于其能够提供卓越的性能表现。CuTe DSL 目前已可实现接近同等水平的性能,且相关优化工作仍在持续进行中。

此外,CUTLASS 3 及其底层 CuTe 已在前几代 GPU 硬件的研究与生产场景中得到应用。由于部署的 GPU 硬件在生产环境中通常具有较长的使用寿命,即使在异构环境下亦是如此,因此 CuTe DSL 在发布时便支持从 Ampere 到 Blackwell 的多代 NVIDIA GPU,以满足这些实际部署需求。

NVIDIA Blackwell 性能

我们评估了三个关键运算的性能:密集GEMM、分组GEMM,以及基于CUTLASS C++和CuTe DSL实现的FMHA。总体来看,CuTe DSL的性能与CUTLASS C++相当。

密集 GEMM

我们在 float16 和 float8 e4m3 两种精度下测试了密集 GEMM 的性能,累加过程均采用 float32 精度。

图1展示了在NVIDIA DGX B200平台上,使用CuTe DSL实现的稠密GEMMNVIDIA/cutlass GitHub仓库中CUTLASS 3库的稠密GEMM进行的基准测试对比。横轴表示测试的问题规模,纵轴表示通过NVIDIA/cutlass采集的Tensor Core数学计算吞吐效率通过NVIDIA Compute Nsight。

对于规模较小的 GEMM-K 问题(K = 512),DSL 内核的当前执行速度仍慢于 C++。这主要是因为在进入内核数学计算之前,同步操作的开销较大、效率较低,相关优化工作正在积极推进中。

Bar chart titled ‘B200 Dense GEMM Math Throughput Efficiency %’ showing that DSL performance is on par with C++ except for the small K cases.
Bar chart titled ‘B200 Dense GEMM Math Throughput Efficiency %’ showing that DSL performance is on par with C++ except for the small K cases.
图 1。基于 NVIDIA DGX B200 平台,对 CuTe DSL 密集矩阵乘法(GEMM)与 CUTLASS 3 密集 GEMM 进行基准性能对比。

分组 GEMM

比较基于 CuTe DSL 分组 GEMMCUTLASS 3 分组 GEMM 的基准测试结果来自 NVIDIA/cutlass GitHub 仓库。

Bar chart titled ‘B200 Float16 I/O Group GEMM Math Throughput Efficiency %’ showing that DSL performance is on par with C++.
Bar chart titled ‘B200 Float16 I/O Group GEMM Math Throughput Efficiency %’ showing that DSL performance is on par with C++.
图 2。在 NVIDIA DGX B200 上对 Float16 I/O 的 CuTe DSL Group GEMM 与 CUTLASS 3 Group GEMM 进行基准测试比较

融合多头注意力 (FMHA)

基准测试采用 CuTe DSL FMHACUTLASS 3 FMHA NVIDIA/cutlass GitHub 仓库进行比较。

Bar chart titled ‘B200 Float16 I/O Flash Attention Math Throughput Efficiency %’ showing that DSL performance is on par with C++.
Bar chart titled ‘B200 Float16 I/O Flash Attention Math Throughput Efficiency %’ showing that DSL performance is on par with C++.
图3:在NVIDIA DGX B200上对采用Float16 I/O的CuTe DSL闪存注意力与CUTLASS 3闪存注意力进行基准测试比较

Ampere 性能:Dense GEMM

基准测试采用 NVIDIA/cutlass GitHub 仓库中的 CuTe DSL 密集 GEMM(Ampere)CUTLASS 3 密集 GEMM(Ampere) 进行对比。

Bar chart titled ‘A100 FP16 I/O Dense GEMM Math Throughput Efficiency %’ showing that DSL performance is slightly slower than C++ and perf gaps to be investigated.
Bar chart titled ‘A100 FP16 I/O Dense GEMM Math Throughput Efficiency %’ showing that DSL performance is slightly slower than C++ and perf gaps to be investigated.
图 4。在 NVIDIA A100 上对使用 Float16 I/O 的 CuTe DSL 密集 GEMM 与 CUTLASS 3 密集 GEMM 进行基准测试比较。

缩短编译时间

CuTe DSL 使内核开发者能够利用 CuTe 抽象来即时编译内核,从而有效缓解 C++ 模板带来的高编译时间问题。

如图5所示,编译时间显著缩短,平均可减少达两个数量级。这不仅有助于内核开发者快速尝试更多图块大小和布局形状,高效确定最优配置以实现更优性能,还能有效缩短PyTorch Inductor自动调优功能的整体耗时。

Blackwell 上的 GEMM 编译速度相比 C++ 提升了约 100 倍,而基于 Blackwell 的闪光注意力技术可使编译速度提升 30 至 50 倍。

Bar chart showing compilation time on B200 is far faster than C++ with speedups of 33 -116x. 
Bar chart showing compilation time on B200 is far faster than C++ with speedups of 33 -116x.
图 5。NVIDIA Blackwell 的编译速度远超 C++。

轻松集成 DL 框架

在 DLPack 协议的支持下,CuTe DSL 能够直接将主流深度学习框架中的张量数据作为输入,并将其转换为 cute 张量,而无需复制底层内存。

CuTe DSL 的 Python 原生接口使深度学习框架能够直接嵌入自定义内核,无需编写繁琐的胶水代码,也无需深入掌握 CUDA C++。这使得研究人员和工程师可以在现有模型工作流中快速完成自定义线性代数内核的原型设计与部署,显著缩短开发周期。

DSL 的可组合布局抽象能够简化对复杂内存和线程映射的表达,对于在 NVIDIA Ampere、Hopper 和 Blackwell 架构之间高效利用 Tensor Core 硬件具有关键作用。

开始使用 CuTe DSL

CuTe DSL 引入了全新的编程接口,在保持 CUTLASS C++ 高性能的同时,显著提升了开发效率。通过查阅快速入门指南,您可以深入了解如何构建高性能计算内核。此外,欢迎将您开发的内核贡献至CUTLASS GitHub 仓库,共同丰富示例库。

首先,下载 CUTLASS 并仔细阅读其官方文档。同时加入 NVIDIA 开发者论坛,参与更深入的技术交流与讨论。

致谢

我们衷心感谢所有 CUTLASS 开源项目贡献者。正是得益于他们的基础性贡献,CUTLASS 4 才得以实现。

 

标签