Article·AI & Engineering·Mar 7, 2024
22 min read

Building an LLM Stack Part 3: The art and magic of Fine-tuning

22 min read
Zian (Andy) Wang
By Zian (Andy) Wang
PublishedMar 7, 2024
UpdatedJun 27, 2024

In the previous article on building the Large Language Model (LLM) Stack, we covered various methods and techniques utilized in the first stage of LLM training.

Large Language Models are trained in several stages. The initial phase involves pre-training on a vast text corpus in an unsupervised manner, meaning without the support of labels. While models that are pre-trained have an excellent understanding of the human language, they are not inherently equipped to perform specific, nuanced tasks such as engaging in dialogues, providing answers to questions, or crafting summaries.

To bridge this gap, the next phase involves fine-tuning the pre-trained models. Think of this process as sharpening a broadsword to perform the precise work of a scalpel. Fine-tuning tailors the model to excel in particular tasks by training it further on a smaller, task-specific dataset. This stage equips the model with the ability to apply its broad language understanding to specific applications, enhancing its performance and making it adept at tasks it was previously unprepared for.

As opposed to pre-training, which typically demands millions of dollars and thousands of hours, fine-tuning stands out as significantly more affordable and feasible on consumer-grade hardware. Anyone with access to a robust GPU or willing to invest a modest sum in cloud computing resources can fine-tune language models to their desire.

Moreover, most pre-trained models utilize datasets up to hundreds of terabytes and billions of tokens to widen the breadth of the model’s knowledge base with many of these datasets being proprietary assets that large companies licensed from third-party data providers. However, during fine-tuning, datasets as small as a couple thousands tokens may be enough to guide the model to perform a downstream task.

In fact, the majority of the leading open-source Large Language Models originate from just two pre-trained foundational models. These models are either built on the LLaMA architecture, released by Meta, or the Mistral model, introduced by the French startup Mistral.ai.

Why Fine-tune at All?

Imagine building a skyscraper. It’s far more practical to start with a strong, versatile foundation capable of supporting various structures, and then customize the upper floors to serve specific purposes, rather than constructing a small building for every single purpose.

In terms of training Large Language Models, fine-tuning a pre-trained model for specific tasks is similar to adding and customizing the upper floors of our skyscraper. It’s easier and more efficient to refine something large and general into something specialized, than to attempt building the vast, complex foundation for each new task from the ground up. This approach not only saves time and resources but also ensures the model benefits from a rich, diverse linguistic base.

Language Models are like humans, when learning a new language, you don’t start by memorizing phrases for very specific situations—like ordering at a restaurant or asking for directions. Instead, you first lay the groundwork by understanding the grammar, vocabulary, and the nuances of how words come together to form meaningful sentences.

This foundational knowledge equips you with the flexibility to navigate a vast array of conversations, even those you’ve never explicitly prepared for. Once you have a solid base, you can then fine-tune your language skills for specific scenarios, enhancing your fluency in particular contexts without being limited to them.

The Basics of LLM Fine-Tuning

Before diving into the process of fine-tuning, there is one term that gets thrown around a lot in the context of fine-tuning that needs to be addressed: alignment.

As shown in the first figure, a language model that is not fine-tuned will produce results that are either incoherent, or completely irrelevant to the desire of the user. Indeed, pre-trained language models are only capable of predicting the most likely token following the end of the user input based on its training. This is where alignment comes into play.

In order to extract the knowledge and predictive power of LLMs, we need to align the model to output coherently structured text that is understandable to humans. And the process to align the model is referred to as fine-tuning.

Stages of Fine-Tuning

Fine-tuning typically unfolds in two distinct phases: the initial phase is Supervised Fine-Tuning (SFT), followed by a second phase that involves alignment through either Reinforcement Learning or other task-specific techniques.

Conversational chatbots released by companies such as Google and OpenAI have gone through both phases of fine-tuning. The first phase transforms the initially “useless”  pre-trained model into one capable of engaging in natural conversations with users.

Subsequently, a second phase of fine-tuning is applied to further refine the model, ensuring that its responses are appropriate and free from bias. This stage often employs a Reinforcement Learning algorithm for alignment.

On the other hand, many open-source models available to the public may have undergone only the first fine-tuning phase. While these models are operational, their capacity to deliver responses in a preferred tone or to generate unbiased outputs may be limited.

