大模型中的计算和内存占用理论分析

1. 理论FLOPS测算

参考Nvidia Megatron中给的方法,我们约定如下符号:

  • l:Transformer 层数。
  • h:hidden size大小。
  • s:sequence length 。
  • V:vocabulary size 词汇表大小。
  • b:训练的batch size。
    FLOPS基本贡献都来自所涉及的矩阵乘过程,例如:矩阵 $X(a,b) * Y(b,c)$ 一共需要2abc次OPs(2 是由于乘、加单独各算一次)。

Untitled 2 2|Untitled 2 2.png

1.1 Attention模块

  • $QKV$生成:输入X为$(b,s,h)$,通过一个线性映射生成QKV时需要权重:W$(h,3h)$, 一共是 $2\times 3\times bsh^2$ , 即 ($6bsh^2$)
  • $QK^T$:一个简单的表述 $(b,s,h) * (b,h,s)$,可以认为是$2bs^2h$
  • Map V: $(b,s,s) * (b,s,h)$,可以认为是 $2bs^2h$
  • post-attention 线性映射:$(b,s,h)*(h,h)$,可以认为是$2bsh^2$
  • 总共 $8bsh^2+4bs^2h$

1.2 Feed-Forward Network (FFN)模块

在前馈网络层,主要包含两个MLP层,实现h→4h 和 4h→h的映射:

  • MLP 1: $(b,s,h) * (h,4h)$ ,可认为是 $2*4bsh^2$
  • MLP 2: $(b,s,4h) * (4h,h)$,可以认为是$2*4bsh^2$
  • 总共 $16bsh^2$

1.3 前向&反向&激活重计算

由前两部分可知,一层Transformer的前向过程 其主要 FLOPs 大致可估算为:

$$24bsh^2 + 4bs^2h$$

我们一般认为反向计算 相比正向,大致可估算为双倍的FLOPs,因为需要分别计算 输入和权重的梯度(这里的输入 其实就是我们说的激活值,因为下一层的输入 就是前一层的计算结果)。

因此,对于不包含激活值重计算的层,的理论FLOPS就是:

$$3\times(24bsh^2+4bs^2h)=72bsh^2(1+\frac{s}{6h})$$

而对于存在激活值重计算的层,可以简单认为是按照 full的方式进行重计算,因此,相当于每次反向时多进行了一次额外前向计算(略显暴力的描述),因此, 理论FLOPS就是:

$$4\times(24bsh^2+4bs^2h)=96bsh^2(1+\frac{s}{6h})$$

1.4 完整模型

我们需要考虑了将h 映射到V的过程(Embedding),即 $(b,s,h) * (h,V)$。 前向 $2bshV$,反向直接变成 $4bshV$,共$6bshV$,只有一层。

综上,我们使用$l_1$ 表示非激活重计算的层数,$l_2$表示激活重计算的层数,可以量化一个包含$l_1+l_2$层Transformer和Embeding的简化版模型的计算复杂度:

$$ l_1 \times 72bsh^2(1+\frac{s}{6h}) + l_2 \times 96bsh^2(1+\frac{s}{6h}) ~+~6bshV \\ =(3l_1+4l_2)24bsh^2(1+\frac{s}{6h})+6bshV $$

这里可以引入一个比值$\gamma = \frac{l_2}{l}$,表示使用激活重计算层数的比例(全激活则1,无激活则0),可进一步简化为:

$$(3+\gamma)24bsh^2l(1+\frac{s}{6h})+6bshV$$

2. 理论内存占用分析 (训练)

Transformer-Decode-Only类模型在训练中的主要显存占用可划分为:

  • 模型状态量(包括模型参数、梯度以及优化器状态),只和模型参数量有关。
  • 模型的中间激活值。
  • 计算冗余、临时buffer、存储碎片等。

类似地,我们仍然规定一些符号:

  • l:transformer 模型的层数
  • h:隐藏层大小
  • a:注意力头数
  • s:输入的序列长度
  • b:输入的batch 大小

2.1 模型参数

Transformer类模型的主要结构形式是:Embedding+ Transformer Layer $\times~l$ 。其中主要的参数和激活值部分都在Transformer Layer,该层主要由Attention和FFN(FeedForwardNetwork)组成。 而模型的参数主要集中在一系列(可学习的)线性映射过程中,包含weight和bias矩阵(由于广播的因素,bias的规模一般远小于weight)。

与第一部分不同,该部分笔者根据LLaMA-v1的结构进行分析,由于swiglu的使用,FFN的结构发生了些许变化,如下:

