LLM Finetuning: Demystifying Huggingface Trainer 🚀
For a long time, I avoided using the Hugging Face Trainer because it didn’t offer the level of fine-grained control I preferred compared to pure PyTorch. Additionally, I struggled to find a comprehensive tutorial that demonstrated how to log examples post-training—something I consider essential for evaluating any training run. In this blog, I’ll walk you through training a large language model (LLM), integrating Weights & Biases (wandb) for tracking, and highlight some key gotchas to watch out for along the way.
In this tutorial, we’ll use the Qwen2.5-1.5B-Instruct
model to fine-tune a detector for prompt injection attempts, leveraging the xTRam1/safe-guard-prompt-injection
dataset. While this dataset is a great starting point, it’s worth noting that its labels can be subjective—after reviewing the examples, I found that I would have classified some cases differently. If you’re building your own prompt-injection detection model, it’s important to critically evaluate and curate your dataset.
The full implementation is available in this kaggle kernel. If you find it helpful, please consider upvoting!
Data Setup
For Hugging Face models to compute the loss, they require input_ids
, attention_mask
, and labels
to be provided. This means we need a custom collate function that generates these three items in a dictionary. Additionally, since we’re using an *-instruct
model, the input must adhere to a specific format. While the model can technically train without this formatting, doing so can significantly slow down convergence.
To add another layer of complexity, different models might have slightly varying prompt structures. However, the format outlined below appears to be the most widely used and serves as a good starting point.
def _format_message(tokenizer: transformers.PreTrainedTokenizer, query: str, answer: Optional[int] = None) -> str:
= [
messages "role": "system", "content": "You are a helpful assistant. You reply with just yes or no!"},
{"role": "user", "content": f"Is the following a prompt injection attempt. {query}"},
{
]if answer is not None:
"role": "assistant", "content": "yes" if answer == 1 else "no"})
messages.append({= tokenizer.apply_chat_template(
text
messages,=False,
tokenize=answer is None,
add_generation_prompt
)return text.strip()
Calculating Loss
In most tutorials I’ve encountered, the loss is either calculated over the entire sequence or masked only for padded values. However, when dealing with a scenario where the input includes a query and the goal is to generate a corresponding answer, it’s unnecessary to train the model to regenerate the query itself. Instead, we should focus the loss calculation on the generated answer portion of the sequence.
To achieve this, we can use the following snippet to calculate the appropriate loss mask:
def _get_assistant_mask(labels: torch.Tensor, assistant_token: int) -> torch.Tensor:
= (labels == assistant_token).cumsum(dim=1).argmax(dim=1)
last_indices = torch.arange(labels.size(1), device=labels.device).unsqueeze(0) <= last_indices.unsqueeze(1)
mask return mask
Explanation
To mask everything before the word “assistant,” we first identify its position within the sequence. In our example, we know the output only contains “yes” or “no.” However, if your output may also include occurrences of the word “assistant,” you’ll need to adapt this approach accordingly.
Here, labels
is simply the same as input_ids
. The line (labels == assistant_token).cumsum(dim=1)
identifies all positions where the token for “assistant” appears by performing a cumulative sum along the sequence. This ensures we find the last occurrence of the word “assistant.” Using argmax
, we then locate the position where this maximum cumulative value occurs, effectively giving us the final index of “assistant” in each sequence.
To create the mask efficiently and avoid using a for loop, we construct an array of shape (1, labels.shape[1])
with values ranging from 0
to labels.shape[1] - 1
. Next, we reshape the last_indices
we calculated earlier to (labels.shape[0], 1)
using unsqueeze(1)
. Broadcasting is then applied to generate a matrix of the same shape as labels
, allowing us to create the mask that selectively includes only the relevant portions of the sequence for loss calculation.
Callbacks and Wandb logging
Callbacks are one of the features that elevate the Hugging Face Trainer into a fully-fledged PyTorch powerhouse. While the loss calculation is abstracted within the transformers
Trainer
class, callbacks provide us with the flexibility to manipulate or extend the training process as needed.
In this example, we’ll use callbacks to generate model responses over the validation set. This will be done both at the start of the training cycle and at the end of each epoch. This allows us to monitor the model’s performance and behavior throughout the training process.
A complete list of available callback options can be found here.
class WandbPredictionProgressCallback(transformers.TrainerCallback):
def __init__(self, trainer, tokenizer, val_dataset, num_samples=100):
super().__init__()
self.trainer = trainer
self.tokenizer = tokenizer
self.valid_dataloader = DataLoader(
range(num_samples)),
val_dataset.select(=BATCH_SIZE,
batch_size=True,
pin_memory=False,
shuffle=False,
drop_last
)
@torch.inference_mode()
def log_examples(self, state, model):
f"Starting to log examples at global step {state.global_step}")
logger.info(eval()
model.= []
output_texts = []
prompt_texts for batch in self.valid_dataloader:
= [_format_message(self.tokenizer, text) for text in batch["text"]]
texts
prompt_texts.extend(texts)= self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, padding_side="left").to(model.device)
tokenized_text = model.generate(
output_ids **tokenized_text,
=5,
max_new_tokens=1,
num_beams=False,
do_sample
)
output_texts.extend(
[self.tokenizer.decode(
sum().item():],
output[attention_mask.argmax().item():][attention_mask.=True
skip_special_tokensfor attention_mask, output in zip(tokenized_text.attention_mask, output_ids)
)
]
)
= pd.DataFrame(
df
{"query_text": self.valid_dataloader.dataset["text"],
"actual_answer": ["yes" if label == 1 else "no" for label in self.valid_dataloader.dataset["label"]],
"predicted_answer": output_texts,
"prompt_text": prompt_texts,
}
)
f"validation_results_gs_{state.global_step}": wandb.Table(dataframe=df)})
wandb.log({f"Finsihed logging examples at global step {state.global_step}")
logger.info(
model.train()
def on_train_begin(self, args, state, control, **kwargs):
super().on_evaluate(args, state, control, **kwargs)
self.log_examples(state, self.trainer.model)
def on_epoch_end(self, args, state, control, **kwargs):
super().on_evaluate(args, state, control, **kwargs)
self.log_examples(state, self.trainer.model)
Things to note:
- We access the model by doing
trainer.model
. Not sure if this was necessary. model.generate
only generates up to 5 new tokens, and does so in a deterministic way.- To ensure proper alignment when generating text, we need to pad from the left. This is why we set
padding_side="left"
. Padding from the right, which is the default behaviour, can result in the shorter sequence receiving padding in the middle of the input. This disrupts the model’s ability to generate coherent text, as the padded tokens interfere with the sequence structure. By padding on the left, the input sequence remains contiguous, preserving the logical flow for text generation. - To decode just the answer we need to:
output[attention_mask.argmax().item():]
for that row because padding is from the left.[attention_mask.sum().item():]
of the result of above to remove input text.skip_special_tokens=True
to remove special tokens that make the output look messy.
Training
The final training code is as follows:
= CollateFn(tokenizer)
collate_fn = transformers.TrainingArguments(
training_args =EPOCHS,
num_train_epochs=BATCH_SIZE,
per_device_train_batch_size=GRADIENT_ACCUMULATION,
gradient_accumulation_steps=50,
warmup_steps=LEARNING_RATE,
learning_rate=WEIGHT_DECAY,
weight_decay=25,
logging_steps="steps",
save_strategy=250,
save_steps=1,
save_total_limit="paged_adamw_8bit", # for 8-bit, keep this, else adamw_hf
optim=True, # underlying precision for 8bit
bf16=f"./{model_name}-prompt-injection",
output_dir=f"sachin/{model_name}-prompt-injection",
hub_model_id="wandb",
report_to=False,
remove_unused_columns=1.0,
max_grad_norm
)
= transformers.Trainer(
trainer =model,
model=training_args,
args=collate_fn,
data_collator=train_ds,
train_dataset
)= WandbPredictionProgressCallback(trainer, tokenizer, valid_ds)
wandb_callback
trainer.add_callback(wandb_callback)
trainer.train()
The Gotchas
- I had to use QLora to train the model despite the model only being a 1.5B model. LORA unfortunately didn’t cut it. Perhaps due to having some sentences that were too long to fit in memory.
- Don’t forget to set
add_generation_prompt=True
when you are evaluating text, i.e. after training and when you are testing for the solution.
= tokenizer.apply_chat_template(
text
messages,=False,
tokenize=answer is None,
add_generation_prompt )
- Set
model.eval()
when inferring over the validation example andmodel.train()
once you are done. Also use@torch.inference_mode()
decorator over the function that you validating.
Results
Full results can be found in this wandb report.
Before training
As can be seen the predictions have extra text on top of being wrong.
After training
And after training for 514 steps (of batch size 4) we can see that the answers match up. We can also see that it selects the lower case setting as opposed to some of the upper case answers we saw earlier.
Summary
I have summarised the above in four take-aways:
- Introduction to Hugging Face Trainer
While the Hugging Face Trainer simplifies many aspects of training, its lack of fine-grained control initially made it less appealing. Logging examples post-training was also not well-documented. This tutorial demonstrates training a large language model (LLM), using Weights & Biases (wandb) for tracking, and tackling common challenges.
- Collate Function Requirements
Hugging Face models need input_ids, attention_mask, and labels for loss calculation. Custom collate functions ensure these are correctly generated. For *-instruct
models, inputs must follow a specific format to optimize convergence. This format varies slightly by model but has common patterns.
- Targeted Loss Masking
When the input includes a query and the goal is to generate an answer, it’s unnecessary to calculate loss for regenerating the query. A custom mask ensures the loss focuses only on the answer portion. Techniques like cumulative sum and broadcasting are used to efficiently identify and mask everything before the target token (e.g., “assistant”).
- Using Callbacks for Enhanced Functionality
Callbacks in the Hugging Face Trainer enable customization of the training loop. In this example, callbacks generate responses over the validation set at the beginning of training and at the end of each epoch. This provides insights into model performance during training. The full range of callbacks is available in the Hugging Face documentation.
By integrating these techniques, you can fine-tune models effectively, gain detailed insights during training, and handle challenges like loss masking and sequence formatting with precision. The full implementation can be found in this Kaggle kernel.
References + Kudos
Merve noyan’s blog, “Fine-tune SmolVLM on Visual Question Answering using Consumer GPU with QLoRA”