The two phases of language model fine-tuning was first introduced by OpenAI in their InstructGPT, the predecessor to the later revealed ChatGPT. Its surprising effectiveness along with the added benefit of low resource cost has allowed it to become the de facto of LLM fine-tuning.

Supervised Fine-Tuning

Supervised Fine-Tuning (SFT) is surprisingly simple to understand and to execute. Recall from the previous article on the pre-training stage of Large Language Models, where the objective of training is next token prediction. SFT employs this same fundamental goal. The primary distinction from the pre-training phase lies in the dataset used.

Rather than predicting the next token from a sequence of words, SFT adopts a task-specific structure.

Taking conversation as an example, the SFT dataset would include pairs of well-curated prompts and responses. These prompts represent potential user inputs, while the responses are what the Large Language Model is expected to produce.

We then fine-tune the model over these prompt-response pairs with the next-token prediction objective. During training, the prompts are fed into the model as “context,” and the model’s prediction of the next token is evaluated against the expected response. The typical loss function used with the next-token prediction objective is identical to pre-training, usually being the cross entropy loss.

A Brief History on Fine-Tuning

Note that SFT is different from fine-tuning in general. The notion of fine-tuning in machine learning has long existed before the popularity of the Large Language Model.

Fine-tuning, as a general principle in machine learning, involves adjusting a pre-trained model to enhance its performance on a particular task. This concept is not exclusive to language models but extends across various domains of machine learning.

For instance, in the field of computer vision, Convolutional Neural Networks (CNNs) pre-trained on large datasets like ImageNet have been fine-tuned for specialized tasks such as facial recognition or medical image analysis. The history of pre-training CNNs goes back to as far as the late 2000s with the introduction of the ImageNet dataset. Most CNNs such as ResNets and EfficientNets are equipped with pre-trained weights and can be fine-tuned on any type of task out of the box.

The Pros and Cons of Supervised Fine-Tuning

Supervised Fine-Tuning is extremely cost efficient. Compared to its counterpart, pre-training, SFT costs less than a fraction of the time and resources needed for pre-training. There are many existing libraries such as the Transformer Reinforcement Learning (TRL) library that provides out-of-the-box implementations of SFT.

To fine-tune a model on the chat-styled dataset openassistant-guanaco only requires a few lines of code to format the dataset into a prompt-response structure before executing a one-liner to initiate the training.

However, with efficiency and ease of use comes several drawbacks.

  1. Data Dependency: The behavior of the model is highly dependent on the quality of data and the diversity of its contents. Datasets that aren’t well curated are likely to result in models that behave inconsistently.

  2. Data Curation: Collecting enough pairs of high quality prompts and responses is challenging along with ensuring the diversity of its content. Although there have been techniques such as Self-Instruct that utilizes the generative nature of the LLM itself to generate its own fine-tuning data, creating a perfect dataset to train a LLM solely on SFT is still challenging.

  3. Catastrophic Forgetting: Studies show that when a LLM is fine-tuned on domain-specific data, the model is likely to experience catastrophic forgetting of its previous, foundational knowledge obtained during the pre-training phase. There are various techniques developed to mitigate such issues, which we will discuss later, but none are 100% effective.

Reinforcement Learning Based Fine-Tuning

After performing supervised fine-tuning, the model will have a basic understanding of the task at hand, but further alignment may be needed in order to achieve the desired behavior.

Going back to the SFT example, a supervised fine-tuned model for conversation may be able to mimic the back and forth style of human conversation, but the tone it uses could be overly casual or overwhelmingly monotonic. This is where Reinforcement Learning based approaches come in.

With SFT, what is classified as “correct” is difficult to define. Although producing the exact same output as the label in the training dataset would result in a perfect loss value, there are other ways to respond that would appeal to a user in a similar way, if not better.

Teaching language models how to “speak” a language is drastically different from training them to classify 1s and 0s, it’s like the classic example of math and english. In math, most problems have a definite, numerical answer while in English, an essay can meet the requirements of the teacher in infinitely many ways.

Constructing a loss function to capture the intricate “correctness” of the human language is a near impossible task since the art of language is an inherent “human” ability. This is exactly why humans are involved in the next stage of fine-tuning.

Reinforcement Learning with Human Feedback (RLHF)

