File size: 1,286 Bytes
b334778 19f79d5 5d50a44 9fc880b 5d50a44 4146933 10a975c 19f79d5 10a975c fdeaa3e 5d50a44 b334778 10a975c 5d50a44 b334778 5d50a44 b334778 5d50a44 b334778 10a975c 5d50a44 b334778 9fc880b 5d50a44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import gradio as gr
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# Load DistilBERT model and tokenizer
model_name = "bhadresh-savani/distilbert-base-uncased-finetuned-sentiment"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForSequenceClassification.from_pretrained(model_name)
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Define the prediction function
def predict_sentiment(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
return predictions.item()
# Gradio interface
with gr.Blocks() as sentiment_app:
gr.Markdown("<h1>Sentiment Analysis with DistilBERT</h1>")
input_box = gr.Textbox(label="Input Text", placeholder="Enter text to analyze sentiment")
output_box = gr.Textbox(label="Sentiment Result", placeholder="Sentiment result will appear here")
submit_button = gr.Button("Analyze Sentiment")
# Button click event
submit_button.click(fn=predict_sentiment, inputs=input_box, outputs=output_box)
# Launch the app
if __name__ == "__main__":
sentiment_app.launch()
|