Untitled 2|Untitled 2.png

  • 传统FFN(例如GPT3):
    • $h_{ffn}=4h$
    • $Proj:h->h_{ffn} –GeLu– h_{ffn}->h$

  • LLaMA的FFN:
    • $h_{ffn}=256 \times ((\frac{8h}{3}+255) \mid 256)$,|表示整除。一般分析可近似为:$h_{ffn}=\frac{8h}{3}$
    • $Proj:Gate:h->h_{ffn},UP:h->h_{ffn} – SiLU(Gate)\cdot \ UP – Down:h_{ffn} ->h$

笔者统计出了各层的权重,如下表所示:

Untitled 1

一般再加上做Embedding时的proj矩阵(V,h),求和为:$l\times (4h^2+3h*h_{ffn})+V h$,可以近似$12h^2l+Vh$

对我们的LLaMA65B模型:$l=80,h=8192,V=50176,h_{fnn}=22016$,估算的模型参数量为:65.17B。

对LLaMA-7B模型:$l=32,h=4096,V=50176,h_{fnn}=11008$,估算的模型参数量为:6.68B。

2.2 简单内存占用估算

我们训练中实际使用的是默认的AdamW优化器,优化器状态中需要保留一阶动量二阶动量。而涉及到混合精度训练时,虽然参与前向计算的参数和反向计算的梯度都是fp16的形式,但为了避免参数更新时的舍入误差(FP16精度缺乏 导致的A+b=A之类情况),会额外在优化器内维护一份FP32的参数用于参数更新。这些信息全部采用FP32保存

梯度可以看作参数的一种“属性”,虽然不是所有参数都有梯度,但一般可将梯度大小近似看作与参数相同

一般的大模型内存分析,可以简单分析模型状态量的总占用。若我们将参数量记为$\Phi$,不同精度模式下的内存占用如下:

训练精度 参数 梯度 优化器 总大小
FP32 Φ*4Bytes Φ*4Bytes 2Φ*4Bytes 16Φ Bytes
FP16混精 Φ*2Bytes Φ*2Bytes 3Φ*4Bytes 16Φ Bytes

对于 65B模型为: 971 GB。

对于 7B模型为:99.5 GB。

2.3 中间激活值

我们在上面得到了一个反直觉的分析结果:为何采用混精并未减少内存占用? 这与我们在实践中得到的结论不一样啊。

笔者认为,应考虑如下两种情景:

  • 采用优化器侧的内存优化(例如ZeRo、不保留FP32副本),混精训练的内存节省在:参数和梯度、中间激活值
  • 不采用优化器的内存优化,混精训练的内存节省在: 中间激活值参数和梯度(被FP32副本抵消)。

为了简化分析,我们按照GPT3的decode-only结构讨论,计算过程采用FP16
Untitled 2 2|Untitled 2 2.png
我们仍然规定一些符号:

  • s:Seq Length
  • b: Batch Size
  • h:Hidden Size
  • a:Head Num
  • d:Dim Size, $h= a \times d$
  • l:Layer Num

2.3.1 Attention模块

