shrish191 commited on
Commit
c48887f
·
verified ·
1 Parent(s): 4ff4a8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -220,12 +220,12 @@ import tensorflow as tf
220
  import praw
221
  import os
222
 
223
- # Fallback imports
224
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
225
  import torch
226
  from scipy.special import softmax
227
 
228
- # Load main model and tokenizer
229
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
230
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
231
 
@@ -235,7 +235,7 @@ LABELS = {
235
  2: "Negative"
236
  }
237
 
238
- # Load fallback model and tokenizer
239
  fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
240
  fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
241
  fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
@@ -254,7 +254,7 @@ def fetch_reddit_text(reddit_url):
254
  except Exception as e:
255
  return f"Error fetching Reddit post: {str(e)}"
256
 
257
- # Fallback classifier using RoBERTa
258
  def fallback_classifier(text):
259
  encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
260
  with torch.no_grad():
 
220
  import praw
221
  import os
222
 
223
+
224
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
225
  import torch
226
  from scipy.special import softmax
227
 
228
+
229
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
230
  tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
231
 
 
235
  2: "Negative"
236
  }
237
 
238
+
239
  fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
240
  fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
241
  fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
 
254
  except Exception as e:
255
  return f"Error fetching Reddit post: {str(e)}"
256
 
257
+
258
  def fallback_classifier(text):
259
  encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
260
  with torch.no_grad():