Reinforcement Learning with Human Feedback was first introduced in 2017 in the paper “Deep reinforcement learning from human preferences” for Reinforcement Learning from OpenAI then later utilized in ChatGPT to fine-tune GPT-3.

In order to perform RLHF on a LLM, the algorithm itself doesn’t require the model to be fine-tuned on additional task-specific data, but in most cases a raw, pre-trained model is not directly used with RLHF.

Human preferences in RLHF are reflected through the reward model. The reward model is another language model where its only job is to evaluate a given output based on human preferences, outputting a numerical reward value that can be interpreted by the pre-trained LLM.

To generate a dataset for training the reward model, human annotators will rank LLM outputs based on certain criteria, depending on the eventual goal of the fine-tuning. Relative comparisons such as ranking generally produce much better results than directly assigning a scalar reward to an output.

These rankings are then transformed into scalar values through ranking systems such as those similar to the Elo system in Chess. The values represent a relative “correctness” or “incorrectness” between each LLM output, rather than an absolute measure of preference.

Finally, the output-reward pair is fed into a language model. After training, that model will be calibrated to “judge” LLM outputs based on human preference.

Once the reward model is trained, it is integrated into the RLHF process for fine-tuning the LLM. This involves taking a pre-trained model and updating its parameters using a policy-gradient method like Proximal Policy Optimization (PPO). The fine-tuning process is carefully managed to only adjust a subset of parameters due to computational costs. The RL policy in this case is the LLM itself, taking a prompt and generating text, with its action space being the possible tokens it can produce, and the observation space being the possible inputs it can receive.

The reward function in RLHF combines the scalar reward from the reward model with a penalty term based on the Kullback–Leibler (KL) divergence, which ensures that the fine-tuned model does not deviate too significantly from the pre-trained model, preventing the generation of nonsensical text. The final reward sent to the RL update rule is a combination of the preference model reward and the KL penalty.

The details of RLHF training is beyond the scope of this article, for a more in-depth explanation, checkout the following article.

Unfortunately, there are no out-of-the-box one-liners to fine-tune a LLM using RLHF as it is a complex method involving human-in-the-loop processes. Below is a minimal working example using the library TRL.

Notice that the code only provides one example prompt and a corresponding sample reward. In order to fully fine-tune a LLM with RLHF, an actual reward model needs to be trained along with much more prompts than a single one (obviously).

Direct Preference Optimization (DPO)

Due to the complexity and work required in RLHF, it’s typically not the preferred method of fine-tuning, especially among smaller research labs and open source models. Direct Preference Optimization (DPO) is a preference based fine-tuning approach for LLMs as opposed to a Reinforcement Learning based one.

RLHF is also prone to instability, as it relies on a reward model that mimics the preferences of a human, which can be unpredictable at times.

Instead of training a reward model to use with Reinforcement Learning, DPO formulates the fine-tuning framework directly into a loss function, much like what Supervised Fine-Tuning does. However, SFT only operates on a pair of text, a prompt and an expected response while DPO incorporates an “incorrect” response in addition to a preferred response.

DPO sidesteps the need for a reward model. Instead of learning what constitutes a “good” or “bad” output indirectly through a reward model, DPO adjusts the language model directly based on human preference data. Essentially, it treats the preference data as a target to optimize towards, thus removing the need to first learn a separate reward function and then optimize the language model towards that.

The key of DPO is its loss function. SFT only implements a simple next-token prediction objective function, which serves well for basic, general tasks but fails at accurately adhering to human preferences.

Here, the loss function is a measure of how well the model π_θ (with parameters θ) predicts human preferences compared to a reference model π_ref. The reference model is typically the Supervise Fine-Tuned model without any further alignment.

The function $σ$ is the logistic sigmoid function, which transforms the log odds ratio into a probability between 0 and 1. The term inside the logarithm is essentially a comparison of how much more likely the model is to predict the human-preferred completion $y_{u}$ over the less preferred completion $y_{l}$, compared to the reference model.

Beta(β) is a hyperparameter that controls the strength of the preference. A higher β puts more emphasis on the preferences, making the model more sensitive to human feedback.

If we take a look at the gradient of the loss function, which is what will actually be employed to update the model during training, things actually become somewhat intuitive.

