模型训练的显存占用分布
训练过程中,显存消耗主要有模型参数、梯度、optimizer状态值和中间激活值。
- 模型参数:词表embedding部分占大头,与输入序列长度无关
- 梯度:每个参数对应有一个梯度
- 优化器状态值:每个参数有一个对应梯度,每个参数又对应优化器一个一阶动量和二阶动量
- 激活值:保存激活值是为了计算梯度,因此每个矩阵相乘、softmax、dropout都需要保存输入值的中间的激活值。与输入序列长度呈正相关