shrish191 commited on
Commit
56b85a5
·
verified ·
1 Parent(s): 5bca558

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py CHANGED
@@ -55,6 +55,7 @@ demo = gr.Interface(
55
 
56
  demo.launch()
57
  '''
 
58
  import gradio as gr
59
  from transformers import TFBertForSequenceClassification, BertTokenizer
60
  import tensorflow as tf
@@ -136,5 +137,81 @@ demo = gr.Interface(
136
  )
137
 
138
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
 
 
55
 
56
  demo.launch()
57
  '''
58
+ '''
59
  import gradio as gr
60
  from transformers import TFBertForSequenceClassification, BertTokenizer
61
  import tensorflow as tf
 
137
  )
138
 
139
  demo.launch()
140
+ '''
141
+ import gradio as gr
142
+ from transformers import TFBertForSequenceClassification, BertTokenizer
143
+ import tensorflow as tf
144
+ import praw
145
+ import os
146
+
147
+ # Load model and tokenizer from Hugging Face
148
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
149
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
150
+
151
+ # Label mapping
152
+ LABELS = {
153
+ 0: "Neutral",
154
+ 1: "Positive",
155
+ 2: "Negative"
156
+ }
157
+
158
+ # Reddit API setup (credentials loaded securely from secrets)
159
+ reddit = praw.Reddit(
160
+ client_id=os.getenv("REDDIT_CLIENT_ID"),
161
+ client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
162
+ user_agent=os.getenv("REDDIT_USER_AGENT", "sentiment-classifier-script")
163
+ )
164
+
165
+ # Reddit post fetcher
166
+ def fetch_reddit_text(reddit_url):
167
+ try:
168
+ submission = reddit.submission(url=reddit_url)
169
+ return f"{submission.title}\n\n{submission.selftext}"
170
+ except Exception as e:
171
+ return f"Error fetching Reddit post: {str(e)}"
172
+
173
+ # Main sentiment function
174
+ def classify_sentiment(text_input, reddit_url):
175
+ if reddit_url.strip():
176
+ text = fetch_reddit_text(reddit_url)
177
+ elif text_input.strip():
178
+ text = text_input
179
+ else:
180
+ return "[!] Please enter some text or a Reddit post URL."
181
+
182
+ if text.lower().startswith("error") or "Unable to extract" in text:
183
+ return f"[!] {text}"
184
+
185
+ try:
186
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
187
+ outputs = model(inputs)
188
+ probs = tf.nn.softmax(outputs.logits, axis=1)
189
+ pred_label = tf.argmax(probs, axis=1).numpy()[0]
190
+ confidence = float(tf.reduce_max(probs).numpy())
191
+ return f"Prediction: {LABELS[pred_label]} (Confidence: {confidence:.2f})"
192
+ except Exception as e:
193
+ return f"[!] Prediction error: {str(e)}"
194
+
195
+ # Gradio UI
196
+ demo = gr.Interface(
197
+ fn=classify_sentiment,
198
+ inputs=[
199
+ gr.Textbox(
200
+ label="Text Input (can be tweet or any content)",
201
+ placeholder="Paste tweet or type any content here...",
202
+ lines=4
203
+ ),
204
+ gr.Textbox(
205
+ label="Reddit Post URL",
206
+ placeholder="Paste a Reddit post URL (optional)",
207
+ lines=1
208
+ ),
209
+ ],
210
+ outputs="text",
211
+ title="Multilingual Sentiment Analysis",
212
+ description="🔍 Paste any text (including tweet content) OR a Reddit post URL to analyze sentiment.\n\n💡 Tweet URLs are not supported directly due to platform restrictions. Please paste tweet content manually."
213
+ )
214
+
215
+ demo.launch()
216
 
217