|
import re |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import torch |
|
from torch.utils.data import random_split |
|
from collections import Counter |
|
import torch.nn as nn |
|
|
|
|
|
class LSTMClassifier(nn.Module): |
|
def __init__(self, vocab_size, embedding_dim=200, hidden_dim=256): |
|
super(LSTMClassifier, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) |
|
|
|
self.lstm = nn.LSTM( |
|
embedding_dim, |
|
hidden_dim, |
|
num_layers=2, |
|
batch_first=True, |
|
bidirectional=True, |
|
dropout=0.3, |
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(0.4) |
|
|
|
|
|
self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim) |
|
self.fc2 = nn.Linear(hidden_dim, 2) |
|
|
|
def forward(self, x): |
|
embedded = self.embedding(x) |
|
|
|
lstm_out, (hidden, cell) = self.lstm(embedded) |
|
|
|
|
|
hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) |
|
hidden = self.dropout(hidden) |
|
|
|
|
|
hidden = torch.relu(self.fc1(hidden)) |
|
hidden = self.dropout(hidden) |
|
|
|
|
|
out = self.fc2(hidden) |
|
return out |
|
|
|
|
|
def create_vocabulary(ds, max_words=10000): |
|
word2idx = { |
|
"<PAD>": 0, |
|
"<UNK>": 1, |
|
} |
|
words = [] |
|
for example in ds: |
|
text = example["sms"] |
|
text = text.lower() |
|
text = re.sub(r"[^\w\s]", "", text) |
|
words.extend(text.split()) |
|
|
|
word_counts = Counter(words) |
|
common_words = word_counts.most_common(max_words - 2) |
|
for word, _ in common_words: |
|
word2idx[word] = len(word2idx) |
|
|
|
return word2idx |
|
|
|
|
|
def create_splits(ds): |
|
|
|
full_dataset = ds['train'] |
|
train_size = int(0.8 * len(full_dataset)) |
|
test_size = len(full_dataset) - train_size |
|
|
|
train_dataset, test_dataset = random_split( |
|
full_dataset, |
|
[train_size, test_size], |
|
generator=torch.Generator().manual_seed(42), |
|
) |
|
return train_dataset, test_dataset |
|
|
|
|
|
ds = load_dataset("ucirvine/sms_spam") |
|
train_dataset, test_dataset = create_splits(ds) |
|
vocab = create_vocabulary(train_dataset) |
|
|
|
|
|
model = LSTMClassifier(len(vocab), 100) |
|
|
|
model.load_state_dict(torch.load('best_model.pth')) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
model = model.to(device) |
|
|
|
|
|
def predict_text(model, text, word2idx, device, max_length=50): |
|
|
|
model.eval() |
|
|
|
|
|
text = text.lower() |
|
words = text.split() |
|
|
|
|
|
indices = [word2idx.get(word, word2idx['<UNK>']) for word in words] |
|
|
|
|
|
if len(indices) < max_length: |
|
indices += [word2idx['<PAD>']] * (max_length - len(indices)) |
|
else: |
|
indices = indices[:max_length] |
|
|
|
|
|
with torch.no_grad(): |
|
input_tensor = torch.tensor(indices).unsqueeze( |
|
0).to(device) |
|
outputs = model(input_tensor) |
|
probabilities = torch.softmax(outputs, dim=1) |
|
prediction = torch.argmax(outputs, dim=1) |
|
|
|
return { |
|
'prediction': 'spam' if prediction.item() == 1 else 'ham', |
|
'confidence': probabilities[0][prediction].item() |
|
} |
|
|
|
|
|
interface = gr.Interface( |
|
fn=lambda text: predict_text(model, text, vocab, device), |
|
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."), |
|
outputs=gr.Textbox(), |
|
title="SMS Spam Classifier", |
|
description="Enter a text message to predict if it's spam or ham.", |
|
) |
|
|
|
interface.launch(share=True) |
|
|