This gradient tells us how to adjust the parameters θ of our model $π_{θ}$ in order to make the preferred completion y_u more likely and the less preferred completion $y_{l}$ less likely. The term $σ(r_θ(x, y_u) - r_θ(x, y_l))$ serves as a weight, indicating the confidence of the preference: if the model is very wrong about a particular preference, this term will be large, prompting a larger update to the model parameters. The second term in inner square brackets will increase the likelihood of a preferred completion and decrease the likelihood of a dispreferred response.

In practice, DPO requires much less computational resources and manual labor to achieve the same results, if not better compared to RLHF. Curating a dataset for DPO only takes a few lines of code. The below example constructs a dataset suitable for DPO training with the TRL library from the Stack Exchange Preferences dataset.

Then training the model is a concise one-liner, similar to the SFT training code discussed previously.

Fine-Tuning Optimization Frameworks

Though DPO reduced the manual work required to align LLMs beyond basic supervised fine-tuning by proposing a straightforward pipeline for curating datasets and subsequently training on it, the hardware compute needed to fine-tune larger LLMs may still be out of consumers’ reach.

Fortunately, as the techniques for fine-tuning improved, methods for adapting fine-tuning to consumer hardware progressed as well, being able to fine-tune models with billions of parameters on low-grade hardware without performance drops. Together, these methods are referred to as “Parameter Efficient Fine-Tuning” or PEFT.

Additionally, PEFT also addresses the deployment of LLMs on consumer hardware. Since the fine-tuned model still retains the same size as the original LLM and will likely take up memory footprints of similar magnitude, merely running the full model on consumer hardware is a challenge.

Low Rank Adaptation (LoRA)

Low Rank Adaptation (LoRA) was one of the first and most effective methods balancing both performance and cost when it comes to fine-tuning LLMs. LoRA relies on some fundamental Linear Algebra concepts, namely the rank of a matrix and its rank decomposition.

Matrix Rank

The rank of a matrix is defined as the maximum number of linearly independent column vectors (or row vectors) in the matrix. A set of vectors is considered linearly independent if no vector in the set can be expressed as a linear combination of the others. In other words, you can’t recreate any vector in the set by adding and scaling (multiplying by numbers) the other vectors in the set.

The rank of a matrix will always be less than or equal to the smaller dimension of the matrix (number of rows or the number of columns).

Intuitively, think of each column (or row) vector in a matrix as a direction you can travel in space. The dimension of the space is determined by the number of entries in each vector (e.g., 3 entries for 3-dimensional space). When you plot all the vectors from a matrix in this space, they span a certain shape or volume. The rank of the matrix tells you the maximum number of dimensions this shape or volume can have.

For example, if you have a matrix with three vectors in 3-dimensional space (3D), there are a few possibilities:

  • If all three vectors point in completely different directions and none of them can be made by combining the others, they span a full 3D space. You can think of it as forming a 3D parallelepiped (a skewed box). This matrix has a rank of 3 because you have three dimensions of movement without redundancy.

  • If one of the vectors can be made by combining the other two, or if two of the vectors lie exactly in the same direction, then no matter how you try to stretch or skew, you’ll only get a flat shape, like a sheet of paper or a plane. This means you’ve lost a dimension; you’re now operating in 2D. Hence, the matrix has a rank of 2.

  • Similarly, if all three vectors lie exactly in the same line, you’re down to just a 1-dimensional line, and the rank of the matrix is 1.

In the context of LoRA, we can interpret the rank of a matrix as the distinct feature space it represents. As the rank of a matrix increases, so does the feature space it spans.

The LoRA Approach

During fine-tuning, the model parameters are updated through gradient descent, beginning with a set of pre-trained weights $W$ and iteratively updating these weights. We represent this update as $W + \Delta W$. In typical fine-tuning, $\Delta W$ would match the dimensions and size of (W), which could significantly impede compute time and increase the storage required for each fine-tuned model.

Building on previous works that observed a similar phenomenon, the authors hypothesize that during fine-tuning, the update matrix $\Delta W \in \mathbb{R}^{d \times k}$ possesses a low “intrinsic rank,” allowing them to represent $\Delta W$ through its low-rank decomposition $BA$, where $B \in \mathbb{R}^{d \times r}$ and $A \in \mathbb{R}^{r \times k}$. Thus, both $A$ and $B$ will have a maximum rank of $r$, ensuring their product retains the same dimensions as $\Delta W$. In essence, unlike pre-training, which requires a full-rank weight matrix to capture all the information, fine-tuning can capture the essential components from a fine-tuning dataset without necessitating a full-rank matrix.

