s-a-malik commited on
Commit
0120475
·
1 Parent(s): b501b77
Files changed (1) hide show
  1. app.py +71 -102
app.py CHANGED
@@ -26,10 +26,10 @@ DESCRIPTION = """
26
  """
27
 
28
  EXAMPLES = [
29
- ["What is the capital of France?", "You are a helpful assistant."],
30
- ["Who landed on the moon?", "You are a knowledgeable historian."],
31
- ["Who is Yarin Gal?", "You are a helpful assistant."],
32
- ["Explain the theory of relativity in simple terms.", "You are an expert physicist explaining concepts to a layman."],
33
  ]
34
 
35
  if torch.cuda.is_available():
@@ -93,28 +93,7 @@ class CustomStreamer(TextIteratorStreamer):
93
 
94
 
95
 
96
- # se_highlighted_text = ""
97
- # acc_highlighted_text = ""
98
- # for new_text in streamer:
99
- # hidden_states = streamer.hidden_states_queue.get()
100
-
101
- # # Semantic Uncertainty Probe
102
- # se_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
103
- # se_concat_layers = se_token_embeddings.numpy()[se_layer_range[0]:se_layer_range[1]].reshape(-1)
104
- # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
105
-
106
- # # Accuracy Probe
107
- # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
108
- # acc_concat_layers = acc_token_embeddings.numpy()[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
109
- # acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1] * 2 - 1
110
-
111
- # se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
112
- # acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
113
-
114
- # se_highlighted_text += se_new_highlighted_text
115
- # acc_highlighted_text += acc_new_highlighted_text
116
-
117
- # yield se_highlighted_text, acc_highlighted_text
118
 
119
  @spaces.GPU
