Next Word Prediction using LSTM¶
Trains an LSTM model on Medium article titles to predict the next word in a sequence.
Importing Libraries
Install NLTK and import all required libraries for deep learning, data processing, and tokenization.
!pip install nltk
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize
import nltk
Loading The Data
Load the Medium articles dataset from CSV and preview the first few rows.
df = pd.read_csv('../input/medium-articles-dataset/medium_data.csv')
df.head(5)
Preparing The Training Data
Inspect dataset structure and column types.
df.info()
Combine all article titles into a single newline-separated string to use as the training corpus.
document = '\n'.join(df['title'].dropna().astype(str))
Download the NLTK punkt tokenizer and tokenize the entire document into lowercase word tokens.
nltk.download('punkt')
tokens = word_tokenize(document.lower())
Build a vocabulary mapping each unique token to an integer index. Reserve index 0 for unknown tokens.
vocab = {'<unk>': 0}
for token in Counter(tokens).keys():
if token not in vocab:
vocab[token] = len(vocab)
len(vocab)
Split the document back into individual sentences (one title per sentence).
input_sentences = document.split('\n')
Convert each token in a sentence to its corresponding vocabulary index. Unknown tokens map to index 0.
def text_to_indices(sentence, vocab):
numerical_sentence = []
for token in sentence:
if token in vocab:
numerical_sentence.append(vocab[token])
else:
numerical_sentence.append(vocab['<unk>'])
return numerical_sentence
Apply text_to_indices to every tokenized sentence to produce a list of index sequences.
input_numerical_sentences = []
for sentence in input_sentences:
input_numerical_sentences.append(text_to_indices(word_tokenize(sentence.lower()), vocab))
len(input_numerical_sentences)
Generate n-gram style training sequences: for each sentence, create all partial prefix sequences (length 2 to full).
training_sequence = []
for sentence in input_numerical_sentences:
for i in range(1, len(sentence)):
training_sequence.append(sentence[:i+1])
len(training_sequence)
Find the length of the longest sequence to determine the padding length.
len_list = [len(sequence) for sequence in training_sequence]
max_len = max(len_list)
max_len
Left-pad each sequence with zeros so all sequences have the same length.
padded_training_sequence = []
for sequence in training_sequence:
padded_training_sequence.append([0] * (max_len - len(sequence)) + sequence)
Convert the padded list to a PyTorch tensor.
padded_training_sequence = torch.tensor(padded_training_sequence, dtype=torch.long)
Split the tensor into inputs X (all tokens except the last) and targets y (last token to predict).
X = padded_training_sequence[:, :-1]
y = padded_training_sequence[:, -1]
Dataset & DataLoader
Define a custom PyTorch Dataset that wraps X and y for use with DataLoader.
class CustomDataset(Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
Instantiate the dataset and wrap it in a DataLoader with batch size 32 and shuffling enabled.
dataset = CustomDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
LSTM Model
Define the LSTM model: an embedding layer (vocab → 100-dim), an LSTM (100 → 150-dim hidden), and a fully connected output layer (150 → vocab size). The final hidden state is used for next-word prediction.
class LSTMModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, 100)
self.lstm = nn.LSTM(100, 150, batch_first=True)
self.fc = nn.Linear(150, vocab_size)
def forward(self, x):
embedded = self.embedding(x)
intermediate_hidden_states, (final_hidden_state, final_cell_state) = self.lstm(embedded)
output = self.fc(final_hidden_state.squeeze(0))
return output
Instantiate the model, detect available device (GPU or CPU), and move the model to that device.
model = LSTMModel(len(vocab))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
Set training hyperparameters: 50 epochs, learning rate 0.001. Use CrossEntropyLoss and Adam optimizer.
epochs = 50
learning_rate = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Training
Run the training loop: for each epoch, iterate over all batches — forward pass, compute loss, backpropagate, update weights. Print total loss per epoch.
for epoch in range(epochs):
total_loss = 0
for batch_x, batch_y in dataloader:
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch: {epoch + 1}, Loss: {total_loss:.4f}")
Trying Our Model
Define a prediction function that tokenizes input text, pads it, runs a forward pass, and returns the input text with the predicted next word appended.
import time
def prediction(model, vocab, text):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
tokenized_text = word_tokenize(text.lower())
numerical_text = text_to_indices(tokenized_text, vocab)
padded_text = torch.tensor([0] * (51 - len(numerical_text)) + numerical_text, dtype=torch.long).unsqueeze(0)
padded_text = padded_text.to(device)
output = model(padded_text)
value, index = torch.max(output, dim=1)
predicted_token = list(vocab.keys())[index]
return text + " " + predicted_token
Test single next-word predictions on sample inputs.
print(prediction(model, vocab, "Databricks: How to Save Files in"))
print(prediction(model, vocab, "A Step-by-Step Implementation of"))
Autoregressively generate the next 10 tokens by feeding each prediction back as input.
num_tokens = 10
input_text = "A Step-by-Step Implementation of"
for i in range(num_tokens):
output_text = prediction(model, vocab, input_text)
print(output_text)
input_text = output_text
time.sleep(0.5)