ivpich commited on
Commit
aec6377
·
verified ·
1 Parent(s): 2b211d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -38,8 +38,6 @@ model = AutoModelForCausalLM.from_pretrained(
38
  device_map="auto",
39
  ignore_mismatched_sizes=True)
40
 
41
- eos_token_id = tokenizer.eos_token_id
42
-
43
  @spaces.GPU()
44
  def stream_chat(
45
  message: str,
@@ -64,20 +62,25 @@ def stream_chat(
64
 
65
  conversation.append({"role": "user", "content": message})
66
 
67
- input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
68
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
 
 
 
 
 
69
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
70
 
71
  generate_kwargs = dict(
72
- input_ids=inputs,
73
  max_new_tokens=max_new_tokens,
 
74
  do_sample=False if temperature == 0 else True,
75
  top_p=top_p,
76
  top_k=top_k,
77
  temperature=temperature,
78
  streamer=streamer,
79
- pad_token_id=eos_token_id,
80
- eos_token_id=eos_token_id,
81
  )
82
 
83
  with torch.no_grad():
@@ -88,8 +91,6 @@ def stream_chat(
88
  for new_text in streamer:
89
  buffer += new_text
90
  yield buffer
91
- if eos_token_id in tokenizer.encode(new_text):
92
- break
93
 
94
 
95
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
 
38
  device_map="auto",
39
  ignore_mismatched_sizes=True)
40
 
 
 
41
  @spaces.GPU()
42
  def stream_chat(
43
  message: str,
 
62
 
63
  conversation.append({"role": "user", "content": message})
64
 
65
+ input_text = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
66
+
67
+ terminators = [
68
+ tokenizer.eos_token_id,
69
+ tokenizer.convert_tokens_to_ids("")
70
+ ]
71
+
72
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
73
 
74
  generate_kwargs = dict(
75
+ input_ids=input_text,
76
  max_new_tokens=max_new_tokens,
77
+ eos_token_id=terminators,
78
  do_sample=False if temperature == 0 else True,
79
  top_p=top_p,
80
  top_k=top_k,
81
  temperature=temperature,
82
  streamer=streamer,
83
+ pad_token_id=10,
 
84
  )
85
 
86
  with torch.no_grad():
 
91
  for new_text in streamer:
92
  buffer += new_text
93
  yield buffer
 
 
94
 
95
 
96
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)