k-code commited on
Commit
24b3cda
·
1 Parent(s): 6f242b5

push to hugging face

Browse files
Files changed (4) hide show
  1. .gradio/certificate.pem +31 -0
  2. README.md +29 -0
  3. app.py +135 -0
  4. requirements.txt +70 -0
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SMS Spam Classifier
3
+ emoji: 📱
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # SMS Spam Classifier
13
+
14
+ This application uses a bidirectional LSTM model to classify SMS messages as either spam or legitimate (ham). Simply enter your text message, and the model will predict whether it's spam or not, along with a confidence score.
15
+
16
+ ## Usage
17
+
18
+ 1. Enter your text message in the input box
19
+ 2. Click submit
20
+ 3. The model will return its prediction (spam/ham) and confidence level
21
+
22
+ ## Model
23
+
24
+ The classifier uses a bidirectional LSTM architecture with:
25
+
26
+ - Word embeddings
27
+ - 2 LSTM layers
28
+ - Dropout for regularization
29
+ - Dense layers with ReLU activation
app.py CHANGED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ from datasets import load_dataset
4
+ import torch
5
+ from torch.utils.data import random_split
6
+ from collections import Counter
7
+ import torch.nn as nn
8
+
9
+
10
+ class LSTMClassifier(nn.Module):
11
+ def __init__(self, vocab_size, embedding_dim=200, hidden_dim=256):
12
+ super(LSTMClassifier, self).__init__()
13
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
14
+
15
+ self.lstm = nn.LSTM(
16
+ embedding_dim,
17
+ hidden_dim,
18
+ num_layers=2,
19
+ batch_first=True,
20
+ bidirectional=True,
21
+ dropout=0.3,
22
+ )
23
+
24
+ # Dropout layer
25
+ self.dropout = nn.Dropout(0.4)
26
+
27
+ # Additional dense layers
28
+ self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
29
+ self.fc2 = nn.Linear(hidden_dim, 2)
30
+
31
+ def forward(self, x):
32
+ embedded = self.embedding(x)
33
+
34
+ lstm_out, (hidden, cell) = self.lstm(embedded)
35
+
36
+ # Concatenate forward and backward hidden states
37
+ hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
38
+ hidden = self.dropout(hidden)
39
+
40
+ # Additional layer with ReLU activation
41
+ hidden = torch.relu(self.fc1(hidden))
42
+ hidden = self.dropout(hidden)
43
+
44
+ # Final classification layer
45
+ out = self.fc2(hidden)
46
+ return out
47
+
48
+
49
+ def create_vocabulary(ds, max_words=10000):
50
+ word2idx = {
51
+ "<PAD>": 0,
52
+ "<UNK>": 1,
53
+ }
54
+ words = []
55
+ for example in ds:
56
+ text = example["sms"]
57
+ text = text.lower()
58
+ text = re.sub(r"[^\w\s]", "", text)
59
+ words.extend(text.split())
60
+
61
+ word_counts = Counter(words)
62
+ common_words = word_counts.most_common(max_words - 2)
63
+ for word, _ in common_words:
64
+ word2idx[word] = len(word2idx)
65
+
66
+ return word2idx
67
+
68
+
69
+ def create_splits(ds):
70
+ # 80/20 split
71
+ full_dataset = ds['train']
72
+ train_size = int(0.8 * len(full_dataset))
73
+ test_size = len(full_dataset) - train_size
74
+
75
+ train_dataset, test_dataset = random_split(
76
+ full_dataset,
77
+ [train_size, test_size],
78
+ generator=torch.Generator().manual_seed(42),
79
+ )
80
+ return train_dataset, test_dataset
81
+
82
+
83
+ ds = load_dataset("ucirvine/sms_spam")
84
+ train_dataset, test_dataset = create_splits(ds)
85
+ vocab = create_vocabulary(train_dataset)
86
+
87
+ # First recreate the model architecture
88
+ model = LSTMClassifier(len(vocab), 100)
89
+ # Load the saved state dict
90
+ model.load_state_dict(torch.load('best_model.pth'))
91
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+
93
+ model = model.to(device)
94
+
95
+
96
+ def predict_text(model, text, word2idx, device, max_length=50):
97
+ # Set model to evaluation mode
98
+ model.eval()
99
+
100
+ # Preprocess the text (same as training)
101
+ text = text.lower()
102
+ words = text.split()
103
+
104
+ # Convert words to indices
105
+ indices = [word2idx.get(word, word2idx['<UNK>']) for word in words]
106
+
107
+ # Pad or truncate
108
+ if len(indices) < max_length:
109
+ indices += [word2idx['<PAD>']] * (max_length - len(indices))
110
+ else:
111
+ indices = indices[:max_length]
112
+
113
+ # Convert to tensor
114
+ with torch.no_grad():
115
+ input_tensor = torch.tensor(indices).unsqueeze(
116
+ 0).to(device) # Add batch dimension
117
+ outputs = model(input_tensor)
118
+ probabilities = torch.softmax(outputs, dim=1)
119
+ prediction = torch.argmax(outputs, dim=1)
120
+
121
+ return {
122
+ 'prediction': 'spam' if prediction.item() == 1 else 'ham',
123
+ 'confidence': probabilities[0][prediction].item()
124
+ }
125
+
126
+
127
+ interface = gr.Interface(
128
+ fn=lambda text: predict_text(model, text, vocab, device),
129
+ inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
130
+ outputs=gr.Textbox(),
131
+ title="SMS Spam Classifier",
132
+ description="Enter a text message to predict if it's spam or ham.",
133
+ )
134
+
135
+ interface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.14
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ attrs==25.3.0
8
+ certifi==2025.1.31
9
+ charset-normalizer==3.4.1
10
+ click==8.1.8
11
+ datasets==3.4.1
12
+ dill==0.3.8
13
+ fastapi==0.115.12
14
+ ffmpy==0.5.0
15
+ filelock==3.18.0
16
+ frozenlist==1.5.0
17
+ fsspec==2024.12.0
18
+ gradio==5.23.1
19
+ gradio_client==1.8.0
20
+ groovy==0.1.2
21
+ h11==0.14.0
22
+ httpcore==1.0.7
23
+ httpx==0.28.1
24
+ huggingface-hub==0.29.3
25
+ idna==3.10
26
+ Jinja2==3.1.6
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe==3.0.2
29
+ mdurl==0.1.2
30
+ mpmath==1.3.0
31
+ multidict==6.2.0
32
+ multiprocess==0.70.16
33
+ networkx==3.4.2
34
+ numpy==2.2.4
35
+ orjson==3.10.16
36
+ packaging==24.2
37
+ pandas==2.2.3
38
+ pillow==11.1.0
39
+ propcache==0.3.1
40
+ pyarrow==19.0.1
41
+ pydantic==2.10.6
42
+ pydantic_core==2.27.2
43
+ pydub==0.25.1
44
+ Pygments==2.19.1
45
+ python-dateutil==2.9.0.post0
46
+ python-multipart==0.0.20
47
+ pytz==2025.2
48
+ PyYAML==6.0.2
49
+ requests==2.32.3
50
+ rich==13.9.4
51
+ ruff==0.11.2
52
+ safehttpx==0.1.6
53
+ semantic-version==2.10.0
54
+ shellingham==1.5.4
55
+ six==1.17.0
56
+ sniffio==1.3.1
57
+ starlette==0.46.1
58
+ sympy==1.13.1
59
+ tomlkit==0.13.2
60
+ torch==2.6.0
61
+ torchvision==0.21.0
62
+ tqdm==4.67.1
63
+ typer==0.15.2
64
+ typing_extensions==4.13.0
65
+ tzdata==2025.2
66
+ urllib3==2.3.0
67
+ uvicorn==0.34.0
68
+ websockets==15.0.1
69
+ xxhash==3.5.0
70
+ yarl==1.18.3