“drop the ‘the’ just facebook it’s cleaner” — “Social Network” (link)
This Thinky blog post (https://thinkingmachines.ai/blog/on-policy-distillation/) have some sleek definitions of different stages of LLM training, upon which I elaborate a bit more:
Pre-training
- Goal: learn general capabilities (language use, reality about human online experience and world)
- Dataset: internet-scale, mixed quality, various sources
- There isn’t universal structure about the data at such scale. The only structure is that human language isn’t a sequence of random tokens. It follows grammar and encodes underlying knowledge
- Objective: next token prediction since we can only exploit the universal language structure that the next token and its previous context are related.
Post-Training:
- Goal: Elicit targeted behaviors by morphing its internal knowledge in higher dimensional space into a more structured form (e.g., QA) in lower dimensional space. This could mean instruction following, reasoning through math problems, or chat.
- Dataset: smaller-scale, higher-quality, less diverse, similar sources
- There is a lot of structure in Q&A data or math reasoning problems where there can always be an answer in the box.
- Objective: there is a lot of structures that we can exploit, like preference signal or verifiable reward from the answer in the box. So, we can design our learning algorithms more tailored towards the downstream task with SFT and RL.
⭐️Mid-Training⭐️:
- Goal: Train a general model into a domain-expert model by imparting domain knowledge, such as code, medical databases, or internal company documents.
- Dataset:
- Much small than web-scale data, usually higher quality or at least more close in distribution.
- There can be a lot of structure to exploit like URLs, citations, or code syntax etc.
- Objective: We can still use next token prediction but we can leverage these structural meta-data to generate synthetic data. Or, we can use JEPA-style loss to learn embeddings based on the data structure.
Now, you see there isn’t really a formal definition of pre-/mid-/post-training. The real question are: What is the scale of the dataset? What inherent structure does the dataset have? How can we exploit them with data curation or learning algorithms?
The reason why JEPA will not work on language is that it has too many assumptions about the data structure that the learning objective, i.e., contrastive learning of joint embeddings, cannot generalize to internet scale.