JAX 优化方案

将 JAX 优化以运行在 Spark 上

基本思路

JAX 允许您编写 NumPy 风格Python 代码,并在 GPU 上快速运行,而无需编写 CUDA 代码。它通过以下方式实现:

  • 在 GPU 上使用 NumPy:jax.numpy 就像 NumPy 一样,但数组存储在 GPU 上。

  • 函数转换:
    • jit→ 将函数编译成快速的 GPU 代码

    • grad→ 提供自动微分

    • vmap→ 将函数向量化并跨批次执行

    • pmap→ 在多个 GPU 上并行

  • XLA 后端:JAX 将您的代码交给 XLA(加速的线性代数编译器),其会融合运算并生成优化的 GPU 内核。

您将完成

您将在 NVIDIA Spark 上搭建一个基于 Blackwell 架构的 JAX 开发环境,该环境使您能够利用熟悉的类 NumPy 抽象进行高性能机器学习原型设计,并具备完整的 GPU 加速和性能优化能力。

前置知识

  • 熟悉 Python 和 NumPy 编程

  • 对机器学习工作流程和技术有基本的了解

  • 有在终端环境工作​​的经验

  • 有使用和构建容器的经验

  • 熟悉不同版本的 CUDA

  • 具备线性代数基础(高中数学水平即可)

先决条件

  • 采用 Blackwell 架构的 NVIDIA Spark 设备

  • ARM64 (AArch64) 处理器架构

  • 安装 Docker 或其他容器运行时

  • 已配置 NVIDIA Container Toolkit 

  • 验证 GPU 访问权限:nvidia-smi

  • 端口 8080 可用,用于访问 marimo notebook

辅助文件

所有必要资源都可在 GitHub 找到。

时间和风险

  • 时长:2-3 小时,包括设置、完成教程和验证

  • 风险:
    • Python 环境中的包依赖冲突

    • 性能验证可能需要针对特定​​架构进行优化。

  • 还原:容器环境提供了隔离性;删除容器并重新启动即可重置状态。

  • 最后更新:2025 年 11 月 7 日
    • 进行了少量文字修改

第 1 步 – 验证系统先决条件

请确认您的 NVIDIA Spark 系统满足要求并已配置 GPU 访问权限。

# Verify GPU access
nvidia-smi

# Verify ARM64 architecture  
uname -m

# Check Docker GPU support
docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi

如果您看到 permission denied 错误(例如尝试连接到 Docker 守护进程套接字时出现权限被拒绝的错误),请将您的用户添加到 docker 组,这样您就不需要使用 sudo 运行命令了。

sudo usermod -aG docker $USER
newgrp docker

第 2 步 – 克隆 playbook 库

git clone https://github.com/NVIDIA/dgx-spark-playbooks

第 3 步 - 构建 Docker 镜像

重要提示:

该命令将下载基础镜像并在本地构建容器以支持此环境。

cd dgx-spark-playbooks/nvidia/jax/assets
docker build -t jax-on-spark .

第 4 步 - 启动 Docker 容器

在 Docker 容器中运行 JAX 开发环境,并启用 GPU 支持和端口转发以访问 marimo。

docker run --gpus all --rm -it \
    --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \
    -p 8080:8080 \
    jax-on-spark

第 5 步 - 访问 marimo 界面

连接到 marimo notebook 以开始 JAX 教程。

# Access via web browser
# Navigate to: http://localhost:8080

界面将加载目录显示和对 marimo 的简要介绍。

第 6 步 - 完成 JAX 入门教程

请学习入门材料,了解 JAX 编程模型与 NumPy 的差异。

前往并完成 JAX 入门 notebook ,内容包括:

  • JAX 编程模型基础知识

  • 与 NumPy 的主要区别

  • 绩效评估技术

第 7 步 - 实现 NumPy 基准版本

完成基于 NumPy 的自组织映射 (SOM) 实现,以建立性能基准。

完成NumPy SOM notebook 中的示例,以:

  • 理解 SOM 训练算法

  • 使用熟悉的 NumPy 操作实现该算法

  • 记录性能指标以供对比

第 8 步 - 优化 JAX 实现

通过不断迭代改进 JAX 实现,以期获得性能提升。

完成 JAX SOM notebook各部分:

  • NumPy 实现的基本 JAX 移植

  • 性能优化的 JAX 版本

  • GPU 加速的并行 JAX 实现

  • 比较所有版本之间的性能

第 9 步 - 验证性能提升

这些 notebook 将向您展示如何检查每个 SOM 训练实现的性能;您会发现 JAX 实现比 NumPy 基准性能有所提高(有些甚至快得多)。

在随机颜色数据上直观检查 SOM 训练输出,以确认算法的正确性。

第 10 步 - 后续步骤

将 JAX 优化技术应用于您自己的 NumPy 机器学习代码。

# Example: Profile your existing NumPy code
python -m cProfile your_numpy_script.py

# Then adapt to JAX and compare performance

尝试将您最喜欢的 NumPy 算法适配到 JAX,并测量其在 Blackwell GPU 架构上的性能提升。

通过手动 SSH 连接时可能出现的问题

错误
原因
修复
nvidia-smi 未找到
缺少 NVIDIA 驱动
安装适用于 ARM64 的 NVIDIA 驱动程序
容器无法访问GPU
缺少 NVIDIA Container Toolkit
安装 nvidia-container-toolkit
JAX 仅使用 CPU
CUDA/JAX 版本不匹配
重新安装支持 CUDA 的 JAX
端口 8080 不可用
端口已被占用
使用 -p 8081:8080 或者终止 8080 上的进程
Docker 构建中的包冲突
过时的环境文件
更新 Blackwell 的环境文件

注意:

DGX Spark 使用统一内存架构 (UMA),支持 GPU 和 CPU 之间的动态内存共享。由于许多应用程序仍在更新以利用 UMA,即使在 DGX Spark 的内存容量范围内,您也可能遇到内存问题。如果发生这种情况,请手动刷新缓冲区缓存:

sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'

资源

JAX 文档

DGX Spark 文档

DGX Spark 论坛