Screenshot 2024-03-01 at 8.47.34 AM.png

https://arxiv.org/abs/2401.10774

Presented by Daniel Varoli from Zapata AI

March 1st 2024

The Problem

LLMs are slow… and the bigger they are the slower they get

The bottleneck is not in the math or the GPU not being able to multiply faster but in the data transfer.

Per the paper, for each forward pass we have to move inputs, model weights + other stuff from High Bandwidth Memory (type of RAM on GPU I believe) to the GPU cache for computation

So a lot of time is spent doing data transfers

Essentially we are underutilizing the GPUs computational potential

Want to make our operations more arithmetically intensive (i.e. getting more FLOPs for the same amount of data movement)

And want to do less of those operations

What Can We Do?

Batch inference

Enhances arithmetic intensity

My hand-wavy explanation as to why:

We still only need to move the big honking model weights the same amount as before

But we get more FLOPs out because we have more inputs

⇒ Profit

However!

Also have things like KV-Cache that take up more memory and can put limit the size of batches you can use

This is especially problematic with LLMs

Reduce KV Cache Size

If we can reduce the size of the KV cache we may be able to get away with a higher batch size

Quantization

Instead of keeping everything in FP32 or even FP16, we use special techniques to represent the model weights, activations as a mix of 8-bit (or even 4bit and 1bit?) and possibly 16-bit operations