The first time I dove into the Llama model codebase around two years ago I was stumped by the definition of MLP. Unlike the traditional 2 layer MLP, it consisted of 3 layers. There was an additional linear projection layer called “gate” alongside the traditional “up” and “down”. I had intuition for thinking about the 2 layer MLPs but not for the 3 layer. So, I went to Google scholar and prepared myself to read a spate of papers on the history of MLPs in LLMs. What did I find? Crickets. There was just one paper by the one and only deep learning magician worth billions, Noam Shazeer. The explanation? “divine benevolence”. Great! This Reddit thread had more information than any academic paper I could find at that time. However, even after two years, I still don’t think most people know how different 3-layer MLPs are from the 2-layer ones.
$$ \begin{equation} y = f_{act} ( x \space W^{up}) \space W^{down} \end{equation} $$
Where, $x \in d^t$ is the input vector while $W^{up} \in d^{t} \times d^{h}$ and $W^{down} \in d^h \times d^t$ typically $d^h = m \times d^t , \space \space m >1$ and $f_{act}$ is the activation function. I understand this and have intuition for it. The input is a vector with $d^t$ number of elements while $W^{up}$ is a $d^h$ sized collection of vectors with each vector containing $d^t$ elements. Essentially, we’re calculating projection of the input vector $x$ on a collection of $d^h$ vectors in the $W^{up}$ matrix. The activation function, typically ReLU, filters out negative projection score. Thus, the vector that multiplies with the $W^{down}$ matrix is all positive. Which means the first matrix is sort of a pattern look up (keys) while the second matrix is the values to average over (values). Very similar to attention.
Now, here comes Noam Shazeer saying, no, that’s not the way to do MLPs. You gotta have 3 layers. He proposes an additional linear layer $W^{gate}$ of the same size as $W^{up}$.
$$ \begin{equation} y = (f_{act}(x \space W^{gate}) \space \odot \space (x \space W^{up})) \space W^{down} \end{equation} $$
Where, $\odot$ is element wise product between the two vectors of size $d^h$.
If the attention analogy ($W^{up}$ being the set of key vectors, $W^{down}$ being the set of value vectors and $x$ vector as the query vector) holds for 2-layer MLPs, then how do we think about the new $W^{gate}$ projection? Perhaps, you could think of it as a query projection, but how do you think about point wise multiplication with keys (output of up projection). You need some other way to intuit what’s going on. Let’s analyze the computation in detail.
$$ \begin{equation} y_1 = \sum_i x_i W_{i1} \end{equation} $$
This is how a single element of the output vector gets computed in the vector-matrix multiplication operation. Each element is a $x$ weighted sum of a single column in $W$. Let’s now ignore the activation function (assume that $\sigma(z) = z)$ and see how gate and up projections interact with each other.
$$ \begin{equation} y_1 = \sum_i x_i W^{gate}{i1} \times \sum_i x_i W^{up}{i1} \end{equation} $$
Expanding the product of sums, we get:
$$ \begin{equation} y_1 = x_1^2W^{gate}{11}W^{up}{11} + x_1 x_2 W^{gate}{11}W^{up}{21} + ... \end{equation} $$
Ha! This looks like a second degree polynomial. Written another way, we’re calculating dot product between transformed vectors $\hat{x} = x \space \otimes \space x$ and $\hat{W} = W^{gate}{:1} \space \otimes \space W^{up}{:1}$ where $\otimes$ represents something math people call Kronecker product. Another way to think about it is cartesian product. There are two sets of numbers and the operation is multiplying every number in one set with every number in the other set to create a new set. Yet another way to think about it is a 2nd degree polynomial kernel. When plain dot products are not enough, you whip out kernels in Machine Learning which allow you to calculate dot products in higher, non-linear, implicit feature spaces. Thus, affording more expressivity to your model.
The combination of normal dot product and element-wise product turns the first half of MLP computation into a dot product over a much larger implicit feature space. Instead of $x = [x_1, x_2, ... x_{d^t}] \in d^t$ we get $\hat x = [x_1^2, x_2^2, ... x_{d^t}^2, x_1x_2, ... x_{d^t-1}x_{d^t}] \in d^t \times d^t$. But, we don’t store the full implicit vector and never do the actual full dot product. That’s exactly kernel trick! Same thing happens for the weight vectors. Instead of the full $d^t \times d^t$ weight vector we break it into two $d^t$vectors in gate and up projection matrices. $2 d^t d^h \ll d^t d^t d^h.$
This has some interesting implications. First, you don’t strictly need the activation function anymore. The GLU paper calls this bilinear activation function. Second, the rank of the weight matrix is not $d^t$, like in 2-layer MLPs, but actually $d^h$. How come? The up and gate projection matrices are $d^t\times d^h = d^t \times md^t$ (recall that $d^h = md^t)$, and since $m > 1$ but $m \ll d^t$ (in Mixtral $m = 2.75$), the rank is still $d^t$. However, when you factor in the element wise product, the you’re implicitly using a matrix $\hat W \in d^td^t \times md^t$. Hence, as long as $m < d^t$ the rank is $d^h = md^t$. Thus, increasing $d^h$ should also increase the modeling capacity of the MLP. Essentially, hidden dimension of MLPs can now be scaled to much larger values!