Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer | |
from model import SentimentClassifier | |
model_state_dict = torch.load('sentiment_model.pth') | |
model = SentimentClassifier(2) | |
model.load_state_dict(model_state_dict) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
def preprocess(text): | |
inputs = tokenizer(text, padding='max_length', | |
truncation=True, max_length=512, return_tensors='pt') | |
return inputs | |
# Define a function to use the model to make predictions | |
def predict(review): | |
inputs = preprocess(review) | |
with torch.no_grad(): | |
outputs = model(inputs['input_ids'], inputs['attention_mask']) | |
predicted_class = torch.argmax(outputs[0]).item() | |
if(predicted_class==0): | |
return "It was a negative review" | |
return "It was a positive review" | |
# Create a Gradio interface | |
input_text = gr.inputs.Textbox(label="Input Text") | |
output_text = gr.outputs.Textbox(label="Output Text") | |
interface = gr.Interface(fn=predict, inputs=input_text, outputs=output_text) | |
# Run the interface | |
interface.launch() | |