File size: 2,811 Bytes
71f2227
3e149d5
 
71f2227
 
0a042d8
cdfe192
703a8b7
0a042d8
61a2cb3
 
fbd4cfd
 
0d07fde
fbd4cfd
 
870b1fe
482fbde
fd58ec5
 
2dfc510
ed80cd3
3e149d5
 
 
55f3dd8
c669d92
f13ff81
691bbb0
aa451fe
 
6def779
 
c669d92
 
f13ff81
08d35ca
703a8b7
5a4fdf8
 
 
 
 
a6bca23
6986639
5a4fdf8
 
db0e2dc
cdfe192
6bbb073
c556324
2ddaf14
6bbb073
2ddaf14
337fe11
cdfe192
5a4fdf8
 
 
 
db0e2dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")
grammar_tokenizer = T5Tokenizer.from_pretrained('deep-learning-analytics/GrammarCorrector')
grammar_model = T5ForConditionalGeneration.from_pretrained('deep-learning-analytics/GrammarCorrector')
import torch
import gradio as gr



def chat(message, history=[]):
    new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
    if  len(history) > 0:
        last_set_of_ids = history[len(history)-1][2]
        bot_input_ids = torch.cat([last_set_of_ids, new_user_input_ids], dim=-1) 
    else:
        bot_input_ids = new_user_input_ids
    chat_history_ids = model.generate(bot_input_ids, max_length=5000, pad_token_id=tokenizer.eos_token_id)
    response_ids = chat_history_ids[:, bot_input_ids.shape[-1]:][0]
    response = tokenizer.decode(response_ids, skip_special_tokens=True)
    history.append((message, response, chat_history_ids))
    return history, history, feedback(message)


def feedback(text):
    num_return_sequences=1
    batch =  grammar_tokenizer([text],truncation=True,padding='max_length',max_length=64, return_tensors="pt")
    corrections = grammar_model.generate(**batch,max_length=64,num_beams=2, num_return_sequences=num_return_sequences, temperature=1.5)
    corrected_text = grammar_tokenizer.decode(corrections[0], clean_up_tokenization_spaces=True, skip_special_tokens=True)
    print("The corrected text is: ", corrected_text)
    print("The orig text is: ", text)
    if corrected_text.rstrip('.') == text.rstrip('.'):
    # if corrected_text == text:
        feedback = f'Looks good! Keep up the good work'
    else:
        feedback = f'\'{corrected_text}\' might be a little better'
    return feedback


title = "A chatbot that provides grammar feedback"
description = "A quick proof of concept using Gradio"
article = "<p style='text-align: center'><a href='https://docs.google.com/presentation/d/11fiO91MKZVgNoQJh5pn3Tw8-inHe6XbWYB2r1f701WI/edit?usp=sharing'> A conversational agent for Language learning</a> | <a href='https://github.com/ConorNugent/gradio-chatbot-demo'>Github Repo</a></p>"
examples = [
            ["Have you read the play what I wrote?"],
            ["Were do you live?"],
]

iface = gr.Interface(
    chat,
    [gr.Textbox(label="Send messages here"), "state"],
    [gr.Chatbot(label='Conversation'), "state", gr.Textbox(
            label="Feedback",
            lines=1
        )],
    allow_screenshot=False,
    allow_flagging="never",
    title=title, 
    description=description, 
    article=article, 
    examples=examples)
iface.launch()