philipobiorah commited on
Commit
3f74701
·
verified ·
1 Parent(s): 93308cb

add confidence level of prediction to display

Browse files
Files changed (1) hide show
  1. main.py +17 -11
main.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from flask import Flask, jsonify, request, render_template
3
  import pandas as pd
4
  import torch
5
  from transformers import BertTokenizer, BertForSequenceClassification
@@ -9,7 +9,6 @@ matplotlib.use('Agg') # Prevents GUI issues for Matplotlib
9
  import matplotlib.pyplot as plt
10
  import base64
11
  from io import BytesIO
12
- from flask import send_file
13
 
14
 
15
  # Ensure the file exists in the current directory
@@ -36,18 +35,23 @@ model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
36
 
37
  model.eval()
38
 
39
- # Function to Predict Sentiment
40
  def predict_sentiment(text):
41
  if not text.strip():
42
- return "Neutral" # Avoid processing empty text
43
 
44
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
45
 
46
  with torch.no_grad():
47
  outputs = model(**inputs)
48
 
49
- sentiment = outputs.logits.argmax(dim=1).item()
50
- return "Positive" if sentiment == 1 else "Negative"
 
 
 
 
 
51
 
52
  @app.route('/')
53
  def upload_file():
@@ -67,8 +71,8 @@ def analyze_text():
67
  if not text:
68
  return jsonify({"error": "No text provided!"}), 400 # Return JSON error message
69
 
70
- sentiment = predict_sentiment(text)
71
- return jsonify({"sentiment": sentiment}) # Return JSON response
72
 
73
  @app.route('/uploader', methods=['POST'])
74
  def upload_file_post():
@@ -86,8 +90,10 @@ def upload_file_post():
86
  if 'review' not in data.columns:
87
  return "Error: CSV file must contain a 'review' column!", 400
88
 
89
- # Predict sentiment for each review
90
- data['sentiment'] = data['review'].astype(str).apply(predict_sentiment)
 
 
91
 
92
  # Generate summary
93
  sentiment_counts = data['sentiment'].value_counts().to_dict()
@@ -114,4 +120,4 @@ def upload_file_post():
114
  return f"Error processing file: {str(e)}", 500
115
 
116
  if __name__ == '__main__':
117
- app.run(host='0.0.0.0', port=7860, debug=True)
 
1
  import os
2
+ from flask import Flask, jsonify, request, render_template, send_file
3
  import pandas as pd
4
  import torch
5
  from transformers import BertTokenizer, BertForSequenceClassification
 
9
  import matplotlib.pyplot as plt
10
  import base64
11
  from io import BytesIO
 
12
 
13
 
14
  # Ensure the file exists in the current directory
 
35
 
36
  model.eval()
37
 
38
+ # Function to Predict Sentiment + Confidence Score
39
  def predict_sentiment(text):
40
  if not text.strip():
41
+ return {"sentiment": "Neutral", "confidence": 0.0} # Avoid processing empty text
42
 
43
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
44
 
45
  with torch.no_grad():
46
  outputs = model(**inputs)
47
 
48
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)[0] # Convert logits to probabilities
49
+ sentiment_idx = probabilities.argmax().item() # Get predicted class (0 = Negative, 1 = Positive)
50
+ confidence = probabilities[sentiment_idx].item() * 100 # Convert to percentage
51
+
52
+ sentiment_label = "Positive" if sentiment_idx == 1 else "Negative"
53
+
54
+ return {"sentiment": sentiment_label, "confidence": round(confidence, 2)}
55
 
56
  @app.route('/')
57
  def upload_file():
 
71
  if not text:
72
  return jsonify({"error": "No text provided!"}), 400 # Return JSON error message
73
 
74
+ result = predict_sentiment(text)
75
+ return jsonify(result) # Return JSON response including confidence score
76
 
77
  @app.route('/uploader', methods=['POST'])
78
  def upload_file_post():
 
90
  if 'review' not in data.columns:
91
  return "Error: CSV file must contain a 'review' column!", 400
92
 
93
+ # Predict sentiment & confidence for each review
94
+ results = data['review'].astype(str).apply(predict_sentiment)
95
+ data['sentiment'] = results.apply(lambda x: x['sentiment'])
96
+ data['confidence'] = results.apply(lambda x: f"{x['confidence']}%")
97
 
98
  # Generate summary
99
  sentiment_counts = data['sentiment'].value_counts().to_dict()
 
120
  return f"Error processing file: {str(e)}", 500
121
 
122
  if __name__ == '__main__':
123
+ app.run(host='0.0.0.0', port=7860, debug=True)