thugCodeNinja's picture
Update app.py
b8985e3 verified
raw
history blame
3.26 kB
import gradio as gr
import torch
from torch.nn.functional import softmax
import shap
import requests
from transformers import RobertaTokenizer,RobertaForSequenceClassification, pipeline
from IPython.core.display import HTML
model_dir = 'temp'
tokenizer = RobertaTokenizer.from_pretrained(model_dir)
model = RobertaForSequenceClassification.from_pretrained(model_dir)
#pipe = pipeline("text-classification", model="thugCodeNinja/robertatemp")
pipe = pipeline("text-classification",model=model,tokenizer=tokenizer)
def process_text(input_text, input_file):
if input_text:
text = input_text
elif input_file is not None:
text = input_file.read().decode('utf-8')
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = softmax(logits, dim=1)
max_prob, predicted_class_id = torch.max(probs, dim=1)
prob = str(round(max_prob.item() * 100, 2))
label = model.config.id2label[predicted_class_id.item()]
final_label='Human' if model.config.id2label[predicted_class_id.item()]=='LABEL_0' else 'Chat-GPT'
processed_result = text
def search(text):
query = text
api_key = 'AIzaSyClvkiiJTZrCJ8BLqUY9I38WYmbve8g-c8'
search_engine_id = '53d064810efa44ce7'
url = f'https://www.googleapis.com/customsearch/v1?key={api_key}&cx={search_engine_id}&q={query}'
try:
response = requests.get(url)
data = response.json()
return data
except Exception as e:
return {'error': str(e)}
def find_plagiarism(text):
search_results = search(text)
if 'items' not in search_results:
return []
similar_articles = []
for item in search_results['items']:
title = item.get('title', '')
link = item.get('link', '')
similar_articles.append([ title,link])
return similar_articles[:5]
prediction = pipe([text])
explainer = shap.Explainer(pipe)
shap_values = explainer([text])
shap_plot_html = HTML(shap.plots.text(shap_values, display=False)).data
# HTML(shap.plots.text(shap_values, display=False))
# with open('rendered.html', 'w') as file:
# file.write(shap.plots.text(shap_values, display=False))
similar_articles = find_plagiarism(text)
return processed_result, prob, final_label, shap_plot_html,similar_articles
text_input = gr.Textbox(label="Enter text")
file_input = gr.File(label="Upload a text file")
outputs = [gr.Textbox(label="Processed text"), gr.Textbox(label="Probability"), gr.Textbox(label="Label"), gr.HTML(label="SHAP Plot"),gr.Dataframe(label="Similar Articles", headers=["Title", "Link"],row_count=5)]
title = "Group 2- ChatGPT text detection module"
description = '''Please upload text files and text input responsibly and await the explainable results. The approach in place includes finetuning a Roberta model for text classification.Once the classifications are done the decision is exaplined thorugh the SHAP text plot.
The probability is particularly explained by the attention plots through SHAP'''
gr.Interface(fn=process_text,title=title,description=description, inputs=[text_input, file_input], outputs=outputs).launch()