120
  def generate(
@@ -137,7 +116,8 @@ def generate(
137
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
138
  input_ids = input_ids.to(model.device)
139
 
140
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
141
  generation_kwargs = dict(
142
  input_ids=input_ids,
143
  max_new_tokens=max_new_tokens,
@@ -150,41 +130,84 @@ def generate(
150
  output_hidden_states=True,
151
  return_dict_in_generate=True,
152
  )
153
-
154
- # Generate without threading
155
- with torch.no_grad():
156
- outputs = model.generate(**generation_kwargs)
157
- generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
158
- generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
159
- # hidden states
160
- hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
161
-
162
- # TODO do this loop on the fly instead of waiting for the whole generation
163
  se_highlighted_text = ""
164
  acc_highlighted_text = ""
165
- for i in range(1, len(hidden)):
166
-
167
  # Semantic Uncertainty Probe
168
- token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)
169
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
170
  se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
171
 
172
  # Accuracy Probe
173
- # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
174
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
175
  acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
176
 
177
- output_id = outputs.sequences[0, input_ids.shape[1]+i]
178
- output_word = tokenizer.decode(output_id)
179
- print(output_id, output_word, se_probe_pred, acc_probe_pred)
180
 
181
- se_new_highlighted_text = highlight_text(output_word, se_probe_pred)
182
- acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)
183
  se_highlighted_text += f" {se_new_highlighted_text}"
184
  acc_highlighted_text += f" {acc_new_highlighted_text}"
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  # yield se_highlighted_text, acc_highlighted_text
187
- return se_highlighted_text, acc_highlighted_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
 
190
 
@@ -215,7 +238,7 @@ with gr.Blocks(title="Llama-2 7B Chat with Dual Probes", css="footer {visibility
215
 
216
  with gr.Column():
217
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
218
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
219
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
220
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
221
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
@@ -243,7 +266,6 @@ with gr.Blocks(title="Llama-2 7B Chat with Dual Probes", css="footer {visibility
243
  inputs=[message, system_prompt],
244
  outputs=[se_output, acc_output],
245
  fn=generate,
246
-
247
  )
248
 
249
  generate_btn.click(
@@ -252,59 +274,6 @@ with gr.Blocks(title="Llama-2 7B Chat with Dual Probes", css="footer {visibility
252
  outputs=[se_output, acc_output]
253
  )
254
 
255
- # chat_interface = gr.ChatInterface(
256
- # fn=generate,
257
- # additional_inputs=[
258
- # gr.Textbox(label="System prompt", lines=6),
259
- # gr.Slider(
260
- # label="Max new tokens",
261
- # minimum=1,
262
- # maximum=MAX_MAX_NEW_TOKENS,
263
- # step=1,
264
- # value=DEFAULT_MAX_NEW_TOKENS,
265
- # ),
266
- # gr.Slider(
267
- # label="Temperature",
268
- # minimum=0.1,
269
- # maximum=4.0,
270
- # step=0.1,
271
- # value=0.6,
272
- # ),
273
- # gr.Slider(
274
- # label="Top-p (nucleus sampling)",
275
- # minimum=0.05,
276
- # maximum=1.0,
277
- # step=0.05,
278
- # value=0.9,
279
- # ),
280
- # gr.Slider(
281
- # label="Top-k",
282
- # minimum=1,
283
- # maximum=1000,
284
- # step=1,
285
- # value=50,
286
- # ),
287
- # gr.Slider(
288
- # label="Repetition penalty",
289
- # minimum=1.0,
290
- # maximum=2.0,
291
- # step=0.05,
292
- # value=1.2,
293
- # ),
294
- # ],
295
- # stop_btn=None,
296
- # examples=[
297
- # ["What is the capital of France?"],
298
- # ["Who landed on the moon?"],
299
- # ["Who is Yarin Gal?"]
300
- # ],
301
- # title="Llama-2 7B Chat with Streamable Semantic Uncertainty Probe",
302
- # description=DESCRIPTION,
303
- # )
304
-
305
- # if __name__ == "__main__":
306
- # chat_interface.launch()
307
-
308
 
309
  if __name__ == "__main__":
310
  demo.launch()
 
26
  """
27
 
28
  EXAMPLES = [
29
+ ["What is the capital of France?", ""],
30
+ ["Who landed on the moon?", ""],
31
+ ["Who is Yarin Gal?", ""],
32
+ ["Explain the theory of relativity in simple terms.", ""],
33
  ]
34
 
35
  if torch.cuda.is_available():
 
93
 
94
 
95
 
96
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  @spaces.GPU
99
  def generate(
 
116
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
117
  input_ids = input_ids.to(model.device)
118
 
119
+ # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
120
+ streamer = CustomStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
121
  generation_kwargs = dict(
122
  input_ids=input_ids,
123
  max_new_tokens=max_new_tokens,
 
130
  output_hidden_states=True,
131
  return_dict_in_generate=True,
132
  )
133
+ # with threading
134
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
135
+ thread.start()
 
 
 
 
 
 
 
136
  se_highlighted_text = ""
137
  acc_highlighted_text = ""
138
+ for new_text in streamer:
139
+ hidden_states = streamer.hidden_states_queue.get()
140
  # Semantic Uncertainty Probe
141
+ token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
142
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
143
  se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
144
 
145
  # Accuracy Probe
 
146
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
147
  acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
148
 
149
+ print(new_text, se_probe_pred, acc_probe_pred)
 
 
150
 
151
+ se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
152
+ acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
153
  se_highlighted_text += f" {se_new_highlighted_text}"
154
  acc_highlighted_text += f" {acc_new_highlighted_text}"
155
 
156
+ yield se_highlighted_text, acc_highlighted_text
157
+
158
+ # Semantic Uncertainty Probe
159
+ # se_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
160
+ # se_concat_layers = se_token_embeddings.numpy()[se_layer_range[0]:se_layer_range[1]].reshape(-1)
161
+ # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
162
+
163
+ # # Accuracy Probe
164
+ # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
165
+ # acc_concat_layers = acc_token_embeddings.numpy()[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
166
+ # acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1] * 2 - 1
167
+
168
+ # se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
169
+ # acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
170
+
171
+ # se_highlighted_text += se_new_highlighted_text
172
+ # acc_highlighted_text += acc_new_highlighted_text
173
+
174
  # yield se_highlighted_text, acc_highlighted_text
175
+
176
+ # Generate without threading
177
+ # with torch.no_grad():
178
+ # outputs = model.generate(**generation_kwargs)
179
+ # generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
180
+ # generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
181
+ # # hidden states
182
+ # hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
183
+
184
+ # # TODO do this loop on the fly instead of waiting for the whole generation
185
+ # se_highlighted_text = ""
186
+ # acc_highlighted_text = ""
187
+
188
+ # for i in range(1, len(hidden)):
189
+
190
+ # # Semantic Uncertainty Probe
191
+ # token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)
192
+ # se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
193
+ # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
194
+
195
+ # # Accuracy Probe
196
+ # # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
197
+ # acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
198
+ # acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
199
+
200
+ # output_id = outputs.sequences[0, input_ids.shape[1]+i]
201
+ # output_word = tokenizer.decode(output_id)
202
+ # print(output_id, output_word, se_probe_pred, acc_probe_pred)
203
+
204
+ # se_new_highlighted_text = highlight_text(output_word, se_probe_pred)
205
+ # acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)
206
+ # se_highlighted_text += f" {se_new_highlighted_text}"
207
+ # acc_highlighted_text += f" {acc_new_highlighted_text}"
208
+
209
+ # # yield se_highlighted_text, acc_highlighted_text
210
+ # return se_highlighted_text, acc_highlighted_text
211
 
212
 
213
 
 
238
 
239
  with gr.Column():
240
  max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
241
+ temperature = gr.Slider(label="Temperature", minimum=0.01, maximum=2.0, step=0.1, value=0.01)
242
  top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
243
  top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
244
  repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
 
266
  inputs=[message, system_prompt],
267
  outputs=[se_output, acc_output],
268
  fn=generate,
 
269
  )
270
 
271
  generate_btn.click(
 
274
  outputs=[se_output, acc_output]
275
  )
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  if __name__ == "__main__":
279
  demo.launch()