Transformer架构

image.png

image.png

公式

Parameter

Parameter Description
$bs$ batch size
$d$ the model size / hidden state dimension / positional encoding size
$h$ number of attention heads
$dh$ head dimension, usually $d/h$
$L$ sequence length
$N_l$ number of transformer layers
$X^n \in \mathbb R^{bs \times L \times d}$ hidden state of nth layer
$W^{n}_{fnn1} \in \mathbb R^{d \times 4d}$ Weight matrix of first feed-forward layer
$W^{n}_{fnn1} \in \mathbb R^{4d \times d}$ Weight matrix of second feed-forward layer
$W^{n}_{q} \in \mathbb R^{d \times d}$ Weight matrix of query projection
$W^{n}_{k} \in \mathbb R^{d \times d}$ Weight matrix of key projection
$W^{n}_{v} \in \mathbb R^{d \times d}$ Weight matrix of value projection
$W^{n}_{o} \in \mathbb R^{d \times d}$ Weight matrix of output projection

Transformer layer

$$ \begin{aligned} &\mathbf{TransformerLayer}(X^{n} \in \mathbb{R}^{bs\times L \times d}): \\\ & (1)~~X^{n}_h = X^{n} + \text{MultiHeadAttention}(X^{n}) \\\ & (2)~~X^{n+1} = X^{n}_h + \text{FFN}(X^{n}_h) \\\ & \text{return}~X^{n+1} \\\ \end{aligned} $$

FNN layer

$$ \begin{aligned} & \mathbf{FFN}(X \in \mathbb{R}^{bs\times L \times d}): \\\ & (3)~~ H_{pre} = X \cdot W_{ffn1} \in \mathbb{R}^{bs\times L \times 4d} \\\ & (4)~~ H_{post} = \text{GeLU}(H_{pre}) \in \mathbb{R}^{bs\times L \times 4d} \\\ & (5)~~ Y \in \mathbb{R}^{bs\times L \times d} = H_{post} \cdot W_{ffn2} \\\ & \text{return}~Y \\\ \end{aligned} $$

MultiHeadAttention

$$ \begin{aligned} & \mathbf{MultiHeadAttention}(X \in \mathbb{R}^{bs\times L \times d}): \\\ & (6)~~ Q = X \cdot W_q \in \mathbb{R}^{bs\times L \times d} \\\ & (7)~~ K = X \cdot W_k \in \mathbb{R}^{bs\times L \times d} \\\ & (8)~~ V = X \cdot W_v \in \mathbb{R}^{bs\times L \times d} \\\ & (9)~~ O = [\text{head}_1; …; \text{head}_j; …; \text{head}_h] \cdot W_o \in \mathbb{R}^{bs\times L \times d} \\\ & \text{head}_j = \text{SelfAttention}(Q\tiny{[:, :, (j-1)dh: j dh]}, \normalsize{K}\tiny{[:, :, (j-1)dh: j dh]},\normalsize{V}\tiny{[:, :, (j-1)dh: j dh]}\normalsize) \in \mathbb{R}^{bs\times L \times dh} \\\ & \text{return}~O \\\ \end{aligned} $$

SelfAttention

$$ \begin{aligned} & \mathbf{SelfAttention}(Q \in \mathbb{R}^{bs\times L \times dh}, K \in \mathbb{R}^{bs\times L \times dh}, V \in \mathbb{R}^{bs\times L \times dh}): \\\ & (10)~~ S = Q \cdot K^T \in \mathbb{R}^{bs\times L \times L} \\\ & (11)~~ P = \text{Softmax}(\frac{mask(S)}{\sqrt{dh}}) \in \mathbb{R}^{bs\times L \times L} \\\ & (12)~~ O = P \cdot V \in \mathbb{R}^{bs\times L \times dh} \\\ & \text{return}~O \end{aligned} $$

单次训练所需要的算力

一次MatMul所需要的算力

forward

MatMul是深度神经网络中最常见的的操作,也是最消耗算力的步骤,我们想先把他拆解开。

我们以向量乘法 $C = A \times B$为例,其中 $C \in \mathbb R^{M \times N}$,$A \in \mathbb R^{M \times K}$,$B \in \mathbb R^{K \times N}$。

计算结果向量C中一个元素,需要K次乘法(A向量中一行的每个元素逐个乘以B向量一列的每个元素,A向量中一行和B向量中一列的元素个数都为K),和K次加法,用于把每个乘的结果相加。计算矩阵中一个元素需要进行 $2K$ 次的计算

结果矩阵C中一共有 $M \times N$个,所以整个矩阵计算一共需要计算 $2K \times M \times N$次,也就是 $2K \times M \times N$ FLOPs