LLM Finetuning: Demystifying Huggingface Trainer 🚀

LLM
Tutorial on finetuning LLMs via HF transformers library with wandb logging
Author

Sachin Abeywardana

Published

December 20, 2024

For a long time, I avoided using the Hugging Face Trainer because it didn’t offer the level of fine-grained control I preferred compared to pure PyTorch. Additionally, I struggled to find a comprehensive tutorial that demonstrated how to log examples post-training—something I consider essential for evaluating any training run. In this blog, I’ll walk you through training a large language model (LLM), integrating Weights & Biases (wandb) for tracking, and highlight some key gotchas to watch out for along the way.

In this tutorial, we’ll use the Qwen2.5-1.5B-Instruct model to fine-tune a detector for prompt injection attempts, leveraging the xTRam1/safe-guard-prompt-injection dataset. While this dataset is a great starting point, it’s worth noting that its labels can be subjective—after reviewing the examples, I found that I would have classified some cases differently. If you’re building your own prompt-injection detection model, it’s important to critically evaluate and curate your dataset.

The full implementation is available in this kaggle kernel. If you find it helpful, please consider upvoting!

Data Setup

For Hugging Face models to compute the loss, they require input_ids, attention_mask, and labels to be provided. This means we need a custom collate function that generates these three items in a dictionary. Additionally, since we’re using an *-instruct model, the input must adhere to a specific format. While the model can technically train without this formatting, doing so can significantly slow down convergence.

To add another layer of complexity, different models might have slightly varying prompt structures. However, the format outlined below appears to be the most widely used and serves as a good starting point.

def _format_message(tokenizer: transformers.PreTrainedTokenizer, query: str, answer: Optional[int] = None) -> str:
    messages = [
        {"role": "system", "content": "You are a helpful assistant. You reply with just yes or no!"},
        {"role": "user", "content": f"Is the following a prompt injection attempt. {query}"},
    ]
    if answer is not None:
        messages.append({"role": "assistant", "content": "yes" if answer == 1 else "no"})
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=answer is None,
    )
    return text.strip()

Calculating Loss

In most tutorials I’ve encountered, the loss is either calculated over the entire sequence or masked only for padded values. However, when dealing with a scenario where the input includes a query and the goal is to generate a corresponding answer, it’s unnecessary to train the model to regenerate the query itself. Instead, we should focus the loss calculation on the generated answer portion of the sequence.

To achieve this, we can use the following snippet to calculate the appropriate loss mask:

def _get_assistant_mask(labels: torch.Tensor, assistant_token: int) -> torch.Tensor:
    last_indices = (labels == assistant_token).cumsum(dim=1).argmax(dim=1)
    mask = torch.arange(labels.size(1), device=labels.device).unsqueeze(0) <= last_indices.unsqueeze(1)
    return mask

Explanation

To mask everything before the word “assistant,” we first identify its position within the sequence. In our example, we know the output only contains “yes” or “no.” However, if your output may also include occurrences of the word “assistant,” you’ll need to adapt this approach accordingly.

Here, labels is simply the same as input_ids. The line (labels == assistant_token).cumsum(dim=1) identifies all positions where the token for “assistant” appears by performing a cumulative sum along the sequence. This ensures we find the last occurrence of the word “assistant.” Using argmax, we then locate the position where this maximum cumulative value occurs, effectively giving us the final index of “assistant” in each sequence.

To create the mask efficiently and avoid using a for loop, we construct an array of shape (1, labels.shape[1]) with values ranging from 0 to labels.shape[1] - 1. Next, we reshape the last_indices we calculated earlier to (labels.shape[0], 1) using unsqueeze(1). Broadcasting is then applied to generate a matrix of the same shape as labels, allowing us to create the mask that selectively includes only the relevant portions of the sequence for loss calculation.

Callbacks and Wandb logging

Callbacks are one of the features that elevate the Hugging Face Trainer into a fully-fledged PyTorch powerhouse. While the loss calculation is abstracted within the transformers Trainer class, callbacks provide us with the flexibility to manipulate or extend the training process as needed.

In this example, we’ll use callbacks to generate model responses over the validation set. This will be done both at the start of the training cycle and at the end of each epoch. This allows us to monitor the model’s performance and behavior throughout the training process.

A complete list of available callback options can be found here.