我们考虑 Self-Attention 后接Linear和Dropout的结构。

  • $QKV$生成:输入X为$(b,s,h)$,通过一个线性映射生成QKV时需要权重:W**$(h,3h)$**和偏置bias(3h),然后再进行切分,一共得到a组QKV,共得到$a\times3\times (b,s,h/a)$。

    • 反向计算时,需要保留共同输入X:$bsh\times 2Bytes$
    • $X(b,s,h)\ \times W^Q_i(h,h/a)=> Q_i(b,s,h/a)\ X(b,s,h)\ \times W^K_i(h,h/a)=> K_i(b,s,h/a)\ X(b,s,h)\ \times W^V_i(h,h/a)=> V_i(b,s,h/a)$
  • $QK^T$:每个头的输入为$(b,s,h/a)\times2$,无需权重,得到矩阵$(b,s,s)$,共计 $a\times(b,s,s)$。

    • 反向计算时需要保留输入:$2bsh\times2Bytes$
    • $Q(b,s,h/a)\ \times\ K^T(b,h/a,s)=>A(b,s,s)$
  • Mask过程(可选),输入 $a\times(b,s,s)$的Mask。一般可认为每个元素采用1 Byte存储。

    • 反向时需要保留mask:$abs^2\times1Bytes$。
  • Softmax:每个头的输入是 (b,s,s),执行softmax无需权重也不改变输入大小,输出时仍然为(b,s,s),共计得到$a\times(b,s,s)$。

    • $softmax(\frac{A(b,s,s)}{\sqrt{h/a}})=> M(b,s,s)$
    • 反向计算时候需要保留:$abs^2\times 2Bytes$。
  • Softmax-Dropout:输入 $a\times(b,s,s)$的Mask,输出 $a\times(b,s,s)$。

    • $M(b,s,s)\odot Mask(b,s,s)=> M_{drop}(b,s,s)$
    • 反向计算需要保留mask:$abs^2\times1Bytes$。
  • $MapV$: 输入是$a\times Map (b,s,s)$和 $a\times V(b,s,h/a)$ ,计算$Map\times V$,无权重。输出是 $a\times V(b,s,h/a)$ , 多头输出通过Contact连接得到 $(b,s,h)$。

    • $M_{drop}(b,s,s)\times V(b,s,h/a)=>Y_i(b,s,h/a)$
    • 反向计算需要保留输入:$(abs^2+bsh)\times 2Bytes$。
  • Attention-Linear:输入是$(b,s,h)$ ,权重是W(h,h),偏置bias(h),输出为(b,s,h)。

    • 反向计算需要保留输入:$bsh\times 2Bytes$。
    • $Y(b,s,h)\times W(h,s)=>Z(b,s,h)$
  • Attention-Dropout:输入是一个Mask$(b,s,h)$。

    • $Z(b,s,h)\odot Mask(b,s,h)=>Z_{drop}(b,s,h)$
    • 反向计算需要保留mask:$bsh \times 1Bytes$。

综上,Attention 模块共需要激活值Mem:$11bsh+5abs^2 Bytes$。(如考虑softmax前的Mask,则是$\ 11bsh+6abs^2 Bytes$ )。

2.3.2 FFN模块

FFN我们考虑两个Linear层、一个GeLU激活和一个Dropout。

  • 第一个线性层:输入是$(b,s,h)$,权重$(h,4h)$,偏置bias(4h),输出是$(b,s,4h)$。
    • $X(b,s,h)\times W(h,4h)=>Z(b,s,4h)$
    • 反向需要保留输入:$bsh\times 2Bytes$。
  • GeLU:输入是$(b,s,4h)$,输出形状不变。
    • $GeLU(Z(b,s,4h)) => A(b,s,4h)$
    • 反向需保留输入:$4bsh\times 2Bytes$。
  • 第二个线性层:输入是$(b,s,4h)$,权重是$(4h,h)$,偏置bias(h),输出是$(b,s,h)$。
    • $A(b,s,4h)\times W(4h,h)=>Y(b,s,h)$
    • 反向需保留输入:$4bsh\times 2Bytes$。
  • Dropout:输入一个等大的$(b,s,h)$ Mask。
    • $Y(b,s,h)\odot Mask(b,s,h)=>Y_{drop}(b,s,h)$
    • 反向需要保留Mask:$bsh\times 1Bytes$。

综上,FFN(MLP)模块共需要激活值Mem:$\ 19bsh\ Bytes$ 。

2.3.3 LayerNorm层

  • LayerNorm:一个是在Attention前,一个是在MLP前。
    • 输入$Y(b,s,h)$ 按h维求均值方差 $\frac{X(b,s,h)-E[X]}{\sqrt(Var[X]+\xi)}\times \gamma+\beta$, 输出仍然是$(b,s,h)$
      两个层的输入都是$(b,s,h)$,输出不变。
      反向需要保留输入,共:$2bsh\times2Bytes$。

综上,2个LayerNorm模块共需要激活值Mem:$4bsh\ Bytes$ 。

最终,每个Transformer 需要保留的中间激活值(不采用激活重计算的情况下):$$\ bsh(34+\frac{5as}{h})~Bytes$$
对batch_size=1,seq_length=2048的情形:

  • 65B模型:$l=80,h=8192, a=64$, 中间激活值: 142.5GB(模型状态量 971GB)。
  • 7B模型:$l=32,h=4096, a=32$, 中间激活值:28.5GB(模型状态量 99.5GB)。
    若考虑FP32训练,中间激活值可以暴力翻倍来近似。

注意:

  • 当batch 增加时,中间激活值近乎成比例增加。
  • 当seq_length增加时,中间激活值将逐渐趋于平方量级增加。
    • seq扩增到8k时,65B模型的中间激活值约为 1770GB (已近12倍 )。

大模型中的计算和内存占用理论分析
http://example.com/posts/c00f223d/
发布于
2024年5月1日
许可协议