shrish191 commited on
Commit
815e99c
·
verified ·
1 Parent(s): 75d548f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -51,22 +51,32 @@ import gradio as gr
51
  from transformers import TFBertForSequenceClassification, AutoTokenizer
52
  import tensorflow as tf
53
 
 
54
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
55
  tokenizer = AutoTokenizer.from_pretrained("shrish191/sentiment-bert")
56
 
57
  def classify_sentiment(text):
58
  text = text.lower().strip()
59
  inputs = tokenizer(text, return_tensors="tf", padding=True, truncation=True)
60
- predictions = model(inputs, training=False).logits # Prevent dropout at inference
61
- label = tf.argmax(predictions, axis=1).numpy()[0]
 
 
 
62
  labels = model.config.id2label
63
- print(f"Text: {text} | Prediction: {label} | Logits: {predictions.numpy()}")
64
- return labels[str(label)]
65
 
66
- demo = gr.Interface(fn=classify_sentiment,
67
- inputs=gr.Textbox(placeholder="Enter a tweet..."),
68
- outputs="text",
69
- title="Tweet Sentiment Classifier",
70
- description="Multilingual BERT-based Sentiment Analysis")
 
 
 
 
 
 
71
 
72
  demo.launch()
 
 
51
  from transformers import TFBertForSequenceClassification, AutoTokenizer
52
  import tensorflow as tf
53
 
54
+ # Load model and tokenizer
55
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
56
  tokenizer = AutoTokenizer.from_pretrained("shrish191/sentiment-bert")
57
 
58
  def classify_sentiment(text):
59
  text = text.lower().strip()
60
  inputs = tokenizer(text, return_tensors="tf", padding=True, truncation=True)
61
+ outputs = model(inputs, training=False)
62
+ logits = outputs.logits
63
+ label_id = tf.argmax(logits, axis=1).numpy()[0]
64
+
65
+ # Ensure label ID is a string before looking it up
66
  labels = model.config.id2label
67
+ label_name = labels.get(str(label_id), "Unknown")
 
68
 
69
+ print(f"Text: {text} | Label ID: {label_id} | Label: {label_name} | Logits: {logits.numpy()}")
70
+ return label_name
71
+
72
+ # Gradio UI
73
+ demo = gr.Interface(
74
+ fn=classify_sentiment,
75
+ inputs=gr.Textbox(placeholder="Enter a tweet..."),
76
+ outputs="text",
77
+ title="Tweet Sentiment Classifier",
78
+ description="Multilingual BERT-based Sentiment Analysis"
79
+ )
80
 
81
  demo.launch()
82
+