The authors maintain the “update” matrix separate from the original pre-trained weights, thereby modifying the forward pass to include:

Here, $W_{0}$ represents the pre-trained weight matrix and $x$ the input features.

The rank $r$ serves as a hyper-parameter determined by the user, balancing memory constraints with the trade-off between speed and performance. The optimization LoRA offers can be quite significant, with GPT-3 capable of being fine-tuned by updating only 0.01% of its pre-trained weights.

During training, the update matrix is scaled by $\frac{\alpha}{r}$ where $\alpha$ is a hyper-parameter. A typical default sets $\alpha$ to twice the value of $r$ while a typical $r$ value sits around 8-32.

LoRA further advances its optimization techniques by significantly enhancing the efficiency of deploying different fine-tuned models based on the same foundational model. Since the low-rank update matrix is stored separately from the pre-trained weights, switching the model to another downstream task involves merely subtracting $BA$ and replacing it with a different $B’A’$, requiring minimal memory overhead.

The concept of LoRA can be readily applied to any other type of foundational models which was pre-trained on a large dataset and requires large compute to directly fine-tune. For example, both diffusion models and audio generation models are able to utilize LoRA to be fine-tuned on consumer hardware.

In practice, the TRL library provides a minimal working example here for fine-tuning LLaMA models with custom datasets.

Quantized Low Rank Adaptation (QLoRA)

QLoRA aims to further improve upon the efficiencies introduced by LoRA through 4-bit NormalFloat, Double Quantization, and Paged Optimizers. QLoRA is hardly an innovative fine-tuning technique itself, but rather a further optimization method that sits on the basis of LoRA.

Here’s an intuitive explanation of each concept:

  • 4-bit NormalFloat Quantization: 4-bit NormalFloat Quantization condenses neural network weights into a compact format using only 4 bits per weight. This is achieved by mapping the weights, which typically exhibit a normal distribution, onto 16 discrete quantiles. The process involves calculating the necessary statistical quantiles of a standard normal distribution and then adjusting them to fit the specific spread of the actual weight values. This adjustment is done by scaling the standard theoretical quantiles to match the real distribution’s standard deviation. By doing this, the method ensures an even distribution of weight representations across the 16 levels, which is vital for maintaining the integrity of the network’s performance. Notably, this technique guarantees an exact representation for the value zero, which is essential for encoding network elements like padding or intentional zeros efficiently.

  • Double Quantization: Imagine you’re packing for a trip and you have a bag within a bag. You want to make sure both bags are optimally packed to save space. Double Quantization applies this concept to the computation of neural networks. First, the method optimally packs (quantized) the network weights, and then it optimally packs the information needed to unpack them (quantization constants). This two-step packing process further reduces memory usage without significantly impacting the network’s performance.

  • Paged Optimizers: Paged Optimizers allow a neural network to use the memory of both the GPU and CPU efficiently utilizing the unified memory feature that NVIDIA GPU provides. When the GPU runs out of memory, it temporarily ‘borrows’ some from the CPU, just as you might store excess items in a spare room when your main space is full. Once the GPU is ready to use these items again, it ‘retrieves’ them back. This helps in managing large models that would otherwise not fit into the GPU’s memory, preventing the computation from stopping due to a lack of space.

Note that QLoRA does introduce an additional overhead in training time overhead due to the quantization and de-quantization of the weights.

Takeaways

This article aims to provide a brief overview and pipeline of the often confusing process of fine-tuning with all its jargon. It shows the vast possibilities of fine-tuning and how researchers has democratized Large Language Models for anyone to use.

Unlike pre-training, where the most crucial information, datasets, is proprietary, fine-tuning stays mostly with the open-source spirit, with a vast number of datasets available on platforms like HuggingFace that are curated and ready to be used.

In fact, the Open LLM Leaderboard hosted on HuggingFace contains many fine-tuned models that have iterated on each other to create increasingly powerful models. The future remains bright for open access to fine-tuning foundational models.

Note: If you like this content and would like to learn more, click here! If you want to see a completely comprehensive AI Glossary, click here.

Unlock language AI at scale with an API call.

Get conversational intelligence with transcription and understanding on the world's best speech AI platform.