ACMCMC commited on
Commit
a3c9adb
·
1 Parent(s): f856f17

Bugfix - GPT-2 errors when using 1 input token (still investigating why)

Browse files
Files changed (1) hide show
  1. demo.py +224 -63
demo.py CHANGED
@@ -7,12 +7,45 @@ import tempfile
7
  from io import BytesIO
8
  import logging
9
 
10
- # Load the tokenizer and model
11
- tokenizer = transformers.AutoTokenizer.from_pretrained("openai-community/gpt2")
12
- model = transformers.AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
 
 
 
 
 
13
 
14
 
15
- # Update the optimization function in demo.py to align with the notebook
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  def optimize_simon_says_prompt(
@@ -20,6 +53,9 @@ def optimize_simon_says_prompt(
20
  number_of_simon_says_tokens: int,
21
  n_steps: int,
22
  lr: float,
 
 
 
23
  progress=gr.Progress(track_tqdm=False), # Gradio progress tracking
24
  ) -> tuple[str, torch.Tensor]:
25
  """
@@ -30,27 +66,49 @@ def optimize_simon_says_prompt(
30
  number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
31
  n_steps (int): Number of optimization steps.
32
  lr (float): Learning rate for the optimization process.
 
 
 
33
  progress (gr.Progress): Gradio progress tracking.
34
 
35
  Returns:
36
  The optimized Simon Says prompt
37
  """
 
 
 
 
 
 
 
38
  # Tokenize the input text
39
  tokens = tokenizer(
40
  input_text,
41
  return_tensors="pt",
42
  padding=False,
43
  truncation=True,
44
- add_special_tokens=True,
45
  )
46
- embeddings = model.transformer.wte(tokens["input_ids"]).detach()
47
 
48
  # Initialize a random Simon Says prompt
49
  simon_says_prompt = torch.randn(
50
- 1, number_of_simon_says_tokens, model.config.n_embd, requires_grad=True
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  optimizer = torch.optim.Adam([simon_says_prompt], lr=lr)
53
- loss_fn = torch.nn.CrossEntropyLoss()
54
 
55
  best_loss: float = float("inf")
56
  best_simon_says_prompt: torch.Tensor = None
@@ -60,19 +118,37 @@ def optimize_simon_says_prompt(
60
 
61
  for step in range(n_steps):
62
  optimizer.zero_grad()
63
- expanded_prompt = torch.cat([simon_says_prompt, embeddings], dim=1)
64
- logits = model(inputs_embeds=expanded_prompt).logits
65
- probs = torch.softmax(logits[:, simon_says_prompt.size(-2) - 1 : -1], dim=-1)
 
 
 
 
66
  ranks = (
67
  torch.sum(
68
  probs > probs.gather(2, tokens["input_ids"].unsqueeze(-1)), dim=-1
69
  )
70
  + 1
71
  )
72
- loss = loss_fn(
73
- logits[:, simon_says_prompt.size(-2) - 1 : -1].reshape(-1, logits.size(-1)),
74
- tokens["input_ids"].reshape(-1),
 
 
 
 
 
 
 
75
  )
 
 
 
 
 
 
 
76
  loss.backward()
77
  optimizer.step()
78
 
@@ -90,16 +166,15 @@ def optimize_simon_says_prompt(
90
  best_loss = loss.item()
91
  best_simon_says_prompt = simon_says_prompt.detach().clone()
92
 
93
- # If all ranks are 1, stop the optimization (perfect prediction)
94
- if torch.all(ranks == 1):
95
- break
 
 
96
 
97
  return best_simon_says_prompt
98
 
99
 
100
- # Modify the download_tensor function to save the tensor as a safetensors file
101
-
102
-
103
  def download_tensor(tensor):
104
  """
105
  Save a tensor to a safetensors file for download.
@@ -150,14 +225,20 @@ def upload_tensor(file):
150
  return tensor_data["optimized_tensor"]
151
 
152
 
153
- def greedy_decode_with_ss_prompt(
154
- ss_prompt: torch.Tensor, progress=gr.Progress()
 
 
 
 
155
  ) -> str:
156
  """
157
  Perform greedy decoding using an uploaded optimized tensor and input text.
158
 
159
  Parameters:
160
- ss_prompt (torch.Tensor): The uploaded optimized tensor.
 
 
161
  progress (gr.Progress): Gradio progress tracking.
162
 
163
  Returns:
@@ -168,34 +249,57 @@ def greedy_decode_with_ss_prompt(
168
 
169
  progress(0, desc="Starting greedy decoding...")
170
 
171
- with torch.no_grad():
172
- for i in progress.tqdm(range(150), desc="Decoding..."):
173
- if len(generated_tokens) == 0:
174
- expanded_prompt = ss_prompt
175
- else:
176
- expanded_prompt = torch.cat(
177
- [
178
- ss_prompt,
179
- model.transformer.wte(
180
- torch.tensor(generated_tokens).unsqueeze(0)
181
- ).detach(),
182
- ],
183
- dim=1,
184
- )
185
-
186
- logits = model(inputs_embeds=expanded_prompt).logits
187
- next_token_logits = logits[0, -1, :]
188
- next_token = next_token_logits.argmax().item()
189
-
190
- logging.info(
191
- f"Step {i}, Next Token: {next_token}, Logit: {next_token_logits[next_token].item()}"
 
 
 
 
 
 
192
  )
193
 
194
- generated_tokens.append(next_token)
195
- all_logits.append(next_token_logits)
 
 
 
 
 
 
 
 
 
 
196
 
197
- if next_token == tokenizer.eos_token_id:
198
- break
 
 
 
 
 
 
 
199
 
200
  generated_tokens = torch.tensor(generated_tokens)
201
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
@@ -208,6 +312,8 @@ def process_and_generate(
208
  number_of_simon_says_tokens: int,
209
  n_steps: int,
210
  lr: float,
 
 
211
  ) -> tuple[str, str]:
212
  """
213
  Optimize the Simon Says prompt, display the optimization process, and generate text based on the input text.
@@ -217,19 +323,27 @@ def process_and_generate(
217
  number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
218
  n_steps (int): Number of optimization steps.
219
  lr (float): Learning rate for the optimization process.
 
 
220
 
221
  Returns:
222
  tuple: The optimized Simon Says prompt and the greedy-decoded text.
223
  """
 
224
  optimized_prompt = optimize_simon_says_prompt(
225
  input_text=input_text,
226
  number_of_simon_says_tokens=number_of_simon_says_tokens,
227
  n_steps=n_steps,
228
  lr=lr,
 
 
 
229
  )
230
 
231
  # Generate text using the optimized prompt
232
- generated_text: str = greedy_decode_with_ss_prompt(optimized_prompt)
 
 
233
 
234
  return (
235
  generated_text,
@@ -238,7 +352,7 @@ def process_and_generate(
238
 
239
 
240
  def process_with_uploaded_tensor(
241
- input_text: str, uploaded_tensor: torch.Tensor
242
  ) -> tuple[str, str]:
243
  """
244
  Process the uploaded tensor and generate text based on the input text.
@@ -246,11 +360,15 @@ def process_with_uploaded_tensor(
246
  Parameters:
247
  input_text (str): The input text provided by the user.
248
  uploaded_tensor (torch.Tensor): The uploaded optimized tensor.
 
249
 
250
  Returns:
251
  tuple: The generated text and the file path of the uploaded tensor.
252
  """
253
- generated_text = greedy_decode_with_ss_prompt(uploaded_tensor)
 
 
 
254
  return generated_text, None
255
 
256
 
@@ -273,15 +391,22 @@ theme = gr.themes.Soft(
273
  ],
274
  )
275
 
276
- # Update the Gradio interface to include configurable parameters
277
  demo = gr.Interface(
278
  theme=theme,
279
  title="Simon Says Prompt Optimization and Text Generation",
280
- fn=lambda input_text, number_of_simon_says_tokens, n_steps, lr, uploaded_file: (
281
- process_with_uploaded_tensor(input_text, upload_tensor(uploaded_file))
 
 
282
  if uploaded_file
283
  else process_and_generate(
284
- input_text, number_of_simon_says_tokens, n_steps, lr
 
 
 
 
 
285
  )
286
  ),
287
  inputs=[
@@ -290,27 +415,63 @@ demo = gr.Interface(
290
  placeholder="Enter your text here...",
291
  label="Input Text",
292
  value="Hello world! I'm Aldan, happy to be here.",
 
 
 
 
 
 
 
 
293
  ),
294
  gr.Slider(
295
- minimum=1, maximum=10, step=1, value=4, label="Number of Simon Says Tokens"
 
 
 
 
 
296
  ),
297
  gr.Slider(
298
  minimum=100,
299
  maximum=10000,
300
  step=100,
301
- value=5000,
302
- label="Number of Optimization Steps",
 
303
  ),
304
  gr.Slider(
305
- minimum=1e-5, maximum=1e-1, step=1e-5, value=1e-2, label="Learning Rate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  ),
307
- gr.File(label="Upload Optimized Tensor (Optional)", type="binary"),
308
  ],
309
  outputs=[
310
- gr.Textbox(label="Generated Text"),
311
- gr.File(label="Download Optimized Tensor", type="filepath"),
 
 
 
 
 
 
312
  ],
313
- description="This demo optimizes a Simon Says prompt based on your input text, displays the optimization process, and generates text using the optimized prompt. Optionally, you can upload a pre-optimized tensor for inference.",
314
  )
315
 
316
  # Ensure the Gradio interface is correctly launched
 
7
  from io import BytesIO
8
  import logging
9
 
10
+ # Add a dropdown to select the model
11
+ model_options = [
12
+ "openai-community/gpt2",
13
+ "google/gemma-3-1b-it",
14
+ "meta-llama/Llama-3.2-1B",
15
+ "EleutherAI/pythia-160m",
16
+ "EleutherAI/pythia-14m",
17
+ ]
18
 
19
 
20
+ def load_model_and_tokenizer(model_name):
21
+ """
22
+ Load the tokenizer and model based on the selected model name.
23
+
24
+ Parameters:
25
+ model_name (str): The name of the model to load.
26
+
27
+ Returns:
28
+ tuple: The loaded tokenizer and model.
29
+ """
30
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
31
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
32
+ return tokenizer, model
33
+
34
+
35
+ def get_embeddings(
36
+ input_ids: torch.Tensor, model: transformers.PreTrainedModel
37
+ ) -> torch.Tensor:
38
+ """
39
+ Get the embeddings for the input IDs.
40
+
41
+ Parameters:
42
+ input_ids (torch.Tensor): The input IDs for which to get the embeddings.
43
+ model (transformers.PreTrainedModel): The model to use for generating embeddings.
44
+
45
+ Returns:
46
+ torch.Tensor: The embeddings for the input IDs.
47
+ """
48
+ return model.get_input_embeddings()(input_ids).detach()
49
 
50
 
51
  def optimize_simon_says_prompt(
 
53
  number_of_simon_says_tokens: int,
54
  n_steps: int,
55
  lr: float,
56
+ model: transformers.PreTrainedModel,
57
+ tokenizer: transformers.PreTrainedTokenizer,
58
+ add_eos_token: bool,
59
  progress=gr.Progress(track_tqdm=False), # Gradio progress tracking
60
  ) -> tuple[str, torch.Tensor]:
61
  """
 
66
  number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
67
  n_steps (int): Number of optimization steps.
68
  lr (float): Learning rate for the optimization process.
69
+ model (transformers.PreTrainedModel): The model to use for optimization.
70
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for tokenization.
71
+ add_eos_token (bool): Whether to add an EOS token to the input text.
72
  progress (gr.Progress): Gradio progress tracking.
73
 
74
  Returns:
75
  The optimized Simon Says prompt
76
  """
77
+ torch.manual_seed(42) # Set a random seed for reproducibility
78
+
79
+ # Check if the EOS token checkbox is selected
80
+ if add_eos_token:
81
+ # We could've also used the tokenizer.eos_token_id, but this is easier because we don't need to potentially handle padding attention masks, batching issues, etc.
82
+ input_text += tokenizer.eos_token
83
+
84
  # Tokenize the input text
85
  tokens = tokenizer(
86
  input_text,
87
  return_tensors="pt",
88
  padding=False,
89
  truncation=True,
90
+ add_special_tokens=False,
91
  )
92
+ embeddings = get_embeddings(tokens["input_ids"], model)
93
 
94
  # Initialize a random Simon Says prompt
95
  simon_says_prompt = torch.randn(
96
+ 1, number_of_simon_says_tokens, embeddings.size(-1), requires_grad=True
97
  )
98
+
99
+ dummy_prompt = torch.zeros_like(
100
+ simon_says_prompt[..., 0:1, :], requires_grad=False
101
+ ) # Add an extra dimension
102
+
103
+ attention_mask = torch.ones_like(
104
+ torch.cat([dummy_prompt, simon_says_prompt, embeddings], dim=1)[:, :, 0],
105
+ device=simon_says_prompt.device,
106
+ requires_grad=False,
107
+ )
108
+ # Set the first token to 0 in the attention mask
109
+ attention_mask[:, 0] = 0
110
+
111
  optimizer = torch.optim.Adam([simon_says_prompt], lr=lr)
 
112
 
113
  best_loss: float = float("inf")
114
  best_simon_says_prompt: torch.Tensor = None
 
118
 
119
  for step in range(n_steps):
120
  optimizer.zero_grad()
121
+ expanded_prompt = torch.cat(
122
+ [dummy_prompt, simon_says_prompt, embeddings], dim=1
123
+ )
124
+ logits = model(
125
+ inputs_embeds=expanded_prompt, attention_mask=attention_mask
126
+ ).logits
127
+ probs = torch.softmax(logits[:, -embeddings.size(-2) - 1 : -1], dim=-1)
128
  ranks = (
129
  torch.sum(
130
  probs > probs.gather(2, tokens["input_ids"].unsqueeze(-1)), dim=-1
131
  )
132
  + 1
133
  )
134
+
135
+ # If all ranks are 1, stop the optimization (perfect prediction)
136
+ if torch.all(ranks == 1):
137
+ best_simon_says_prompt = simon_says_prompt.detach().clone()
138
+ break
139
+
140
+ loss = torch.functional.F.cross_entropy(
141
+ input=logits[:, -embeddings.size(-2) - 1 : -1].reshape(-1, logits.size(-1)),
142
+ target=tokens["input_ids"].reshape(-1),
143
+ reduction="none",
144
  )
145
+ # Multiply the loss by the ranks to give more weight to the tokens with higher ranks - this is to speed up the optimization process and avoid getting stuck in local minima
146
+ # Weights should be between 0 and 1 - we can normalize the ranks to get weights and then apply softmax to get the final weights as a more stable distribution
147
+ token_weights = ranks.float() / ranks.float().max()
148
+ print(f"Token Ranks: {ranks}")
149
+ print(f"Token Weights: {token_weights}")
150
+ loss = loss * token_weights.reshape(-1)
151
+ loss = loss.mean()
152
  loss.backward()
153
  optimizer.step()
154
 
 
166
  best_loss = loss.item()
167
  best_simon_says_prompt = simon_says_prompt.detach().clone()
168
 
169
+ else:
170
+ # Show a Gradio warning saying that the optimization did not converge
171
+ gr.Warning(
172
+ "The optimization did not converge. The prompt will not generate the expected output."
173
+ )
174
 
175
  return best_simon_says_prompt
176
 
177
 
 
 
 
178
  def download_tensor(tensor):
179
  """
180
  Save a tensor to a safetensors file for download.
 
225
  return tensor_data["optimized_tensor"]
226
 
227
 
228
+ @torch.inference_mode()
229
+ def greedy_decode_with_simon_says_prompt(
230
+ simon_says_prompt: torch.Tensor,
231
+ model: transformers.PreTrainedModel,
232
+ tokenizer: transformers.PreTrainedTokenizer,
233
+ progress=gr.Progress(),
234
  ) -> str:
235
  """
236
  Perform greedy decoding using an uploaded optimized tensor and input text.
237
 
238
  Parameters:
239
+ simon_says_prompt (torch.Tensor): The uploaded optimized tensor.
240
+ model (transformers.PreTrainedModel): The model to use for decoding.
241
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for decoding.
242
  progress (gr.Progress): Gradio progress tracking.
243
 
244
  Returns:
 
249
 
250
  progress(0, desc="Starting greedy decoding...")
251
 
252
+ # Add an extra dimension with all 0s to the start of the prompt - this is just a bugfix because GPT-2 can't handle a prompt of size 1 (still investigating why)
253
+ dummy_prompt = torch.zeros_like(
254
+ simon_says_prompt[..., 0:1, :]
255
+ ) # Add an extra dimension
256
+ simon_says_prompt_with_dummy = torch.cat(
257
+ [
258
+ dummy_prompt,
259
+ simon_says_prompt,
260
+ ],
261
+ dim=1,
262
+ )
263
+
264
+ for i in progress.tqdm(range(100), desc="Decoding..."):
265
+ if len(generated_tokens) == 0:
266
+ expanded_prompt = simon_says_prompt_with_dummy
267
+ else:
268
+ expanded_prompt = torch.cat(
269
+ [
270
+ simon_says_prompt_with_dummy,
271
+ get_embeddings(
272
+ torch.tensor(
273
+ generated_tokens, device=simon_says_prompt.device
274
+ ).unsqueeze(0),
275
+ model,
276
+ ),
277
+ ],
278
+ dim=1,
279
  )
280
 
281
+ attention_mask = torch.ones_like(
282
+ expanded_prompt[:, :, 0], device=simon_says_prompt.device
283
+ )
284
+ # Set the first token to 0 in the attention mask
285
+ attention_mask[:, 0] = 0
286
+
287
+ logits = model(
288
+ inputs_embeds=expanded_prompt,
289
+ attention_mask=attention_mask,
290
+ ).logits
291
+ next_token_logits = logits[0, -1, :]
292
+ next_token = next_token_logits.argmax().item()
293
 
294
+ logging.info(
295
+ f"Step {i}, Next Token: {next_token}, Logit: {next_token_logits[next_token].item()}"
296
+ )
297
+
298
+ generated_tokens.append(next_token)
299
+ all_logits.append(next_token_logits)
300
+
301
+ if next_token == tokenizer.eos_token_id:
302
+ break
303
 
304
  generated_tokens = torch.tensor(generated_tokens)
305
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
312
  number_of_simon_says_tokens: int,
313
  n_steps: int,
314
  lr: float,
315
+ model_name: str,
316
+ add_eos_token: bool,
317
  ) -> tuple[str, str]:
318
  """
319
  Optimize the Simon Says prompt, display the optimization process, and generate text based on the input text.
 
323
  number_of_simon_says_tokens (int): Number of Simon Says tokens to optimize.
324
  n_steps (int): Number of optimization steps.
325
  lr (float): Learning rate for the optimization process.
326
+ model_name (str): The name of the model to load.
327
+ add_eos_token (bool): Whether to add an EOS token to the input text.
328
 
329
  Returns:
330
  tuple: The optimized Simon Says prompt and the greedy-decoded text.
331
  """
332
+ tokenizer, model = load_model_and_tokenizer(model_name)
333
  optimized_prompt = optimize_simon_says_prompt(
334
  input_text=input_text,
335
  number_of_simon_says_tokens=number_of_simon_says_tokens,
336
  n_steps=n_steps,
337
  lr=lr,
338
+ model=model,
339
+ tokenizer=tokenizer,
340
+ add_eos_token=add_eos_token,
341
  )
342
 
343
  # Generate text using the optimized prompt
344
+ generated_text: str = greedy_decode_with_simon_says_prompt(
345
+ optimized_prompt, model, tokenizer
346
+ )
347
 
348
  return (
349
  generated_text,
 
352
 
353
 
354
  def process_with_uploaded_tensor(
355
+ input_text: str, uploaded_tensor: torch.Tensor, model_name: str
356
  ) -> tuple[str, str]:
357
  """
358
  Process the uploaded tensor and generate text based on the input text.
 
360
  Parameters:
361
  input_text (str): The input text provided by the user.
362
  uploaded_tensor (torch.Tensor): The uploaded optimized tensor.
363
+ model_name (str): The name of the model to load.
364
 
365
  Returns:
366
  tuple: The generated text and the file path of the uploaded tensor.
367
  """
368
+ tokenizer, model = load_model_and_tokenizer(model_name)
369
+ generated_text = greedy_decode_with_simon_says_prompt(
370
+ uploaded_tensor, model, tokenizer
371
+ )
372
  return generated_text, None
373
 
374
 
 
391
  ],
392
  )
393
 
394
+ # Update the Gradio interface to include the model selection dropdown
395
  demo = gr.Interface(
396
  theme=theme,
397
  title="Simon Says Prompt Optimization and Text Generation",
398
+ fn=lambda input_text, model_name, number_of_simon_says_tokens, n_steps, lr, add_eos_token, uploaded_file: (
399
+ process_with_uploaded_tensor(
400
+ input_text, upload_tensor(uploaded_file), model_name
401
+ )
402
  if uploaded_file
403
  else process_and_generate(
404
+ input_text,
405
+ number_of_simon_says_tokens,
406
+ n_steps,
407
+ lr,
408
+ model_name,
409
+ add_eos_token,
410
  )
411
  ),
412
  inputs=[
 
415
  placeholder="Enter your text here...",
416
  label="Input Text",
417
  value="Hello world! I'm Aldan, happy to be here.",
418
+ info="Provide the text for which you want to optimize the Simon Says prompt. This text will be used as the target for generating the Simon Says Prompt.",
419
+ ),
420
+ gr.Dropdown(
421
+ choices=model_options,
422
+ value="EleutherAI/pythia-160m",
423
+ label="Select Model",
424
+ interactive=True,
425
+ info="Choose a pre-trained language model to use for optimization and text generation. Each model has different capabilities and sizes.",
426
  ),
427
  gr.Slider(
428
+ minimum=1,
429
+ maximum=10,
430
+ step=1,
431
+ value=4,
432
+ label="Number of Simon Says Prompt Tokens",
433
+ info="Specify the number of tokens to include in the Simon Says prompt. Bigger sizes may make it easier to optimize, but they take up more space.",
434
  ),
435
  gr.Slider(
436
  minimum=100,
437
  maximum=10000,
438
  step=100,
439
+ value=10000,
440
+ label="Patience",
441
+ info="Set the maximum number of steps for the optimization process. It will stop early if the optimization converges before reaching this number, but if it reaches the limit, it will stop without converging.",
442
  ),
443
  gr.Slider(
444
+ minimum=1e-5,
445
+ maximum=1e-1,
446
+ step=1e-5,
447
+ value=1e-1,
448
+ label="Learning Rate",
449
+ info="Adjust the learning rate for the optimization algorithm. This controls how quickly the optimization converges but can also lead to instability if set too high.",
450
+ ),
451
+ gr.Checkbox(
452
+ label="Add EOS Token",
453
+ value=False,
454
+ interactive=True,
455
+ info="Enable this option to append an End-Of-Sequence (EOS) token to the input text. This can help models better understand the input context.",
456
+ ),
457
+ gr.File(
458
+ label="Upload Optimized SS Prompt (Optional)",
459
+ type="binary",
460
+ file_count="single",
461
+ file_types=[".safetensors"],
462
  ),
 
463
  ],
464
  outputs=[
465
+ gr.Textbox(
466
+ label="Generated Text",
467
+ info="The text generated by the model using the optimized Simon Says prompt.",
468
+ ),
469
+ gr.File(
470
+ label="Download Optimized SS Prompt",
471
+ type="filepath",
472
+ ),
473
  ],
474
+ description="This application allows you to optimize a Simon Says prompt based on your input text using advanced machine learning techniques. You can visualize the optimization process and generate text using the optimized prompt. Additionally, you can upload a pre-optimized tensor for direct inference (if you do, the other parameters will be ignored).",
475
  )
476
 
477
  # Ensure the Gradio interface is correctly launched