JAX 优化方案
将 JAX 优化以运行在 Spark 上
基本思路
JAX 允许您编写 NumPy 风格的 Python 代码,并在 GPU 上快速运行,而无需编写 CUDA 代码。它通过以下方式实现:
在 GPU 上使用 NumPy:jax.numpy 就像 NumPy 一样,但数组存储在 GPU 上。
- 函数转换:
jit→ 将函数编译成快速的 GPU 代码
grad→ 提供自动微分
vmap→ 将函数向量化并跨批次执行
pmap→ 在多个 GPU 上并行
XLA 后端:JAX 将您的代码交给 XLA(加速的线性代数编译器),其会融合运算并生成优化的 GPU 内核。
您将完成
您将在 NVIDIA Spark 上搭建一个基于 Blackwell 架构的 JAX 开发环境,该环境使您能够利用熟悉的类 NumPy 抽象进行高性能机器学习原型设计,并具备完整的 GPU 加速和性能优化能力。
前置知识
熟悉 Python 和 NumPy 编程
对机器学习工作流程和技术有基本的了解
有在终端环境工作的经验
有使用和构建容器的经验
熟悉不同版本的 CUDA
具备线性代数基础(高中数学水平即可)
先决条件
采用 Blackwell 架构的 NVIDIA Spark 设备
ARM64 (AArch64) 处理器架构
安装 Docker 或其他容器运行时
已配置 NVIDIA Container Toolkit
验证 GPU 访问权限:nvidia-smi
端口 8080 可用,用于访问 marimo notebook
辅助文件
所有必要资源都可在 GitHub 找到。
JAX 介绍 notebook — 涵盖 JAX 与 NumPy 的编程模型差异及性能评估
NumPy SOM 实现 — NumPy中自组织映射训练算法的参考实现
JAX SOM 实现 — JAX 中多种经过迭代优化的 SOM 算法实现
环境配置 — 软件包依赖关系和容器设置规范
时间和风险
时长:2-3 小时,包括设置、完成教程和验证
- 风险:
Python 环境中的包依赖冲突
性能验证可能需要针对特定架构进行优化。
还原:容器环境提供了隔离性;删除容器并重新启动即可重置状态。
- 最后更新:2025 年 11 月 7 日
进行了少量文字修改
第 1 步 – 验证系统先决条件
请确认您的 NVIDIA Spark 系统满足要求并已配置 GPU 访问权限。
# Verify GPU access nvidia-smi # Verify ARM64 architecture uname -m # Check Docker GPU support docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi
如果您看到 permission denied 错误(例如尝试连接到 Docker 守护进程套接字时出现权限被拒绝的错误),请将您的用户添加到 docker 组,这样您就不需要使用 sudo 运行命令了。
sudo usermod -aG docker $USER newgrp docker
第 2 步 – 克隆 playbook 库
git clone https://github.com/NVIDIA/dgx-spark-playbooks
第 3 步 - 构建 Docker 镜像
重要提示:
该命令将下载基础镜像并在本地构建容器以支持此环境。
cd dgx-spark-playbooks/nvidia/jax/assets docker build -t jax-on-spark .
第 4 步 - 启动 Docker 容器
在 Docker 容器中运行 JAX 开发环境,并启用 GPU 支持和端口转发以访问 marimo。
docker run --gpus all --rm -it \
--shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \
-p 8080:8080 \
jax-on-spark第 5 步 - 访问 marimo 界面
连接到 marimo notebook 以开始 JAX 教程。
# Access via web browser # Navigate to: http://localhost:8080
界面将加载目录显示和对 marimo 的简要介绍。
第 6 步 - 完成 JAX 入门教程
请学习入门材料,了解 JAX 编程模型与 NumPy 的差异。
前往并完成 JAX 入门 notebook ,内容包括:
JAX 编程模型基础知识
与 NumPy 的主要区别
绩效评估技术
第 7 步 - 实现 NumPy 基准版本
完成基于 NumPy 的自组织映射 (SOM) 实现,以建立性能基准。
完成NumPy SOM notebook 中的示例,以:
理解 SOM 训练算法
使用熟悉的 NumPy 操作实现该算法
记录性能指标以供对比
第 8 步 - 优化 JAX 实现
通过不断迭代改进 JAX 实现,以期获得性能提升。
完成 JAX SOM notebook各部分:
NumPy 实现的基本 JAX 移植
性能优化的 JAX 版本
GPU 加速的并行 JAX 实现
比较所有版本之间的性能
第 9 步 - 验证性能提升
这些 notebook 将向您展示如何检查每个 SOM 训练实现的性能;您会发现 JAX 实现比 NumPy 基准性能有所提高(有些甚至快得多)。
在随机颜色数据上直观检查 SOM 训练输出,以确认算法的正确性。
第 10 步 - 后续步骤
将 JAX 优化技术应用于您自己的 NumPy 机器学习代码。
# Example: Profile your existing NumPy code python -m cProfile your_numpy_script.py # Then adapt to JAX and compare performance
尝试将您最喜欢的 NumPy 算法适配到 JAX,并测量其在 Blackwell GPU 架构上的性能提升。
通过手动 SSH 连接时可能出现的问题
错误 | 原因 | 修复 |
|---|---|---|
nvidia-smi 未找到 | 缺少 NVIDIA 驱动 | 安装适用于 ARM64 的 NVIDIA 驱动程序 |
容器无法访问GPU | 缺少 NVIDIA Container Toolkit | 安装 nvidia-container-toolkit |
JAX 仅使用 CPU | CUDA/JAX 版本不匹配 | 重新安装支持 CUDA 的 JAX |
端口 8080 不可用 | 端口已被占用 | 使用 -p 8081:8080 或者终止 8080 上的进程 |
Docker 构建中的包冲突 | 过时的环境文件 | 更新 Blackwell 的环境文件 |
注意:
DGX Spark 使用统一内存架构 (UMA),支持 GPU 和 CPU 之间的动态内存共享。由于许多应用程序仍在更新以利用 UMA,即使在 DGX Spark 的内存容量范围内,您也可能遇到内存问题。如果发生这种情况,请手动刷新缓冲区缓存:
sudo sh -c 'sync; echo 3 > /proc/sys/vm/drop_caches'