Spaces:
Build error
Build error
File size: 2,811 Bytes
71f2227 3e149d5 71f2227 0a042d8 cdfe192 703a8b7 0a042d8 61a2cb3 fbd4cfd 0d07fde fbd4cfd 870b1fe 482fbde fd58ec5 2dfc510 ed80cd3 3e149d5 55f3dd8 c669d92 f13ff81 691bbb0 aa451fe 6def779 c669d92 f13ff81 08d35ca 703a8b7 5a4fdf8 a6bca23 6986639 5a4fdf8 db0e2dc cdfe192 6bbb073 c556324 2ddaf14 6bbb073 2ddaf14 337fe11 cdfe192 5a4fdf8 db0e2dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector')
import torch
import gradio as gr
def chat(message, history=[]):
new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
if len(history) > 0:
last_set_of_ids = history[len(history)-1][2]
bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1)
else:
bot_input_ids = new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
response_ids = chat_history_ids[:, bot_input_ids.shape[-1]:][0]
response = tokenizer.decode(response_ids, skip_special_tokens=True)
history.append((message, response, chat_history_ids))
return history, history, feedback(message)
def feedback(text):
num_return_sequences=1
batch = grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt")
corrections = grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
corrected_text = grammar_tokenizer.decode(corrections[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
print("The corrected text is: ", corrected_text)
print("The orig text is: ", text)
if corrected_text.rstrip('.') == text.rstrip('.'):
# if corrected_text == text:
feedback = f'Looks good! Keep up the good work'
else:
feedback = f'\'{corrected_text}\' might be a little better'
return feedback
title = "A chatbot that provides grammar feedback"
description = "A quick proof of concept using Gradio"
article = "<p style='text-align: center'><a href='https://docs.google.com/presentation/d/11fiO91MKZVgNoQJh5pn3Tw8-inHe6XbWYB2r1f701WI/edit?usp=sharing'> A conversational agent for Language learning</a> | <a href='https://github.com/ConorNugent/gradio-chatbot-demo'>Github Repo</a></p>"
examples = [
["Have you read the play what I wrote?"],
["Were do you live?"],
]
iface = gr.Interface(
chat,
[gr.Textbox(label="Send messages here"), "state"],
[gr.Chatbot(label='Conversation'), "state", gr.Textbox(
label="Feedback",
lines=1
)],
allow_screenshot=False,
allow_flagging="never",
title=title,
description=description,
article=article,
examples=examples)
iface.launch()
|