Jaehan commited on
Commit
cc28aad
·
1 Parent(s): 78eb766

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -33
app.py CHANGED
@@ -1,52 +1,58 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- from transformers import BlenderbotForConditionalGeneration
3
  import torch
4
  import gradio as gr
5
 
6
- # model_name = "facebook/blenderbot-400M-distill"
7
  model_name = "microsoft/DialoGPT-medium"
8
- chat_tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
- def conversation(user_input, chat_history=[]):
12
- user_input_ids = chat_tokenizer(user_input+chat_tokenizer.eos_token, return_tensors="pt").input_ids
13
 
14
- # maintain history in the tensor
15
- chatbot_input_ids = torch.cat([torch.LongTensor(chat_hostory), user_input_ids], dim=-1)
16
 
17
- # ger response
18
- chat_history = model.generate(chatbot_input_ids, max_length=1000, pad_token_id=chat_tokenizer.eos_token_id).tolist()
19
  print(chat_history)
20
 
21
- response = chat_tokenizer.decode(chat_history[0]).split("<|endoftext|>")
22
- print("Starting to print response")
 
23
  print(response)
24
-
25
  # html for display
26
- html = "<div class='mychatbot'>"
27
- for x, msg in enumerate(response):
28
- if x%2 !=0:
29
- msg = "ChatBot:" + msg
30
- clazz = "bot"
31
- else:
32
- clazz = "user"
33
-
34
- print("Value of x:")
35
  print(x)
36
- print("Message:")
37
- print(msg)
38
- html += "<div class='msg {}'></div>".format(clazz, msg)
 
39
  html += "</div>"
40
  print(html)
41
  return html, chat_history
42
 
43
- # UX
 
44
  css = """
45
- .mychatbot {display:flex;flex-direction:column}
46
- .msg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
47
- .msg.user {background-color:lightblue;color:white}
48
- .msg.bot {background-color:orange;color:white,align-self:self-end}
49
- .footer
50
  """
51
- in_text = gr.inputs.Textbox(placeholder="Let's start a chat...")
52
- gr.Interface(fn=conversation, theme="default", inputs=[in_text, "state"], outputs=["html", "state"], css=css).launch()
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForConditionalGeneration
 
2
  import torch
3
  import gradio as gr
4
 
5
+ #model_name = "facebook/blenderbot-400M-distill"
6
  model_name = "microsoft/DialoGPT-medium"
7
+ chat_token = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
+ def converse(user_input, chat_history=[]):
11
+ user_input_ids = chat_token(user_input + chat_token.eos_token, return_tensors='pt').input_ids
12
 
13
+ # keep history in the tensor
14
+ chatbot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
15
 
16
+ # get response
17
+ chat_history = mdl.generate(chatbot_input_ids, max_length=1000, pad_token_id=chat_token.eos_token_id).tolist()
18
  print(chat_history)
19
 
20
+ response = chat_token.decode(chat_history[0]).split("<|endoftext|>")
21
+
22
+ print("Starting to print response...")
23
  print(response)
24
+
25
  # html for display
26
+ html = "<div class='mybot'>"
27
+ for x, mesg in enumerate(response):
28
+ if x%2!=0 :
29
+ mesg="Bot:" + mesg
30
+ clazz = "bot"
31
+ else :
32
+ clazz = "user"
33
+
34
+ print("value of x")
35
  print(x)
36
+ print("message")
37
+ print (mesg)
38
+
39
+ html += "<div class='mesg {}'> {}</div>".format(clazz, mesg)
40
  html += "</div>"
41
  print(html)
42
  return html, chat_history
43
 
44
+
45
+
46
  css = """
47
+ .mybot {display:flex;flex-direction:column}
48
+ .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
49
+ .mesg.user {background-color:lightblue;color:white}
50
+ .mesg.bot {background-color:orange;color:white,align-self:self-end}
51
+ .footer {display:none !important}
52
  """
53
+ in_text = gr.inputs.Textbox(placeholder="Let's start a chat")
54
+ gr.Interface(fn=converse,
55
+ theme="default",
56
+ inputs=[in_text, "state"],
57
+ outputs=["html", "state"],
58
+ css=css).launch()