Jaehan commited on
Commit
b73e8d0
·
1 Parent(s): 52b8a82

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from transformers import BlenderbotForConditionalGeneration
3
+ import torch
4
+ imprt gradio as gr
5
+
6
+ # model_name = "facebook/blenderbot-400M-distill"
7
+ model_name = "microsoft/DialoGPT-medium"
8
+ chat_tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+
11
+ def conversation(user_input, chat_history=[]):
12
+ user_input_ids = chat_tokenizer(user_input+chat_tokenizer.eos_token, return_tensors="pt").input_ids
13
+
14
+ # maintain history in the tensor
15
+ chatbot_input_ids = torch.cat([torch.LongTensor(chat_hostory), user_input_ids], dim=-1)
16
+
17
+ # ger response
18
+ chat_history = model.generate(chatbot_input_ids, max_length=1000, pad_token_id=chat_tokenizer.eos_token_id).tolist()
19
+ print(chat_history)
20
+
21
+ response = chat_tokeniser.decode(chat_history[0]).split("<|endoftext|>")
22
+ print("Starting to print response")
23
+ print(response)
24
+
25
+ # html for display
26
+ html = "<div class='mychatbot'>"
27
+ for x, msg in enumerate(response):
28
+ if x%2 !=0:
29
+ msg = "ChatBot:" + msg
30
+ class = "bot"
31
+ else:
32
+ class = "user"
33
+
34
+ print("Value of x:")
35
+ print(x)
36
+ print("Message:")
37
+ print(msg)
38
+ html += "<div class='msg {}'></div>".format(class, msg)
39
+ html += "</div>"
40
+ print(html)
41
+ return html, chat_history
42
+
43
+ # UX
44
+ css = """
45
+ .mychat {display:flex;flex-direction:column}
46
+ .msg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
47
+ .msg.user {background-color:lightblue;color:white}
48
+ .msg.bot {background-color:orange;color:white,align-self:self-end}
49
+ .footer
50
+ """
51
+ in_text = gr.inputs.Textbox(placeholder="Let's start a chat...")
52
+ gr.Interface(fn=conversation, theme="default", inputs=[in_text, "state"], outputs=["html", "state"], css=css).launch()