Fine-tuning Google T5-Small on Summarization¶

Fine-tune the T5-Small model on the SAMSum dialogue summarization dataset using HuggingFace Transformers.

1. Import Libraries

Install all required libraries for dataset loading, model training, and evaluation.

In [ ]:
!pip install -q datasets transformers accelerate transformers[sentencepiece] sacrebleu rouge_score py7zr

Import all required libraries and suppress warnings.

In [ ]:
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import DataCollatorForSeq2Seq
from transformers import TrainingArguments, Trainer
from transformers import pipeline
import warnings
warnings.filterwarnings("ignore")

2. Load Model & Tokenizer

Load the T5-Small tokenizer and model from HuggingFace Hub and move the model to GPU.

Popular summarization models sorted by size (smallest → largest):

Model Params Prefix needed
t5-small 60M "summarize: "
google/flan-t5-small 80M "summarize: "
facebook/bart-base 139M none
sshleifer/distilbart-cnn-12-6 139M none
t5-base 220M "summarize: "
google/flan-t5-base 250M "summarize: "
facebook/bart-large 406M none
google/pegasus-cnn_dailymail 568M none
t5-large 770M "summarize: "
In [ ]:
model_checkpoint = "t5-small"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to("cuda")

3. Load SAMSum Dataset

Load the SAMSum dataset — a collection of messenger-style dialogues paired with human-written summaries.

In [ ]:
dataset = load_dataset("knkarthick/samsum")
dataset

4. Tokenize the Dataset

Tokenize dialogues (input) and summaries (target). Prepend "summarize: " to each dialogue as T5 requires a task prefix. Truncate inputs to 1024 tokens and targets to 128 tokens.

In [ ]:
def tokenize_content(data):
    dialogues = data["dialogue"]
    summaries = data["summary"]

    inputs = ["summarize: " + d if d else "summarize: " for d in dialogues]
    targets = [s if s else "" for s in summaries]

    input_encoding = tokenizer(inputs, max_length=1024, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        target_encoding = tokenizer(targets, max_length=128, truncation=True, padding="max_length")

    return {
        "input_ids": input_encoding["input_ids"],
        "attention_mask": input_encoding["attention_mask"],
        "labels": target_encoding["input_ids"],
    }

tokenized_dataset = dataset.map(tokenize_content, batched=True)

5. Setup Data Collator

Create a DataCollatorForSeq2Seq which dynamically pads inputs and labels to the longest sequence in each batch at runtime — more efficient than static padding.

In [ ]:
seq2seq_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

6. Define Training Arguments

Configure all training hyperparameters: 1 epoch, batch size 1 with gradient accumulation of 16 steps (effective batch size = 16), warmup for 500 steps, weight decay for regularization.

In [ ]:
training_args = TrainingArguments(
    output_dir="t5-samsum-model",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=10,
    eval_steps=500,
    save_steps=1e6,
    gradient_accumulation_steps=16,
    report_to="none"
)

7. Initialize Trainer

Wire everything together — model, tokenizer, data collator, training args, and train/validation splits — into HuggingFace's Trainer.

In [ ]:
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=seq2seq_collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"]
)

8. Train the Model

Start training. The Trainer handles the full loop — forward pass, loss computation, backprop, weight updates, logging, and evaluation.

In [ ]:
trainer.train()

9. Save Model & Tokenizer

Save the fine-tuned model weights and tokenizer to disk so they can be reloaded for inference without retraining.

In [ ]:
model.save_pretrained("t5_samsum_finetuned_model")
tokenizer.save_pretrained("t5_samsum_tokenizer")

10. Reload & Setup for Inference

Reload the saved model and tokenizer from disk, then wrap them in HuggingFace's pipeline for simple one-call inference.

In [ ]:
tokenizer = AutoTokenizer.from_pretrained("t5_samsum_tokenizer")
model = AutoModelForSeq2SeqLM.from_pretrained("t5_samsum_finetuned_model").to("cuda")
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)

11. Test on Sample Dialogue

Define a sample multi-turn dialogue to test the model's summarization ability.

In [ ]:
sample_text = '''Luffy: Naruto! You won the ramen eating contest again?! That's your fifth win this month!

Naruto: Believe it, Luffy! Ichiraku's secret menu is my new training ground. Gotta keep up the chakra and the appetite!

Luffy: Haha! I like that! I trained by eating 20 meat-on-the-bone last night. Zoro thought I was insane.

Naruto: Bro, I've fought Akatsuki, and even I think that's dangerous. What's next? Competing with Goku?

Luffy: Maybe! But first I wanna become the Pirate King. Then I'll eat ramen on the moon!

Naruto: You sure talk big, rubber boy. But I respect that. Becoming Hokage wasn't easy either.

Luffy: We're kinda the same, huh? Chasing dreams, fighting crazy villains, making loyal friends.

Naruto: True that. Though I don't have a reindeer doctor or a skeleton with an afro.

Luffy: And I don't have a giant fox inside me. We're even!

Naruto: Hey, wanna team up for a mission? I heard there's a lost treasure in the Hidden Mist village.

Luffy: Treasure?! I'm in! Let's go find it, and maybe snack along the way.

Naruto: Deal. I'll bring the kunai, you bring the appetite.

Luffy: This is gonna be epic! Let's GO!!!

Naruto: Dattebayo!!!'''

12. Generate & Display Summary

Run the summarizer on the sample dialogue using greedy decoding (do_sample=False). Display the output as formatted Markdown.

In [ ]:
from IPython.display import Markdown, display

result = summarizer(sample_text, max_length=100, min_length=30, do_sample=False)
display(Markdown(f"**Summary:** {result[0]['summary_text']}"))