shrish191 commited on
Commit
5f8ec9b
·
verified ·
1 Parent(s): 815e99c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -48,28 +48,31 @@ demo = gr.Interface(fn=classify_sentiment,
48
  demo.launch()
49
  '''
50
  import gradio as gr
51
- from transformers import TFBertForSequenceClassification, AutoTokenizer
52
  import tensorflow as tf
53
 
54
- # Load model and tokenizer
55
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
56
- tokenizer = AutoTokenizer.from_pretrained("shrish191/sentiment-bert")
57
 
58
  def classify_sentiment(text):
59
  text = text.lower().strip()
60
  inputs = tokenizer(text, return_tensors="tf", padding=True, truncation=True)
61
  outputs = model(inputs, training=False)
62
  logits = outputs.logits
63
- label_id = tf.argmax(logits, axis=1).numpy()[0]
64
 
65
- # Ensure label ID is a string before looking it up
66
- labels = model.config.id2label
67
- label_name = labels.get(str(label_id), "Unknown")
68
-
69
- print(f"Text: {text} | Label ID: {label_id} | Label: {label_name} | Logits: {logits.numpy()}")
70
- return label_name
 
 
 
71
 
72
- # Gradio UI
73
  demo = gr.Interface(
74
  fn=classify_sentiment,
75
  inputs=gr.Textbox(placeholder="Enter a tweet..."),
@@ -78,5 +81,6 @@ demo = gr.Interface(
78
  description="Multilingual BERT-based Sentiment Analysis"
79
  )
80
 
 
81
  demo.launch()
82
 
 
48
  demo.launch()
49
  '''
50
  import gradio as gr
51
+ from transformers import TFBertForSequenceClassification, BertTokenizer
52
  import tensorflow as tf
53
 
54
+ # Load model and tokenizer from Hugging Face Hub
55
  model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
56
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
57
 
58
  def classify_sentiment(text):
59
  text = text.lower().strip()
60
  inputs = tokenizer(text, return_tensors="tf", padding=True, truncation=True)
61
  outputs = model(inputs, training=False)
62
  logits = outputs.logits
63
+ label_id = int(tf.argmax(logits, axis=1).numpy()[0])
64
 
65
+ # Handle label mapping correctly
66
+ raw_labels = model.config.id2label
67
+ if isinstance(list(raw_labels.keys())[0], str):
68
+ label = raw_labels.get(str(label_id), "Unknown")
69
+ else:
70
+ label = raw_labels.get(label_id, "Unknown")
71
+
72
+ print(f"Text: {text} | Label ID: {label_id} | Label: {label} | Logits: {logits.numpy()}")
73
+ return label
74
 
75
+ # Define the Gradio interface
76
  demo = gr.Interface(
77
  fn=classify_sentiment,
78
  inputs=gr.Textbox(placeholder="Enter a tweet..."),
 
81
  description="Multilingual BERT-based Sentiment Analysis"
82
  )
83
 
84
+ # Launch the app
85
  demo.launch()
86