shrish191 commited on
Commit
28143df
·
verified ·
1 Parent(s): 3f08853

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py CHANGED
@@ -359,6 +359,7 @@ demo = gr.Interface(
359
 
360
  demo.launch()
361
  '''
 
362
  import gradio as gr
363
  from transformers import TFBertForSequenceClassification, BertTokenizer
364
  import tensorflow as tf
@@ -533,7 +534,122 @@ demo = gr.TabbedInterface(
533
  )
534
 
535
  demo.launch()
 
 
 
 
 
 
536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
 
538
 
539
 
 
359
 
360
  demo.launch()
361
  '''
362
+ '''
363
  import gradio as gr
364
  from transformers import TFBertForSequenceClassification, BertTokenizer
365
  import tensorflow as tf
 
534
  )
535
 
536
  demo.launch()
537
+ '''
538
+ import gradio as gr
539
+ from transformers import TFBertForSequenceClassification, BertTokenizer
540
+ import tensorflow as tf
541
+ from sklearn.metrics import accuracy_score, f1_score, classification_report
542
+ import numpy as np
543
 
544
+ # Load models
545
+ model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
546
+ tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
547
+ LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
548
+
549
+ # Load fallback model
550
+ fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
551
+ fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
552
+ fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
553
+
554
+ def analyze_text(text, true_label=None):
555
+ try:
556
+ # Main model prediction
557
+ inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
558
+ outputs = model(inputs)
559
+ probs = tf.nn.softmax(outputs.logits, axis=1)
560
+ main_pred = LABELS[tf.argmax(probs, axis=1).numpy()[0]]
561
+
562
+ # Fallback model prediction
563
+ fallback_inputs = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
564
+ with torch.no_grad():
565
+ fallback_outputs = fallback_model(**fallback_inputs)
566
+ fallback_scores = softmax(fallback_outputs.logits.numpy()[0])
567
+ fallback_pred = ['Negative', 'Neutral', 'Positive'][np.argmax(fallback_scores)]
568
+
569
+ # Initialize results
570
+ result = f"""Main Model Prediction: {main_pred}
571
+ Fallback Model Prediction: {fallback_pred}"""
572
+
573
+ # Calculate metrics if true label provided
574
+ if true_label:
575
+ # Convert labels to numerical format
576
+ label_map = {v: k for k, v in LABELS.items()}
577
+ y_true = [label_map[true_label]]
578
+
579
+ # Main model metrics
580
+ y_pred_main = [label_map[main_pred]]
581
+ main_acc = accuracy_score(y_true, y_pred_main)
582
+ main_f1 = f1_score(y_true, y_pred_main, average='weighted')
583
+
584
+ # Fallback model metrics
585
+ fallback_label_map = {'Negative': 2, 'Neutral': 0, 'Positive': 1}
586
+ y_pred_fallback = [fallback_label_map[fallback_pred]]
587
+ fallback_acc = accuracy_score(y_true, y_pred_fallback)
588
+ fallback_f1 = f1_score(y_true, y_pred_fallback, average='weighted')
589
+
590
+ # Classification report
591
+ report = classification_report(
592
+ y_true, y_pred_main,
593
+ target_names=LABELS.values(),
594
+ output_dict=True
595
+ )
596
+
597
+ # Format metrics
598
+ metrics = f"""
599
+ \n\nPERFORMANCE METRICS (Single Sample):
600
+ ------------------------------------
601
+ Main Model:
602
+ Accuracy: {main_acc:.4f}
603
+ F1 Score: {main_f1:.4f}
604
+
605
+ Fallback Model:
606
+ Accuracy: {fallback_acc:.4f}
607
+ F1 Score: {fallback_f1:.4f}
608
+
609
+ Classification Report:
610
+ {classification_report(y_true, y_pred_main, target_names=LABELS.values())}
611
+ """
612
+
613
+ result += metrics
614
+
615
+ return result
616
+
617
+ except Exception as e:
618
+ return f"Error: {str(e)}"
619
+
620
+ # Gradio interface
621
+ demo = gr.Interface(
622
+ fn=analyze_text,
623
+ inputs=[
624
+ gr.Textbox(
625
+ label="Input Text",
626
+ placeholder="Enter text to analyze...",
627
+ lines=4
628
+ ),
629
+ gr.Dropdown(
630
+ label="True Label (optional, for metrics)",
631
+ choices=list(LABELS.values()),
632
+ value=None
633
+ )
634
+ ],
635
+ outputs=gr.Textbox(
636
+ label="Analysis Results",
637
+ lines=10
638
+ ),
639
+ title="Sentiment Analysis with Performance Metrics",
640
+ description="""Enter text and optionally select true label to generate:
641
+ - Predictions from both models
642
+ - Accuracy scores
643
+ - F1 scores
644
+ - Classification report""",
645
+ examples=[
646
+ ["I absolutely love this new feature!", "Positive"],
647
+ ["This is the worst experience ever!", "Negative"],
648
+ ["The product seems okay, nothing special.", "Neutral"]
649
+ ]
650
+ )
651
+
652
+ demo.launch()
653
 
654
 
655