CS336: PyTorch Primitives & Resource Accounting (Lecture 2)
CS336: PyTorch Primitives & Resource Accounting (Lecture 2)
(PyTorch 原语与资源核算)
1. Introduction & Motivation (引言与动机)
- 课程目标:不仅仅是实现模型,更要关注Efficiency(效率)和Resource Accounting(资源核算)。
- Napkin Math (餐巾纸数学):在开始训练前,通过简单的数学估算所需的计算资源和时间。
- Example Question:在使用 1024 张 H100 GPU 上训练一个 70B 参数、15T Tokens 的模型需要多久?
- Key Formula:总计算量 (Flops) 。
- Hardware Constraints:
- H100 GPU 拥有 80GB HBM (High Bandwidth Memory)。
- 如果不进行优化,单张卡能容纳的最大参数量约为 40B(仅考虑参数、梯度和优化器状态,不含 Activations)。
2. Memory Accounting: Tensors (内存核算:张量)
Tensor Data Types (数据类型)
- Tensors 是深度学习中存储 Parameters, Gradients, Optimizer States, Activations 的基本单元。
- Float32 (FP32 / Single Precision):
- 默认数据类型,占用 4 bytes (32 bits)。
- 结构:1 bit 符号位,8 bits Exponent,23 bits Fraction。
- 通常用于 Parameters 和 Optimizer States 以保证数值稳定性。
- Float16 (FP16 / Half Precision):
- 占用 2 bytes (16 bits)。
- 缺点:Dynamic range(动态范围)较小,容易导致 Underflow(下溢)。
- BFloat16 (BF16 / Brain Float):
- 占用 2 bytes,但在深度学习中通常优于 FP16。
- 优势:拥有与 Float32 相同的 Exponent 位数(动态范围大),牺牲了 Fraction 精度,但这对 Deep Learning 影响较小。
- 通常用于 Matrix Multiplications 等计算过程。
- FP8:H100 支持的更低精度格式,用于进一步提升速度。
Memory Calculation Example (内存计算示例)
- 一个 的 Tensor (Float32) 占用内存: 。
3. PyTorch Internals & Tensor Views (PyTorch 内部机制与视图)
- Storage vs. View:
- Tensor 本质上是指向内存中连续数组的 Metadata(元数据)指针。
- Metadata 包含 Size 和 Stride(步长)。
- Stride 决定了在某一维度上移动索引时,在内存中需要跳过多少个元素。
- Zero-Copy Operations:
- 操作如
transpose,view, 切片通常不复制数据,而是创建新的 View(共享底层 Storage)。 - Mutation Risk:修改一个 View 会影响原始 Tensor。
- 操作如
- Contiguous:
- 某些操作(如
transpose)会导致 Tensor 在内存中不再连续 (Non-contiguous)。 - 调用
.contiguous()会强制复制数据并在内存中重新排列,这会消耗额外的 Memory。
- 某些操作(如
4. Compute Accounting: Flops (计算核算:浮点运算)
Definitions (定义)
- Flops (Floating Point Operations):计算操作的总数(小写 ‘s’)。
- Flops/s (Floating Point Operations per Second):硬件的计算速度(通常用 ‘/s’ 表示)。
Cost of Operations (运算成本)
- Matrix Multiplication (MatMul) 是深度学习中最主要的计算消耗。
- MatMul Flops Formula:
- 对于矩阵乘法 :
- 。
- 系数 2 来源于每个输出元素包含一次乘法和一次加法。
- Linear Model Forward Pass:
- 。
Model Flops Utilization (MFU)
- Definition: 。
- Benchmark:
- MFU 反映了硬件利用率,通常 。
- 系数 4 的来源:需要计算对 Weights 的梯度以及对 Inputs (Activations) 的梯度(以便传给上一层),每部分约为 Forward Pass 的 2 倍。
- Total Training Flops:
- 总计算量 = Forward () + Backward ()。
- Rule of Thumb: 。
6. Optimization & Total Memory Footprint (优化与总内存占用)
Tensor Operations with einsum
- 使用
torch.einsum或einops库代替复杂的索引操作(如x.transpose(-1, -2)),提高代码可读性和安全性。 - 示例:
einsum('b s h, b s h -> b s', x, y)直观表达了维度操作。
Optimizer Implementation
- Optimizers (如 SGD, Adam, Adagrad) 负责更新参数。
- 实现自定义 Optimizer 时,需要继承
Optimizer类并管理state。
Total Memory Components (总内存组成)
训练一个模型所需的总显存包括以下部分:
- Parameters (Weights): 个元素。
- Gradients:与 Parameters 形状相同, 个元素。
- Optimizer States:
- 取决于优化器。例如 Adagrad 需要存储梯度平方和,额外需要 个元素。
- Adam 通常需要存储一阶和二阶动量,需要 个元素。
- Activations:
- 用于反向传播计算梯度,大小为 。
- 可通过 Activation Checkpointing (重计算) 来减少显存占用,以时间换空间。
Mixed Precision Training (混合精度训练)
- Strategy:
- Parameters & Optimizer States:保持 Float32 以确保累积更新的精度。
- Forward/Backward Computation (MatMul):使用 BF16 或 FP8 提高速度并减少内存。
- PyTorch 提供自动混合精度工具 (AMP) 来管理这种转换。
7. Best Practices (最佳实践)
- Initialization:
- 使用 Xavier/Kaiming Initialization (例如除以 ) 防止 Activation 值在深层网络中爆炸或消失。
- Randomness:
- 固定 Random Seed 以确保 Debug 时的可复现性。
- Checkpointing:
- 定期保存 Model, Optimizer State 和 Iteration number,防止训练崩溃导致进度丢失。
Reference Formula Summary (参考公式汇总)
- Training Flops: (=Params, =Dataset size in tokens).
- Forward Flops: .
- Backward Flops: .
- Matrix Multiply Flops: .
- Memory (Bytes): (FP32=4, BF16=2).