milorable commited on
Commit
83c88ac
·
verified ·
1 Parent(s): cdc054c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +411 -0
app.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3
+
4
+ from PIL import Image
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ unicorn_image_path = "unicorn.png"
8
+
9
+ import gradio as gr
10
+ from transformers import (
11
+ DistilBertTokenizerFast,
12
+ DistilBertForSequenceClassification,
13
+ AutoTokenizer,
14
+ AutoModelForSequenceClassification,
15
+ )
16
+ from huggingface_hub import hf_hub_download
17
+ import torch
18
+ import pickle
19
+ import numpy as np
20
+ from tensorflow.keras.models import load_model
21
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
22
+ import re
23
+
24
+ gru_repo_id = "arjahojnik/GRU-sentiment-model"
25
+ gru_model_path = hf_hub_download(repo_id=gru_repo_id, filename="best_GRU_tuning_model.h5")
26
+ gru_model = load_model(gru_model_path)
27
+ gru_tokenizer_path = hf_hub_download(repo_id=gru_repo_id, filename="my_tokenizer.pkl")
28
+ with open(gru_tokenizer_path, "rb") as f:
29
+ gru_tokenizer = pickle.load(f)
30
+
31
+ lstm_repo_id = "arjahojnik/LSTM-sentiment-model"
32
+ lstm_model_path = hf_hub_download(repo_id=lstm_repo_id, filename="LSTM_model.h5")
33
+ lstm_model = load_model(lstm_model_path)
34
+ lstm_tokenizer_path = hf_hub_download(repo_id=lstm_repo_id, filename="my_tokenizer.pkl")
35
+ with open(lstm_tokenizer_path, "rb") as f:
36
+ lstm_tokenizer = pickle.load(f)
37
+
38
+ bilstm_repo_id = "arjahojnik/BiLSTM-sentiment-model"
39
+ bilstm_model_path = hf_hub_download(repo_id=bilstm_repo_id, filename="BiLSTM_model.h5")
40
+ bilstm_model = load_model(bilstm_model_path)
41
+ bilstm_tokenizer_path = hf_hub_download(repo_id=bilstm_repo_id, filename="my_tokenizer.pkl")
42
+ with open(bilstm_tokenizer_path, "rb") as f:
43
+ bilstm_tokenizer = pickle.load(f)
44
+
45
+ def preprocess_text(text):
46
+ text = text.lower()
47
+ text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
48
+ return text
49
+
50
+ def predict_with_gru(text):
51
+ cleaned = preprocess_text(text)
52
+ seq = gru_tokenizer.texts_to_sequences([cleaned])
53
+ padded_seq = pad_sequences(seq, maxlen=200)
54
+ probs = gru_model.predict(padded_seq)
55
+ predicted_class = np.argmax(probs, axis=1)[0]
56
+ return int(predicted_class + 1)
57
+
58
+ def predict_with_lstm(text):
59
+ cleaned = preprocess_text(text)
60
+ seq = lstm_tokenizer.texts_to_sequences([cleaned])
61
+ padded_seq = pad_sequences(seq, maxlen=200)
62
+ probs = lstm_model.predict(padded_seq)
63
+ predicted_class = np.argmax(probs, axis=1)[0]
64
+ return int(predicted_class + 1)
65
+
66
+ def predict_with_bilstm(text):
67
+ cleaned = preprocess_text(text)
68
+ seq = bilstm_tokenizer.texts_to_sequences([cleaned])
69
+ padded_seq = pad_sequences(seq, maxlen=200)
70
+ probs = bilstm_model.predict(padded_seq)
71
+ predicted_class = np.argmax(probs, axis=1)[0]
72
+ return int(predicted_class + 1)
73
+
74
+ models = {
75
+ "DistilBERT": {
76
+ "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
77
+ "model": DistilBertForSequenceClassification.from_pretrained("nhull/distilbert-sentiment-model"),
78
+ },
79
+ "Logistic Regression": {},
80
+ "BERT Multilingual (NLP Town)": {
81
+ "tokenizer": AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
82
+ "model": AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
83
+ },
84
+ "TinyBERT": {
85
+ "tokenizer": AutoTokenizer.from_pretrained("elo4/TinyBERT-sentiment-model"),
86
+ "model": AutoModelForSequenceClassification.from_pretrained("elo4/TinyBERT-sentiment-model"),
87
+ },
88
+ "RoBERTa": {
89
+ "tokenizer": AutoTokenizer.from_pretrained("ordek899/roberta_1to5rating_pred_for_restaur_trained_on_hotels"),
90
+ "model": AutoModelForSequenceClassification.from_pretrained("ordek899/roberta_1to5rating_pred_for_restaur_trained_on_hotels"),
91
+ }
92
+ }
93
+
94
+ logistic_regression_repo = "nhull/logistic-regression-model"
95
+ log_reg_model_path = hf_hub_download(repo_id=logistic_regression_repo, filename="logistic_regression_model.pkl")
96
+ with open(log_reg_model_path, "rb") as model_file:
97
+ log_reg_model = pickle.load(model_file)
98
+
99
+ vectorizer_path = hf_hub_download(repo_id=logistic_regression_repo, filename="tfidf_vectorizer.pkl")
100
+ with open(vectorizer_path, "rb") as vectorizer_file:
101
+ vectorizer = pickle.load(vectorizer_file)
102
+
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ for model_data in models.values():
105
+ if "model" in model_data:
106
+ model_data["model"].to(device)
107
+
108
+ def predict_with_distilbert(text):
109
+ tokenizer = models["DistilBERT"]["tokenizer"]
110
+ model = models["DistilBERT"]["model"]
111
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
112
+ with torch.no_grad():
113
+ outputs = model(**encodings)
114
+ logits = outputs.logits
115
+ predictions = logits.argmax(axis=-1).cpu().numpy()
116
+ return int(predictions[0] + 1)
117
+
118
+ def predict_with_logistic_regression(text):
119
+ transformed_text = vectorizer.transform([text])
120
+ predictions = log_reg_model.predict(transformed_text)
121
+ return int(predictions[0])
122
+
123
+ def predict_with_bert_multilingual(text):
124
+ tokenizer = models["BERT Multilingual (NLP Town)"]["tokenizer"]
125
+ model = models["BERT Multilingual (NLP Town)"]["model"]
126
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
127
+ with torch.no_grad():
128
+ outputs = model(**encodings)
129
+ logits = outputs.logits
130
+ predictions = logits.argmax(axis=-1).cpu().numpy()
131
+ return int(predictions[0] + 1)
132
+
133
+ def predict_with_tinybert(text):
134
+ tokenizer = models["TinyBERT"]["tokenizer"]
135
+ model = models["TinyBERT"]["model"]
136
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
137
+ with torch.no_grad():
138
+ outputs = model(**encodings)
139
+ logits = outputs.logits
140
+ predictions = logits.argmax(axis=-1).cpu().numpy()
141
+ return int(predictions[0] + 1)
142
+
143
+ def predict_with_roberta_ordek899(text):
144
+ tokenizer = models["RoBERTa"]["tokenizer"]
145
+ model = models["RoBERTa"]["model"]
146
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
147
+ with torch.no_grad():
148
+ outputs = model(**encodings)
149
+ logits = outputs.logits
150
+ predictions = logits.argmax(axis=-1).cpu().numpy()
151
+ return int(predictions[0] + 1)
152
+
153
+ def analyze_sentiment_and_statistics(text):
154
+ results = {
155
+ "Logistic Regression": predict_with_logistic_regression(text),
156
+ "GRU Model": predict_with_gru(text),
157
+ "LSTM Model": predict_with_lstm(text),
158
+ "BiLSTM Model": predict_with_bilstm(text),
159
+ "DistilBERT": predict_with_distilbert(text),
160
+ "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
161
+ "TinyBERT": predict_with_tinybert(text),
162
+ "RoBERTa": predict_with_roberta_ordek899(text),
163
+ }
164
+ scores = list(results.values())
165
+ min_score = min(scores)
166
+ max_score = max(scores)
167
+ min_score_models = [model for model, score in results.items() if score == min_score]
168
+ max_score_models = [model for model, score in results.items() if score == max_score]
169
+ average_score = np.mean(scores)
170
+
171
+ if all(score == scores[0] for score in scores):
172
+ statistics = {
173
+ "Message": "All models predict the same score.",
174
+ "Average Score": f"{average_score:.2f}",
175
+ }
176
+ else:
177
+ statistics = {
178
+ "Lowest Score": f"{min_score} (Models: {', '.join(min_score_models)})",
179
+ "Highest Score": f"{max_score} (Models: {', '.join(max_score_models)})",
180
+ "Average Score": f"{average_score:.2f}",
181
+ }
182
+ return results, statistics
183
+
184
+ with gr.Blocks(
185
+ css="""
186
+ .gradio-container {
187
+ max-width: 900px;
188
+ margin: auto;
189
+ padding: 20px;
190
+ }
191
+ h1 {
192
+ text-align: center;
193
+ font-size: 2.5rem;
194
+ }
195
+ .unicorn-image {
196
+ display: block;
197
+ margin: auto;
198
+ width: 300px; /* Larger size */
199
+ height: auto;
200
+ border-radius: 20px;
201
+ margin-bottom: 20px;
202
+ animation: magical-float 5s ease-in-out infinite; /* Gentle floating animation */
203
+ }
204
+ @keyframes magical-float {
205
+ 0% {
206
+ transform: translate(0, 0) rotate(0deg); /* Start position */
207
+ }
208
+ 25% {
209
+ transform: translate(10px, -10px) rotate(3deg); /* Slightly up and right, tilted */
210
+ }
211
+ 50% {
212
+ transform: translate(0, -20px) rotate(0deg); /* Higher point, back to straight */
213
+ }
214
+ 75% {
215
+ transform: translate(-10px, -10px) rotate(-3deg); /* Slightly up and left, tilted */
216
+ }
217
+ 100% {
218
+ transform: translate(0, 0) rotate(0deg); /* Return to start position */
219
+ }
220
+ }
221
+ footer {
222
+ text-align: center;
223
+ margin-top: 20px;
224
+ font-size: 14px;
225
+ color: gray;
226
+ }
227
+ .custom-analyze-button {
228
+ background-color: #e8a4c9;
229
+ color: white;
230
+ font-size: 1rem;
231
+ padding: 10px 20px;
232
+ border-radius: 10px;
233
+ border: none;
234
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
235
+ transition: transform 0.2s, background-color 0.2s;
236
+ }
237
+ .custom-analyze-button:hover {
238
+ background-color: #d693b8;
239
+ transform: scale(1.05);
240
+ }
241
+ """
242
+ ) as demo:
243
+ gr.Image(
244
+ value=unicorn_image_path,
245
+ type="filepath",
246
+ elem_classes=["unicorn-image"]
247
+ )
248
+
249
+
250
+ gr.Markdown("# Sentiment Analysis Demo")
251
+ gr.Markdown(
252
+ """
253
+ Welcome! A magical unicorn 🦄 will guide you through this sentiment analysis journey! 🎉
254
+ This app lets you explore how different models interpret sentiment and compare their predictions.
255
+ **Enjoy the magic!**
256
+ """
257
+ )
258
+
259
+ with gr.Row():
260
+ with gr.Column():
261
+ text_input = gr.Textbox(
262
+ label="Enter your text here:",
263
+ lines=3,
264
+ placeholder="Type your hotel/restaurant review here..."
265
+ )
266
+ sample_reviews = [
267
+ "The hotel was fantastic! Clean rooms and excellent service.",
268
+ "The food was horrible, and the staff was rude.",
269
+ "Amazing experience overall. Highly recommend!",
270
+ "It was okay, not great but not terrible either.",
271
+ "Terrible! The room was dirty, and the service was non-existent."
272
+ ]
273
+ sample_dropdown = gr.Dropdown(
274
+ choices=["Select an option"] + sample_reviews,
275
+ label="Or select a sample review:",
276
+ value=None,
277
+ interactive=True
278
+ )
279
+
280
+ def update_textbox(selected_sample):
281
+ if selected_sample == "Select an option":
282
+ return ""
283
+ return selected_sample
284
+
285
+ sample_dropdown.change(
286
+ update_textbox,
287
+ inputs=[sample_dropdown],
288
+ outputs=[text_input]
289
+ )
290
+ analyze_button = gr.Button("Analyze Sentiment", elem_classes=["custom-analyze-button"])
291
+
292
+ with gr.Row():
293
+ with gr.Column():
294
+ gr.Markdown("### Machine Learning")
295
+ log_reg_output = gr.Textbox(label="Logistic Regression", interactive=False)
296
+
297
+ with gr.Column():
298
+ gr.Markdown("### Deep Learning")
299
+ gru_output = gr.Textbox(label="GRU Model", interactive=False)
300
+ lstm_output = gr.Textbox(label="LSTM Model", interactive=False)
301
+ bilstm_output = gr.Textbox(label="BiLSTM Model", interactive=False)
302
+
303
+ with gr.Column():
304
+ gr.Markdown("### Transformers")
305
+ distilbert_output = gr.Textbox(label="DistilBERT", interactive=False)
306
+ bert_output = gr.Textbox(label="BERT Multilingual", interactive=False)
307
+ tinybert_output = gr.Textbox(label="TinyBERT", interactive=False)
308
+ roberta_output = gr.Textbox(label="RoBERTa", interactive=False)
309
+
310
+ with gr.Row():
311
+ with gr.Column():
312
+ gr.Markdown("### Feedback")
313
+ feedback_output = gr.Textbox(label="Feedback", interactive=False)
314
+
315
+ with gr.Row():
316
+ with gr.Column():
317
+ gr.Markdown("### Statistics")
318
+ stats_output = gr.Textbox(label="Statistics", interactive=False)
319
+
320
+ gr.Markdown(
321
+ """
322
+ <footer>
323
+ This demo was built as a part of the NLP course at the University of Zagreb.
324
+ Check out our GitHub repository:
325
+ <a href="https://github.com/FFZG-NLP-2024/TripAdvisor-Sentiment/" target="_blank">TripAdvisor Sentiment Analysis</a>
326
+ or explore our HuggingFace collection:
327
+ <a href="https://huggingface.co/collections/nhull/nlp-zg-6794604b85fd4216e6470d38" target="_blank">NLP Zagreb HuggingFace Collection</a>.
328
+ </footer>
329
+ """
330
+ )
331
+
332
+ def convert_to_stars(rating):
333
+ return "★" * rating + "☆" * (5 - rating)
334
+
335
+ def process_input_and_analyze(text_input):
336
+ if not text_input.strip():
337
+ funny_message = "Are you sure you wrote something? Try again! 🧐"
338
+ return (
339
+ "", "", "", "", "", "", "", "",
340
+ funny_message,
341
+ "No statistics can be shown."
342
+ )
343
+
344
+ if len(text_input.strip()) == 1 or text_input.strip().isdigit():
345
+ funny_message = "Why not write something that makes sense? 🤔"
346
+ return (
347
+ "", "", "", "", "", "", "", "",
348
+ funny_message,
349
+ "No statistics can be shown."
350
+ )
351
+
352
+ if len(text_input.split()) < 5:
353
+ results, statistics = analyze_sentiment_and_statistics(text_input)
354
+ short_message = "Maybe try with some longer text next time. 😉"
355
+ stats_text = (
356
+ f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\n"
357
+ f"Average Score: {statistics['Average Score']}"
358
+ if "Message" not in statistics else f"Statistics:\n{statistics['Message']}"
359
+ )
360
+ return (
361
+ convert_to_stars(results['Logistic Regression']),
362
+ convert_to_stars(results['GRU Model']),
363
+ convert_to_stars(results['LSTM Model']),
364
+ convert_to_stars(results['BiLSTM Model']),
365
+ convert_to_stars(results['DistilBERT']),
366
+ convert_to_stars(results['BERT Multilingual (NLP Town)']),
367
+ convert_to_stars(results['TinyBERT']),
368
+ convert_to_stars(results['RoBERTa']),
369
+ short_message,
370
+ stats_text
371
+ )
372
+
373
+ results, statistics = analyze_sentiment_and_statistics(text_input)
374
+ feedback_message = "Sentiment analysis completed successfully! 😊"
375
+
376
+ if "Message" in statistics:
377
+ stats_text = f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
378
+ else:
379
+ stats_text = f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
380
+
381
+ return (
382
+ convert_to_stars(results["Logistic Regression"]),
383
+ convert_to_stars(results["GRU Model"]),
384
+ convert_to_stars(results["LSTM Model"]),
385
+ convert_to_stars(results["BiLSTM Model"]),
386
+ convert_to_stars(results["DistilBERT"]),
387
+ convert_to_stars(results["BERT Multilingual (NLP Town)"]),
388
+ convert_to_stars(results["TinyBERT"]),
389
+ convert_to_stars(results["RoBERTa"]),
390
+ feedback_message,
391
+ stats_text
392
+ )
393
+
394
+ analyze_button.click(
395
+ process_input_and_analyze,
396
+ inputs=[text_input],
397
+ outputs=[
398
+ log_reg_output,
399
+ gru_output,
400
+ lstm_output,
401
+ bilstm_output,
402
+ distilbert_output,
403
+ bert_output,
404
+ tinybert_output,
405
+ roberta_output,
406
+ feedback_output,
407
+ stats_output
408
+ ]
409
+ )
410
+
411
+ demo.launch()