nugentc commited on
Commit
61a2cb3
·
1 Parent(s): 5155493

swpa out chat logic

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -7,17 +7,25 @@ import torch
7
  import gradio as gr
8
 
9
 
10
- def chat(message, history):
11
- history = history if history is not None else []
12
- new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
 
 
13
  bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
14
  history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
15
- response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n")
16
- # pretty print last ouput tokens from bot
17
- # response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
18
- print("The response is ", [response])
19
- # history.append((message, response, new_user_input_ids, chat_history_ids))
20
- return response, feedback(message)
21
 
22
 
23
  def feedback(text):
 
7
  import gradio as gr
8
 
9
 
10
+ # def chat(message, history):
11
+ # history = history if history is not None else []
12
+ # new_user_input_ids = tokenizer.encode(message+tokenizer.eos_token, return_tensors='pt')
13
+ # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
14
+ # history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
15
+ # # response = tokenizer.decode(history[0]).replace("<|endoftext|>", "\n")
16
+ # # pretty print last ouput tokens from bot
17
+ # response = tokenizer.decode(bot_input_ids.shape[-1][0], skip_special_tokens=True)
18
+ # print("The response is ", [response])
19
+ # # history.append((message, response, new_user_input_ids, chat_history_ids))
20
+ # return response, history, feedback(message)
21
+
22
+
23
+ def chat(message, history=[]):
24
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
25
  bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
26
  history = model.generate(bot_input_ids, max_length=500, pad_token_id=tokenizer.eos_token_id).tolist()
27
+ response = tokenizer.decode(history[0]).replace("<|endoftext|>", "")
28
+ return response, history
 
 
 
 
29
 
30
 
31
  def feedback(text):