class WandbPredictionProgressCallback(transformers.TrainerCallback):
    def __init__(self, trainer, tokenizer, val_dataset, num_samples=100):
        super().__init__()
        self.trainer = trainer
        self.tokenizer = tokenizer
        self.valid_dataloader = DataLoader(
            val_dataset.select(range(num_samples)),
            batch_size=BATCH_SIZE,
            pin_memory=True,
            shuffle=False,
            drop_last=False,
        )

    @torch.inference_mode()
    def log_examples(self, state, model):
        logger.info(f"Starting to log examples at global step {state.global_step}")
        model.eval()
        output_texts = []
        prompt_texts = []
        for batch in self.valid_dataloader:
            texts = [_format_message(self.tokenizer, text) for text in batch["text"]]
            prompt_texts.extend(texts)
            tokenized_text = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, padding_side="left").to(model.device)
            output_ids = model.generate(
                **tokenized_text,
                max_new_tokens=5,
                num_beams=1,
                do_sample=False,
            )
            output_texts.extend(
                [
                    self.tokenizer.decode(
                        output[attention_mask.argmax().item():][attention_mask.sum().item():], 
                        skip_special_tokens=True
                    ) for attention_mask, output in zip(tokenized_text.attention_mask, output_ids)
                ]
            )

        df = pd.DataFrame(
            {
                "query_text": self.valid_dataloader.dataset["text"],
                "actual_answer": ["yes" if label == 1 else "no" for label in self.valid_dataloader.dataset["label"]],
                "predicted_answer": output_texts,
                "prompt_text": prompt_texts,
            }
        )

        wandb.log({f"validation_results_gs_{state.global_step}": wandb.Table(dataframe=df)})
        logger.info(f"Finsihed logging examples at global step {state.global_step}")
        model.train()

    def on_train_begin(self, args, state, control, **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        self.log_examples(state, self.trainer.model)
        
    def on_epoch_end(self, args, state, control, **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        self.log_examples(state, self.trainer.model)

Things to note:

  • We access the model by doing trainer.model. Not sure if this was necessary.
  • model.generate only generates up to 5 new tokens, and does so in a deterministic way.
  • To ensure proper alignment when generating text, we need to pad from the left. This is why we set padding_side="left". Padding from the right, which is the default behaviour, can result in the shorter sequence receiving padding in the middle of the input. This disrupts the model’s ability to generate coherent text, as the padded tokens interfere with the sequence structure. By padding on the left, the input sequence remains contiguous, preserving the logical flow for text generation.
  • To decode just the answer we need to:
    • output[attention_mask.argmax().item():] for that row because padding is from the left.
    • [attention_mask.sum().item():] of the result of above to remove input text.
    • skip_special_tokens=True to remove special tokens that make the output look messy.

Training

The final training code is as follows:

collate_fn = CollateFn(tokenizer)
training_args = transformers.TrainingArguments(
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION,
    warmup_steps=50,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    logging_steps=25,
    save_strategy="steps",
    save_steps=250,
    save_total_limit=1,
    optim="paged_adamw_8bit", # for 8-bit, keep this, else adamw_hf
    bf16=True, # underlying precision for 8bit
    output_dir=f"./{model_name}-prompt-injection",
    hub_model_id=f"sachin/{model_name}-prompt-injection",
    report_to="wandb",
    remove_unused_columns=False,
    max_grad_norm=1.0,
)

trainer = transformers.Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_ds,
)
wandb_callback = WandbPredictionProgressCallback(trainer, tokenizer, valid_ds)
trainer.add_callback(wandb_callback)

trainer.train()

The Gotchas

  • I had to use QLora to train the model despite the model only being a 1.5B model. LORA unfortunately didn’t cut it. Perhaps due to having some sentences that were too long to fit in memory.
  • Don’t forget to set add_generation_prompt=True when you are evaluating text, i.e. after training and when you are testing for the solution.
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=answer is None,
)
  • Set model.eval() when inferring over the validation example and model.train() once you are done. Also use @torch.inference_mode() decorator over the function that you validating.

Results

Full results can be found in this wandb report.

Before training

As can be seen the predictions have extra text on top of being wrong.

image of wandb table of actual and predicted answers

After training

And after training for 514 steps (of batch size 4) we can see that the answers match up. We can also see that it selects the lower case setting as opposed to some of the upper case answers we saw earlier. image of wandb table of actual and predicted answers

Summary

I have summarised the above in four take-aways:

  1. Introduction to Hugging Face Trainer

While the Hugging Face Trainer simplifies many aspects of training, its lack of fine-grained control initially made it less appealing. Logging examples post-training was also not well-documented. This tutorial demonstrates training a large language model (LLM), using Weights & Biases (wandb) for tracking, and tackling common challenges.

  1. Collate Function Requirements

Hugging Face models need input_ids, attention_mask, and labels for loss calculation. Custom collate functions ensure these are correctly generated. For *-instruct models, inputs must follow a specific format to optimize convergence. This format varies slightly by model but has common patterns.

  1. Targeted Loss Masking

When the input includes a query and the goal is to generate an answer, it’s unnecessary to calculate loss for regenerating the query. A custom mask ensures the loss focuses only on the answer portion. Techniques like cumulative sum and broadcasting are used to efficiently identify and mask everything before the target token (e.g., “assistant”).

  1. Using Callbacks for Enhanced Functionality

Callbacks in the Hugging Face Trainer enable customization of the training loop. In this example, callbacks generate responses over the validation set at the beginning of training and at the end of each epoch. This provides insights into model performance during training. The full range of callbacks is available in the Hugging Face documentation.

By integrating these techniques, you can fine-tune models effectively, gain detailed insights during training, and handle challenges like loss masking and sequence formatting with precision. The full implementation can be found in this Kaggle kernel.

References + Kudos

  1. Merve noyan’s blog, “Fine-tune SmolVLM on Visual Question Answering using Consumer GPU with QLoRA”

  2. Wandb docs on transformers