|
|
|
import gradio as gr |
|
import tensorflow as tf |
|
from tensorflow.keras.models import load_model |
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
import numpy as np |
|
import json |
|
import pickle |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from nltk.stem import WordNetLemmatizer |
|
import re |
|
import contractions |
|
from huggingface_hub import hf_hub_download |
|
import warnings |
|
from sklearn.exceptions import InconsistentVersionWarning |
|
|
|
|
|
warnings.filterwarnings("ignore", category=InconsistentVersionWarning) |
|
|
|
|
|
nltk.download('punkt', quiet=True) |
|
nltk.download('punkt_tab', quiet=True) |
|
nltk.download('wordnet', quiet=True) |
|
nltk.download('omw-1.4', quiet=True) |
|
|
|
|
|
lemmatizer = WordNetLemmatizer() |
|
|
|
|
|
class LuongAttention(tf.keras.layers.Layer): |
|
def __init__(self, **kwargs): |
|
super(LuongAttention, self).__init__(**kwargs) |
|
|
|
def build(self, input_shape): |
|
self.W = self.add_weight(name='attention_weight', |
|
shape=(input_shape[-1], input_shape[-1]), |
|
initializer='glorot_normal', |
|
trainable=True) |
|
self.b = self.add_weight(name='attention_bias', |
|
shape=(input_shape[-1],), |
|
initializer='zeros', |
|
trainable=True) |
|
super(LuongAttention, self).build(input_shape) |
|
|
|
def call(self, inputs): |
|
e = tf.keras.backend.tanh(tf.keras.backend.dot(inputs, self.W) + self.b) |
|
alpha = tf.keras.backend.softmax(e, axis=1) |
|
context = inputs * alpha |
|
context = tf.keras.backend.sum(context, axis=1) |
|
return context |
|
|
|
def get_config(self): |
|
config = super(LuongAttention, self).get_config() |
|
return config |
|
|
|
|
|
model_path = hf_hub_download(repo_id="logasanjeev/sentiment-analysis-bilstm-luong", filename="sentiment_model.h5") |
|
tokenizer_path = hf_hub_download(repo_id="logasanjeev/sentiment-analysis-bilstm-luong", filename="tokenizer.pkl") |
|
encoder_path = hf_hub_download(repo_id="logasanjeev/sentiment-analysis-bilstm-luong", filename="label_encoder.pkl") |
|
model = load_model( |
|
model_path, |
|
custom_objects={ |
|
"LuongAttention": LuongAttention, |
|
"focal_loss_fn": lambda y_true, y_pred: y_true |
|
} |
|
) |
|
with open(tokenizer_path, "rb") as f: |
|
tokenizer = pickle.load(f) |
|
with open(encoder_path, "rb") as f: |
|
label_encoder = pickle.load(f) |
|
|
|
|
|
OPTIMAL_THRESHOLD = 0.5173 |
|
|
|
|
|
def clean_text(text): |
|
if not isinstance(text, str): |
|
text = str(text) |
|
|
|
text = contractions.fix(text) |
|
|
|
text = text.lower() |
|
|
|
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) |
|
|
|
text = re.sub(r'@\w+|#\w+', '', text) |
|
|
|
text = re.sub(r'<.*?>+', '', text) |
|
|
|
text = re.sub(r'\n', '', text) |
|
|
|
text = re.sub(r'\w*\d\w*', '', text) |
|
|
|
text = re.sub(r'[^\w\s]', '', text) |
|
|
|
text = ' '.join(text.split()) |
|
|
|
tokens = word_tokenize(text) |
|
tokens = [lemmatizer.lemmatize(token, pos='v') for token in tokens] |
|
return ' '.join(tokens).strip() |
|
|
|
|
|
def predict_sentiment(text): |
|
if not text or not isinstance(text, str) or len(text.strip()) < 3: |
|
return "Please enter a valid sentence.", None, None |
|
|
|
|
|
cleaned = clean_text(text) |
|
seq = tokenizer.texts_to_sequences([cleaned]) |
|
if not seq or not any(x > 1 for x in seq[0]): |
|
return "Text too short or invalid.", None, None |
|
|
|
|
|
max_len = 60 |
|
pad = pad_sequences(seq, maxlen=max_len, padding='post', truncating='post') |
|
|
|
|
|
with tf.device('/CPU:0'): |
|
prob = model.predict(pad, verbose=0)[0][0] |
|
|
|
|
|
label_idx = (prob >= OPTIMAL_THRESHOLD).astype(int) |
|
sentiment = label_encoder.inverse_transform([label_idx])[0].lower() |
|
confidence = prob if sentiment == 'positive' else 1 - prob |
|
|
|
|
|
emoji = {"negative": "😣", "positive": "😊"} |
|
probs_dict = { |
|
"Negative": 1 - prob, |
|
"Positive": prob |
|
} |
|
|
|
return ( |
|
f"**Sentiment**: {sentiment.capitalize()} {emoji[sentiment]}", |
|
probs_dict, |
|
cleaned |
|
) |
|
|
|
|
|
css = """ |
|
body { font-family: 'Arial', sans-serif; } |
|
.gradio-container { max-width: 800px; margin: auto; } |
|
h1 { color: #1a73e8; text-align: center; } |
|
.textbox { border-radius: 8px; } |
|
.output-text { font-size: 1.2em; font-weight: bold; } |
|
.footer { text-align: center; color: #666; } |
|
.prob-bar { margin-top: 10px; } |
|
button { border-radius: 6px; } |
|
""" |
|
|
|
|
|
with gr.Blocks(theme="soft", css=css) as demo: |
|
gr.Markdown( |
|
""" |
|
# Sentiment Analysis App |
|
Predict the sentiment of your text (Negative or Positive) using a BiLSTM model with Luong attention. Optimized threshold (0.5173) for 86.58% accuracy. Try it out! |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
text_input = gr.Textbox( |
|
label="Your Text", |
|
placeholder="e.g., I wouldn't recommend it to anyone", |
|
lines=2 |
|
) |
|
predict_btn = gr.Button("Analyze Sentiment", variant="primary") |
|
|
|
output_text = gr.Markdown() |
|
prob_plot = gr.Label(label="Probability Distribution") |
|
cleaned_text = gr.Textbox(label="Cleaned Text", interactive=False) |
|
|
|
examples = gr.Examples( |
|
examples=[ |
|
"Not bad at all.", |
|
"Just what I needed today — a flat tire and a rainstorm. Living the dream!", |
|
"The movie was visually stunning, but the story was painfully slow.", |
|
"I wouldn’t recommend it to someone I like.", |
|
"For once, he didn’t mess it up." |
|
], |
|
inputs=text_input |
|
) |
|
|
|
|
|
predict_btn.click( |
|
fn=predict_sentiment, |
|
inputs=text_input, |
|
outputs=[output_text, prob_plot, cleaned_text] |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
<div class='footer'> |
|
Created by logasanjeev | Powered by Hugging Face & Gradio |
|
</div> |
|
""" |
|
) |
|
|
|
|
|
demo.launch() |