Jaehan commited on
Commit
20cda87
·
1 Parent(s): cc28aad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -2,34 +2,36 @@ from transformers import AutoModelForCausalLM, AutoTokenizer,BlenderbotForCondit
2
  import torch
3
  import gradio as gr
4
 
5
- #model_name = "facebook/blenderbot-400M-distill"
6
- model_name = "microsoft/DialoGPT-medium"
7
- chat_token = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
- def converse(user_input, chat_history=[]):
11
- user_input_ids = chat_token(user_input + chat_token.eos_token, return_tensors='pt').input_ids
 
 
 
12
 
13
  # keep history in the tensor
14
- chatbot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
15
 
16
  # get response
17
- chat_history = mdl.generate(chatbot_input_ids, max_length=1000, pad_token_id=chat_token.eos_token_id).tolist()
18
- print(chat_history)
19
 
20
- response = chat_token.decode(chat_history[0]).split("<|endoftext|>")
21
 
22
- print("Starting to print response...")
23
  print(response)
24
 
25
  # html for display
26
  html = "<div class='mybot'>"
27
  for x, mesg in enumerate(response):
28
  if x%2!=0 :
29
- mesg="Bot:" + mesg
30
- clazz = "bot"
31
  else :
32
- clazz = "user"
 
33
 
34
  print("value of x")
35
  print(x)
@@ -41,18 +43,16 @@ def converse(user_input, chat_history=[]):
41
  print(html)
42
  return html, chat_history
43
 
44
-
45
-
46
  css = """
47
- .mybot {display:flex;flex-direction:column}
48
  .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
49
  .mesg.user {background-color:lightblue;color:white}
50
  .mesg.bot {background-color:orange;color:white,align-self:self-end}
51
  .footer {display:none !important}
52
  """
53
- in_text = gr.inputs.Textbox(placeholder="Let's start a chat")
54
  gr.Interface(fn=converse,
55
  theme="default",
56
- inputs=[in_text, "state"],
57
  outputs=["html", "state"],
58
  css=css).launch()
 
2
  import torch
3
  import gradio as gr
4
 
5
+ chat_tkn = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
6
+ mdl = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
 
 
7
 
8
+ #chat_tkn = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
9
+ #mdl = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill")
10
+
11
+ def converse(user_input, chat_history=[]):
12
+ user_input_ids = chat_tkn(user_input + chat_tkn.eos_token, return_tensors='pt').input_ids
13
 
14
  # keep history in the tensor
15
+ bot_input_ids = torch.cat([torch.LongTensor(chat_history), user_input_ids], dim=-1)
16
 
17
  # get response
18
+ chat_history = mdl.generate(bot_input_ids, max_length=1000, pad_token_id=chat_tkn.eos_token_id).tolist()
19
+ print (chat_history)
20
 
21
+ response = chat_tkn.decode(chat_history[0]).split("<|endoftext|>")
22
 
23
+ print("starting to print response")
24
  print(response)
25
 
26
  # html for display
27
  html = "<div class='mybot'>"
28
  for x, mesg in enumerate(response):
29
  if x%2!=0 :
30
+ mesg="Bot:"+mesg
31
+ clazz="bot"
32
  else :
33
+ clazz="user"
34
+
35
 
36
  print("value of x")
37
  print(x)
 
43
  print(html)
44
  return html, chat_history
45
 
 
 
46
  css = """
47
+ .mychat {display:flex;flex-direction:column}
48
  .mesg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%}
49
  .mesg.user {background-color:lightblue;color:white}
50
  .mesg.bot {background-color:orange;color:white,align-self:self-end}
51
  .footer {display:none !important}
52
  """
53
+ text=gr.inputs.Textbox(placeholder="Lets chat")
54
  gr.Interface(fn=converse,
55
  theme="default",
56
+ inputs=[text, "state"],
57
  outputs=["html", "state"],
58
  css=css).launch()