nhull commited on
Commit
8272328
·
verified ·
1 Parent(s): 0fa2e5e

Fix issues

Browse files
Files changed (1) hide show
  1. app.py +61 -76
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Disable GPU and enforce CPU execution
3
 
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
 
7
- # Load a fun unicorn image
8
- unicorn_image_path = "unicorn.png"
9
 
10
  import gradio as gr
11
  from transformers import (
@@ -22,7 +21,6 @@ from tensorflow.keras.models import load_model
22
  from tensorflow.keras.preprocessing.sequence import pad_sequences
23
  import re
24
 
25
- # Load GRU, LSTM, and BiLSTM models and tokenizers
26
  gru_repo_id = "arjahojnik/GRU-sentiment-model"
27
  gru_model_path = hf_hub_download(repo_id=gru_repo_id, filename="best_GRU_tuning_model.h5")
28
  gru_model = load_model(gru_model_path)
@@ -44,13 +42,11 @@ bilstm_tokenizer_path = hf_hub_download(repo_id=bilstm_repo_id, filename="my_tok
44
  with open(bilstm_tokenizer_path, "rb") as f:
45
  bilstm_tokenizer = pickle.load(f)
46
 
47
- # Preprocessing function for text
48
  def preprocess_text(text):
49
  text = text.lower()
50
  text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
51
  return text
52
 
53
- # Prediction functions for GRU, LSTM, and BiLSTM
54
  def predict_with_gru(text):
55
  cleaned = preprocess_text(text)
56
  seq = gru_tokenizer.texts_to_sequences([cleaned])
@@ -75,13 +71,12 @@ def predict_with_bilstm(text):
75
  predicted_class = np.argmax(probs, axis=1)[0]
76
  return int(predicted_class + 1)
77
 
78
- # Load other models
79
  models = {
80
  "DistilBERT": {
81
  "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
82
  "model": DistilBertForSequenceClassification.from_pretrained("nhull/distilbert-sentiment-model"),
83
  },
84
- "Logistic Regression": {}, # Placeholder for logistic regression
85
  "BERT Multilingual (NLP Town)": {
86
  "tokenizer": AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
87
  "model": AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
@@ -96,7 +91,6 @@ models = {
96
  }
97
  }
98
 
99
- # Logistic regression model and TF-IDF vectorizer
100
  logistic_regression_repo = "nhull/logistic-regression-model"
101
  log_reg_model_path = hf_hub_download(repo_id=logistic_regression_repo, filename="logistic_regression_model.pkl")
102
  with open(log_reg_model_path, "rb") as model_file:
@@ -106,13 +100,11 @@ vectorizer_path = hf_hub_download(repo_id=logistic_regression_repo, filename="tf
106
  with open(vectorizer_path, "rb") as vectorizer_file:
107
  vectorizer = pickle.load(vectorizer_file)
108
 
109
- # Move HuggingFace models to device
110
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
  for model_data in models.values():
112
  if "model" in model_data:
113
  model_data["model"].to(device)
114
 
115
- # Prediction functions for other models
116
  def predict_with_distilbert(text):
117
  tokenizer = models["DistilBERT"]["tokenizer"]
118
  model = models["DistilBERT"]["model"]
@@ -158,7 +150,6 @@ def predict_with_roberta_ordek899(text):
158
  predictions = logits.argmax(axis=-1).cpu().numpy()
159
  return int(predictions[0] + 1)
160
 
161
- # Unified function for analysis
162
  def analyze_sentiment_and_statistics(text):
163
  results = {
164
  "Logistic Regression": predict_with_logistic_regression(text),
@@ -190,7 +181,6 @@ def analyze_sentiment_and_statistics(text):
190
  }
191
  return results, statistics
192
 
193
- # Gradio Interface
194
  with gr.Blocks(
195
  css="""
196
  .gradio-container {
@@ -252,10 +242,9 @@ with gr.Blocks(
252
  }
253
  """
254
  ) as demo:
255
- # Add the unicorn image at the start
256
  gr.Image(
257
- value=unicorn_image_path, # File path or URL
258
- type="filepath", # Correct type for file paths
259
  elem_classes=["unicorn-image"]
260
  )
261
 
@@ -269,7 +258,6 @@ with gr.Blocks(
269
  """
270
  )
271
 
272
-
273
  with gr.Row():
274
  with gr.Column():
275
  text_input = gr.Textbox(
@@ -285,12 +273,15 @@ with gr.Blocks(
285
  "Terrible! The room was dirty, and the service was non-existent."
286
  ]
287
  sample_dropdown = gr.Dropdown(
288
- choices=sample_reviews,
289
  label="Or select a sample review:",
 
290
  interactive=True
291
  )
292
 
293
  def update_textbox(selected_sample):
 
 
294
  return selected_sample
295
 
296
  sample_dropdown.change(
@@ -318,12 +309,16 @@ with gr.Blocks(
318
  tinybert_output = gr.Textbox(label="TinyBERT", interactive=False)
319
  roberta_output = gr.Textbox(label="RoBERTa", interactive=False)
320
 
 
 
 
 
 
321
  with gr.Row():
322
  with gr.Column():
323
  gr.Markdown("### Statistics")
324
  stats_output = gr.Textbox(label="Statistics", interactive=False)
325
 
326
- # Add footer
327
  gr.Markdown(
328
  """
329
  <footer>
@@ -335,80 +330,69 @@ with gr.Blocks(
335
  </footer>
336
  """
337
  )
 
 
 
 
338
  def process_input_and_analyze(text_input):
339
- # Check for empty input
340
  if not text_input.strip():
341
  funny_message = "Are you sure you wrote something? Try again! 🧐"
342
  return (
343
- funny_message, # Logistic Regression
344
- funny_message, # GRU
345
- funny_message, # LSTM
346
- funny_message, # BiLSTM
347
- funny_message, # DistilBERT
348
- funny_message, # BERT Multilingual
349
- funny_message, # TinyBERT
350
- funny_message, # RoBERTa
351
- "No statistics to display, as nothing was input. 🤷‍♀️"
352
  )
353
 
354
- # Check for one letter/number input
355
  if len(text_input.strip()) == 1 or text_input.strip().isdigit():
356
  funny_message = "Why not write something that makes sense? 🤔"
357
  return (
358
- funny_message, # Logistic Regression
359
- funny_message, # GRU
360
- funny_message, # LSTM
361
- funny_message, # BiLSTM
362
- funny_message, # DistilBERT
363
- funny_message, # BERT Multilingual
364
- funny_message, # TinyBERT
365
- funny_message, # RoBERTa
366
- "No statistics to display for one letter or number. 😅"
367
  )
368
 
369
- # Check if the review is shorter than 5 words
370
  if len(text_input.split()) < 5:
371
  results, statistics = analyze_sentiment_and_statistics(text_input)
372
  short_message = "Maybe try with some longer text next time. 😉"
 
 
 
 
 
373
  return (
374
- f"{results['Logistic Regression']} - {short_message}",
375
- f"{results['GRU Model']} - {short_message}",
376
- f"{results['LSTM Model']} - {short_message}",
377
- f"{results['BiLSTM Model']} - {short_message}",
378
- f"{results['DistilBERT']} - {short_message}",
379
- f"{results['BERT Multilingual (NLP Town)']} - {short_message}",
380
- f"{results['TinyBERT']} - {short_message}",
381
- f"{results['RoBERTa']} - {short_message}",
382
- f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}\n{short_message}"
 
383
  )
384
 
385
- # Proceed with normal sentiment analysis if none of the above conditions apply
386
  results, statistics = analyze_sentiment_and_statistics(text_input)
 
 
387
  if "Message" in statistics:
388
- return (
389
- results["Logistic Regression"],
390
- results["GRU Model"],
391
- results["LSTM Model"],
392
- results["BiLSTM Model"],
393
- results["DistilBERT"],
394
- results["BERT Multilingual (NLP Town)"],
395
- results["TinyBERT"],
396
- results["RoBERTa"],
397
- f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
398
- )
399
  else:
400
- return (
401
- results["Logistic Regression"],
402
- results["GRU Model"],
403
- results["LSTM Model"],
404
- results["BiLSTM Model"],
405
- results["DistilBERT"],
406
- results["BERT Multilingual (NLP Town)"],
407
- results["TinyBERT"],
408
- results["RoBERTa"],
409
- f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
410
- )
411
-
 
 
 
412
  analyze_button.click(
413
  process_input_and_analyze,
414
  inputs=[text_input],
@@ -420,9 +404,10 @@ with gr.Blocks(
420
  distilbert_output,
421
  bert_output,
422
  tinybert_output,
423
- roberta_output,
424
- stats_output
 
425
  ]
426
  )
427
 
428
- demo.launch()
 
1
  import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
3
 
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
 
7
+ unicorn_image_path = "scripts/demo/unicorn.png"
 
8
 
9
  import gradio as gr
10
  from transformers import (
 
21
  from tensorflow.keras.preprocessing.sequence import pad_sequences
22
  import re
23
 
 
24
  gru_repo_id = "arjahojnik/GRU-sentiment-model"
25
  gru_model_path = hf_hub_download(repo_id=gru_repo_id, filename="best_GRU_tuning_model.h5")
26
  gru_model = load_model(gru_model_path)
 
42
  with open(bilstm_tokenizer_path, "rb") as f:
43
  bilstm_tokenizer = pickle.load(f)
44
 
 
45
  def preprocess_text(text):
46
  text = text.lower()
47
  text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
48
  return text
49
 
 
50
  def predict_with_gru(text):
51
  cleaned = preprocess_text(text)
52
  seq = gru_tokenizer.texts_to_sequences([cleaned])
 
71
  predicted_class = np.argmax(probs, axis=1)[0]
72
  return int(predicted_class + 1)
73
 
 
74
  models = {
75
  "DistilBERT": {
76
  "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
77
  "model": DistilBertForSequenceClassification.from_pretrained("nhull/distilbert-sentiment-model"),
78
  },
79
+ "Logistic Regression": {},
80
  "BERT Multilingual (NLP Town)": {
81
  "tokenizer": AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
82
  "model": AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
 
91
  }
92
  }
93
 
 
94
  logistic_regression_repo = "nhull/logistic-regression-model"
95
  log_reg_model_path = hf_hub_download(repo_id=logistic_regression_repo, filename="logistic_regression_model.pkl")
96
  with open(log_reg_model_path, "rb") as model_file:
 
100
  with open(vectorizer_path, "rb") as vectorizer_file:
101
  vectorizer = pickle.load(vectorizer_file)
102
 
 
103
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
  for model_data in models.values():
105
  if "model" in model_data:
106
  model_data["model"].to(device)
107
 
 
108
  def predict_with_distilbert(text):
109
  tokenizer = models["DistilBERT"]["tokenizer"]
110
  model = models["DistilBERT"]["model"]
 
150
  predictions = logits.argmax(axis=-1).cpu().numpy()
151
  return int(predictions[0] + 1)
152
 
 
153
  def analyze_sentiment_and_statistics(text):
154
  results = {
155
  "Logistic Regression": predict_with_logistic_regression(text),
 
181
  }
182
  return results, statistics
183
 
 
184
  with gr.Blocks(
185
  css="""
186
  .gradio-container {
 
242
  }
243
  """
244
  ) as demo:
 
245
  gr.Image(
246
+ value=unicorn_image_path,
247
+ type="filepath",
248
  elem_classes=["unicorn-image"]
249
  )
250
 
 
258
  """
259
  )
260
 
 
261
  with gr.Row():
262
  with gr.Column():
263
  text_input = gr.Textbox(
 
273
  "Terrible! The room was dirty, and the service was non-existent."
274
  ]
275
  sample_dropdown = gr.Dropdown(
276
+ choices=["Select an option"] + sample_reviews,
277
  label="Or select a sample review:",
278
+ value=None,
279
  interactive=True
280
  )
281
 
282
  def update_textbox(selected_sample):
283
+ if selected_sample == "Select an option":
284
+ return ""
285
  return selected_sample
286
 
287
  sample_dropdown.change(
 
309
  tinybert_output = gr.Textbox(label="TinyBERT", interactive=False)
310
  roberta_output = gr.Textbox(label="RoBERTa", interactive=False)
311
 
312
+ with gr.Row():
313
+ with gr.Column():
314
+ gr.Markdown("### Feedback")
315
+ feedback_output = gr.Textbox(label="Feedback", interactive=False)
316
+
317
  with gr.Row():
318
  with gr.Column():
319
  gr.Markdown("### Statistics")
320
  stats_output = gr.Textbox(label="Statistics", interactive=False)
321
 
 
322
  gr.Markdown(
323
  """
324
  <footer>
 
330
  </footer>
331
  """
332
  )
333
+
334
+ def convert_to_stars(rating):
335
+ return "★" * rating + "☆" * (5 - rating)
336
+
337
  def process_input_and_analyze(text_input):
 
338
  if not text_input.strip():
339
  funny_message = "Are you sure you wrote something? Try again! 🧐"
340
  return (
341
+ "", "", "", "", "", "", "", "",
342
+ funny_message,
343
+ "No statistics can be shown."
 
 
 
 
 
 
344
  )
345
 
 
346
  if len(text_input.strip()) == 1 or text_input.strip().isdigit():
347
  funny_message = "Why not write something that makes sense? 🤔"
348
  return (
349
+ "", "", "", "", "", "", "", "",
350
+ funny_message,
351
+ "No statistics can be shown."
 
 
 
 
 
 
352
  )
353
 
 
354
  if len(text_input.split()) < 5:
355
  results, statistics = analyze_sentiment_and_statistics(text_input)
356
  short_message = "Maybe try with some longer text next time. 😉"
357
+ stats_text = (
358
+ f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\n"
359
+ f"Average Score: {statistics['Average Score']}"
360
+ if "Message" not in statistics else f"Statistics:\n{statistics['Message']}"
361
+ )
362
  return (
363
+ convert_to_stars(results['Logistic Regression']),
364
+ convert_to_stars(results['GRU Model']),
365
+ convert_to_stars(results['LSTM Model']),
366
+ convert_to_stars(results['BiLSTM Model']),
367
+ convert_to_stars(results['DistilBERT']),
368
+ convert_to_stars(results['BERT Multilingual (NLP Town)']),
369
+ convert_to_stars(results['TinyBERT']),
370
+ convert_to_stars(results['RoBERTa']),
371
+ short_message,
372
+ stats_text
373
  )
374
 
 
375
  results, statistics = analyze_sentiment_and_statistics(text_input)
376
+ feedback_message = "Sentiment analysis completed successfully! 😊"
377
+
378
  if "Message" in statistics:
379
+ stats_text = f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
 
 
 
 
 
 
 
 
 
 
380
  else:
381
+ stats_text = f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
382
+
383
+ return (
384
+ convert_to_stars(results["Logistic Regression"]),
385
+ convert_to_stars(results["GRU Model"]),
386
+ convert_to_stars(results["LSTM Model"]),
387
+ convert_to_stars(results["BiLSTM Model"]),
388
+ convert_to_stars(results["DistilBERT"]),
389
+ convert_to_stars(results["BERT Multilingual (NLP Town)"]),
390
+ convert_to_stars(results["TinyBERT"]),
391
+ convert_to_stars(results["RoBERTa"]),
392
+ feedback_message,
393
+ stats_text
394
+ )
395
+
396
  analyze_button.click(
397
  process_input_and_analyze,
398
  inputs=[text_input],
 
404
  distilbert_output,
405
  bert_output,
406
  tinybert_output,
407
+ roberta_output,
408
+ feedback_output,
409
+ stats_output
410
  ]
411
  )
412
 
413
+ demo.launch()