multimodalart HF Staff commited on
Commit
9b91020
·
verified ·
1 Parent(s): 69595ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -331
app.py CHANGED
@@ -1,118 +1,5 @@
1
- # llada_app.py -> dream_app.py (v2)
2
-
3
- import torch
4
- import numpy as np
5
- import gradio as gr
6
- import spaces
7
- # import torch.nn.functional as F # Not needed for DREAM's basic visualization
8
- from transformers import AutoTokenizer, AutoModel
9
- import time
10
- import re # Keep for parsing constraints
11
-
12
- # Use try-except for space deployment vs local
13
- try:
14
- gpu_check = spaces.GPU
15
- print("Running in Gradio Spaces with GPU environment.")
16
- except AttributeError:
17
- print("Running in local environment or without spaces.GPU.")
18
- def gpu_check(func): return func
19
-
20
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
- print(f"Using device: {device}")
22
-
23
- # --- Load DREAM Model and Tokenizer ---
24
- model_path = "Dream-org/Dream-v0-Instruct-7B"
25
- print(f"Loading model: {model_path}")
26
- try:
27
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device).eval()
28
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
29
- print("Model and tokenizer loaded.")
30
- except Exception as e:
31
- print(f"FATAL: Could not load model/tokenizer. Error: {e}")
32
- # Optionally exit or raise
33
- raise SystemExit(f"Failed to load model: {e}")
34
-
35
-
36
- # --- Constants for DREAM ---
37
- # Find mask token and ID
38
- if tokenizer.mask_token is None:
39
- print("Warning: Mask token not explicitly set in tokenizer. Trying to add '[MASK]'.")
40
- # This might require retraining/fine-tuning if the model didn't see it.
41
- # Check if it exists first before adding
42
- if '[MASK]' not in tokenizer.get_vocab():
43
- tokenizer.add_special_tokens({'mask_token': '[MASK]'})
44
- model.resize_token_embeddings(len(tokenizer)) # Resize model embeddings
45
- print("Added '[MASK]' and resized embeddings.")
46
- else:
47
- tokenizer.mask_token = '[MASK]' # Set it if it exists but wasn't assigned
48
- print("Found existing '[MASK]', assigned as mask_token.")
49
-
50
- MASK_TOKEN = tokenizer.mask_token
51
- MASK_ID = tokenizer.mask_token_id
52
- if MASK_ID is None:
53
- raise ValueError("Failed to get MASK_ID after attempting to set mask_token.")
54
- print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
55
-
56
- # Get EOS and PAD token IDs
57
- EOS_TOKEN_ID = tokenizer.eos_token_id
58
- PAD_TOKEN_ID = tokenizer.pad_token_id
59
- print(f"Using EOS_TOKEN_ID={EOS_TOKEN_ID}, PAD_TOKEN_ID={PAD_TOKEN_ID}")
60
- # Handle cases where they might be None (though unlikely for most models)
61
- if EOS_TOKEN_ID is None:
62
- print("Warning: EOS token ID not found.")
63
- if PAD_TOKEN_ID is None:
64
- print("Warning: PAD token ID not found. Using EOS ID as fallback for hiding.")
65
- PAD_TOKEN_ID = EOS_TOKEN_ID # Use EOS as a fallback for hiding logic if PAD is missing
66
-
67
-
68
- # --- Helper Functions (Constraint Parsing, History Formatting) ---
69
- # (Keep parse_constraints and format_chat_history functions as they were)
70
- def parse_constraints(constraints_text):
71
- """Parse constraints in format: 'position:word, position:word, ...'"""
72
- constraints = {}
73
- if not constraints_text:
74
- return constraints
75
-
76
- parts = constraints_text.split(',')
77
- for part in parts:
78
- part = part.strip() # Trim whitespace
79
- if ':' not in part:
80
- continue
81
- try:
82
- pos_str, word = part.split(':', 1)
83
- pos = int(pos_str.strip())
84
- word = word.strip()
85
- # Allow empty words if needed, but usually we want a word
86
- if word and pos >= 0:
87
- constraints[pos] = word
88
- except ValueError:
89
- print(f"Warning: Could not parse constraint part: '{part}'")
90
- continue
91
-
92
- return constraints
93
-
94
- def format_chat_history(history):
95
- """
96
- Format chat history for the DREAM model (standard messages format)
97
-
98
- Args:
99
- history: List of [user_message, assistant_message] pairs
100
-
101
- Returns:
102
- Formatted conversation for the model (list of dictionaries)
103
- """
104
- messages = []
105
- # Add system prompt if desired (check DREAM examples/recommendations)
106
- # messages.append({"role": "system", "content": "You are a helpful assistant."}) # Optional
107
- for user_msg, assistant_msg in history:
108
- if user_msg: # Handle potential None message if clearing failed
109
- messages.append({"role": "user", "content": user_msg})
110
- if assistant_msg: # Skip if None (for the latest user message awaiting response)
111
- messages.append({"role": "assistant", "content": assistant_msg})
112
-
113
- return messages
114
-
115
- # --- Core Generation Logic for DREAM with Visualization ---
116
 
117
  @gpu_check
118
  def dream_generate_response_with_visualization(
@@ -126,15 +13,29 @@ def dream_generate_response_with_visualization(
126
  alg_temp=0.0,
127
  ):
128
  """
129
- Generate text with DREAM model with visualization using the generation hook.
130
- Hides special tokens (EOS, PAD) and uses labels for coloring.
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  """
132
  print("--- Starting DREAM Generation ---")
133
  print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
134
  print(f"Constraints: {constraints}")
135
 
136
  # --- Input Preparation ---
137
- if constraints is None: constraints = {}
 
138
 
139
  processed_constraints = {}
140
  print("Processing constraints:")
@@ -152,29 +53,43 @@ def dream_generate_response_with_visualization(
152
 
153
  try:
154
  inputs = tokenizer.apply_chat_template(
155
- messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
 
 
 
156
  )
157
  input_ids = inputs.input_ids.to(device=device)
158
  attention_mask = inputs.attention_mask.to(device=device)
159
  prompt_length = input_ids.shape[1]
160
  print(f"Input prompt length: {prompt_length}")
 
161
  except Exception as e:
162
  print(f"Error applying chat template: {e}")
163
- return [([("Error applying chat template.", "Error")],)], f"Error: {e}" # Use 'Error' label
164
 
165
- # Check context length (DREAM uses 2048)
166
  if prompt_length + gen_length > 2048:
167
  print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
168
  gen_length = 2048 - prompt_length
169
  if gen_length <= 0:
170
  print("Error: Prompt is already too long.")
171
- return [([("Prompt too long.", "Error")],)], "Error: Prompt too long."
172
 
173
  # --- State for Visualization Hook ---
174
  visualization_states = []
175
  last_x = None
176
 
177
- # Initial state: Prompt + all masks + initial constraints
 
 
 
 
 
 
 
 
 
 
 
178
  initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
179
  for pos, token_id in processed_constraints.items():
180
  absolute_pos = pos
@@ -185,25 +100,16 @@ def dream_generate_response_with_visualization(
185
  for i in range(gen_length):
186
  token_id = initial_x_part[0, i].item()
187
  if token_id == MASK_ID:
188
- initial_state_vis.append((MASK_TOKEN, "Mask"))
189
- elif token_id == EOS_TOKEN_ID or token_id == PAD_TOKEN_ID:
190
- initial_state_vis.append(("", None)) # Hide special tokens
191
- elif i in processed_constraints and processed_constraints[i] == token_id:
192
- token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
193
- display_token = token_str if token_str else "?"
194
- initial_state_vis.append((display_token, "Constraint"))
195
  else:
196
- # Should only be constraints here, but add fallback
197
- token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
198
- display_token = token_str if token_str else "?"
199
- initial_state_vis.append((display_token, "Old")) # Treat unexpected initial non-masks as 'Old'
200
  visualization_states.append(initial_state_vis)
201
 
202
-
203
  # --- Define the Hook Function ---
204
  def generation_tokens_hook_func(step, x, logits):
205
- nonlocal last_x, visualization_states
206
- # print(f"Hook called for step {step}") # Verbose logging
207
 
208
  current_x = x.clone()
209
  constrained_x = current_x.clone()
@@ -213,13 +119,11 @@ def dream_generate_response_with_visualization(
213
  return current_x
214
 
215
  # 1. Apply Constraints
216
- constraints_applied_this_step = False
217
  for pos, token_id in processed_constraints.items():
218
  absolute_pos = prompt_len + pos
219
  if prompt_len <= absolute_pos < current_x.shape[1]:
220
  if constrained_x[0, absolute_pos] != token_id:
221
  constrained_x[0, absolute_pos] = token_id
222
- constraints_applied_this_step = True
223
 
224
  # 2. Generate Visualization State for *this* step
225
  current_state_vis = []
@@ -229,33 +133,52 @@ def dream_generate_response_with_visualization(
229
  for i in range(gen_length):
230
  current_token_id = gen_part_current[i].item()
231
 
232
- # --- Logic to Hide Special Tokens ---
233
- if current_token_id == EOS_TOKEN_ID or current_token_id == PAD_TOKEN_ID:
234
- # Maybe show on first appearance? For now, always hide.
235
- # LLaDA's behavior: "shown once and then disappear"
236
- # Let's implement the simpler "always hide" first.
237
- current_state_vis.append(("", None)) # Append empty string, no label -> hidden
238
- continue # Move to next token
239
-
240
- # --- Decode and Determine Label ---
241
- token_str = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
242
- display_token = token_str if token_str else MASK_TOKEN if current_token_id == MASK_ID else "?" # Use MASK_TOKEN if decode fails
243
 
244
- label = None # Default label (no color)
245
  is_constrained = i in processed_constraints
 
 
 
 
 
 
 
 
 
246
 
 
247
  if current_token_id == MASK_ID:
248
- label = "Mask"
249
- elif is_constrained and processed_constraints[i] == current_token_id:
250
- label = "Constraint"
251
- elif gen_part_last is None or gen_part_last[i].item() == MASK_ID or gen_part_last[i].item() == EOS_TOKEN_ID or gen_part_last[i].item() == PAD_TOKEN_ID:
252
- # Newly revealed (was mask or hidden special token in previous step)
253
- label = "New"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  else:
255
- # Previously revealed and not masked/hidden/constrained
256
- label = "Old"
 
257
 
258
- current_state_vis.append((display_token, label))
259
 
260
  visualization_states.append(current_state_vis)
261
 
@@ -265,11 +188,12 @@ def dream_generate_response_with_visualization(
265
  # 4. Return the sequence with constraints applied
266
  return constrained_x
267
 
 
268
  # --- Run DREAM Generation ---
269
  try:
270
  print("Calling model.diffusion_generate...")
271
  initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
272
- last_x = initial_full_x.clone() # Initialize last_x *before* the call
273
 
274
  output = model.diffusion_generate(
275
  input_ids,
@@ -289,33 +213,45 @@ def dream_generate_response_with_visualization(
289
  final_sequence = output.sequences[0]
290
  response_token_ids = final_sequence[prompt_length:]
291
 
292
- # Decode final text, skipping special tokens
293
  final_text = tokenizer.decode(
294
  response_token_ids,
295
  skip_special_tokens=True,
296
  clean_up_tokenization_spaces=True
297
  ).strip()
298
- print(f"Final generated text: {final_text}")
299
 
300
- # Safeguard: Add final state visualization if needed (using the new label logic)
 
301
  if len(visualization_states) <= steps:
 
302
  final_state_vis = []
303
  final_gen_part = final_sequence[prompt_length:]
 
 
 
304
  for i in range(gen_length):
305
- token_id = final_gen_part[i].item()
306
- if token_id == EOS_TOKEN_ID or token_id == PAD_TOKEN_ID:
307
- final_state_vis.append(("", None))
308
- continue
309
-
310
- token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()
311
- display_token = token_str if token_str else MASK_TOKEN if token_id == MASK_ID else "?"
312
- label = None
313
  is_constrained = i in processed_constraints
314
-
315
- if token_id == MASK_ID: label = "Mask"
316
- elif is_constrained and processed_constraints[i] == token_id: label = "Constraint"
317
- else: label = "Old" # Default to 'Old' for final state non-masked tokens
318
- final_state_vis.append((display_token, label))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  visualization_states.append(final_state_vis)
320
 
321
 
@@ -324,160 +260,8 @@ def dream_generate_response_with_visualization(
324
  import traceback
325
  traceback.print_exc()
326
  error_msg = f"Error during generation: {str(e)}"
327
- # Use 'Error' label for color mapping
328
- visualization_states.append([("Error", "Error")])
329
  final_text = f"Generation failed: {e}"
330
 
331
  print("--- DREAM Generation Finished ---")
332
- return visualization_states, final_text
333
-
334
-
335
- # --- Gradio UI Setup ---
336
-
337
- css = '''
338
- .category-legend{display:none}
339
- /* button{height: 60px} */
340
- .small_btn {max-width: 100px; height: 40px; flex-grow: 0; margin-left: 5px;}
341
- .chat-input-row {display: flex; align-items: center;}
342
- .chat-input-row > * {margin-right: 5px;}
343
- .chat-input-row > *:last-child {margin-right: 0;}
344
- '''
345
- def create_chatbot_demo():
346
- with gr.Blocks(css=css) as demo:
347
- gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
348
- gr.Markdown("Watch the text generate step-by-step. Special tokens (EOS, PAD) are hidden.")
349
- gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B) - [Blog Post](https://hkunlp.github.io/blog/2025/dream/)")
350
-
351
- # STATE MANAGEMENT
352
- chat_history = gr.State([])
353
-
354
- # UI COMPONENTS
355
- with gr.Row():
356
- with gr.Column(scale=3):
357
- chatbot_ui = gr.Chatbot(
358
- label="Conversation", height=500, bubble_full_width=False
359
- )
360
- with gr.Row(elem_classes="chat-input-row"):
361
- user_input = gr.Textbox(
362
- label="Your Message", placeholder="Type your message...",
363
- scale=4, container=False, show_label=False
364
- )
365
- send_btn = gr.Button("Send", scale=1, elem_classes="small_btn")
366
-
367
- constraints_input = gr.Textbox(
368
- label="Word Constraints (Optional)",
369
- info="Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon'",
370
- placeholder="e.g., 0:Hello, 6:world", value=""
371
- )
372
- with gr.Column(scale=2):
373
- # --- Updated HighlightedText with color_map ---
374
- output_vis = gr.HighlightedText(
375
- label="Denoising Process Visualization",
376
- combine_adjacent=True, # Combine adjacent tokens with same label
377
- show_legend=False, # Keep legend off
378
- color_map={ # Map labels to colors
379
- "Mask": "#A0A0A0", # Lighter Gray for Mask
380
- "New": "#66CC66", # Light Green
381
- "Old": "#6699CC", # Light Blue
382
- "Constraint": "#B266FF", # Lighter Purple/Violet
383
- "Error": "#FF6666" # Light Red
384
- }
385
- )
386
- gr.Markdown(
387
- # Update legend text to match labels
388
- "**Color Legend:** <span style='color:#A0A0A0'>■ Mask</span> | <span style='color:#66CC66'>■ New</span> | <span style='color:#6699CC'>■ Old</span> | <span style='color:#B266FF'>■ Constraint</span>"
389
- )
390
-
391
-
392
- # Advanced generation settings (Keep as before)
393
- with gr.Accordion("Generation Settings", open=False):
394
- with gr.Row():
395
- gen_length = gr.Slider(minimum=16, maximum=512, value=128, step=8, label="Max New Tokens")
396
- steps = gr.Slider(minimum=8, maximum=512, value=128, step=8, label="Diffusion Steps")
397
- with gr.Row():
398
- temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.6, step=0.05, label="Temperature")
399
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-P (Nucleus Sampling)")
400
- with gr.Row():
401
- remasking_strategy = gr.Radio(
402
- choices=[("Random", "origin"), ("Entropy", "entropy"), ("MaskGit+", "maskgit_plus"), ("TopK Margin", "topk_margin")],
403
- value="entropy", label="Generation Order Strategy (alg)"
404
- )
405
- alg_temp = gr.Slider(
406
- minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Order Randomness (alg_temp)",
407
- info="Adds randomness to non-Random strategies. Ignored for Random."
408
- )
409
- with gr.Row():
410
- visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.05, step=0.01, label="Visualization Delay (seconds)")
411
-
412
- clear_btn = gr.Button("Clear Conversation")
413
-
414
- # --- Event Handlers (Keep as before) ---
415
- def add_message_to_history(history, message, response):
416
- history = history.copy(); history.append([message, response]); return history
417
-
418
- def user_message_submitted(message, history):
419
- print(f"User submitted: '{message}'")
420
- if not message or not message.strip():
421
- print("Empty message submitted, doing nothing."); return history, history, "", []
422
- history = add_message_to_history(history, message, None)
423
- history_for_display = history.copy()
424
- message_out = ""; vis_clear = []
425
- return history, history_for_display, message_out, vis_clear
426
-
427
- def bot_response_generator(
428
- history, gen_length, steps, constraints_text, delay,
429
- temperature, top_p, alg, alg_temp
430
- ):
431
- print("--- Generating Bot Response ---")
432
- if not history or history[-1][1] is not None:
433
- print("History empty or last message already has response. Skipping generation.")
434
- yield history, [], "No response generated." # Yield current state if called unnecessarily
435
- return
436
-
437
- messages = format_chat_history(history)
438
- parsed_constraints = parse_constraints(constraints_text)
439
-
440
- try:
441
- vis_states, response_text = dream_generate_response_with_visualization(
442
- messages, gen_length=gen_length, steps=steps, constraints=parsed_constraints,
443
- temperature=temperature, top_p=top_p, alg=alg, alg_temp=alg_temp
444
- )
445
- history[-1][1] = response_text.strip() # Update history state
446
-
447
- if vis_states:
448
- # Yield initial state first
449
- yield history, vis_states[0] # Update chatbot, update visualization
450
- # Animate remaining states
451
- for state in vis_states[1:]:
452
- time.sleep(delay)
453
- yield history, state # Update chatbot (implicitly), update visualization
454
- else:
455
- yield history, [("Generation failed.", "Error")] # Use label
456
-
457
- except Exception as e:
458
- print(f"Error in bot_response_generator: {e}")
459
- import traceback; traceback.print_exc()
460
- error_msg = f"Error: {str(e)}"
461
- error_vis = [(error_msg, "Error")] # Use label
462
- yield history, error_vis
463
-
464
- def clear_conversation():
465
- print("Clearing conversation."); return [], [], "", []
466
-
467
- # --- Wire UI elements (Keep as before) ---
468
- user_input.submit(fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)\
469
- .then(fn=bot_response_generator, inputs=[history, gen_length, steps, constraints_input, visualization_delay, temperature, top_p, remasking_strategy, alg_temp], outputs=[chatbot_ui, output_vis])
470
-
471
- send_btn.click(fn=user_message_submitted, inputs=[user_input, chat_history], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)\
472
- .then(fn=bot_response_generator, inputs=[history, gen_length, steps, constraints_input, visualization_delay, temperature, top_p, remasking_strategy, alg_temp], outputs=[chatbot_ui, output_vis])
473
-
474
- clear_btn.click(fn=clear_conversation, inputs=[], outputs=[chat_history, chatbot_ui, user_input, output_vis], queue=False)
475
-
476
- return demo
477
-
478
- # --- Launch the Gradio App ---
479
- if __name__ == "__main__":
480
- print("Creating Gradio demo...")
481
- demo = create_chatbot_demo()
482
- print("Launching Gradio demo...")
483
- demo.queue().launch(share=True, debug=True)
 
1
+ # Replace the existing dream_generate_response_with_visualization function
2
+ # in the previous code block with this updated version.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  @gpu_check
5
  def dream_generate_response_with_visualization(
 
13
  alg_temp=0.0,
14
  ):
15
  """
16
+ Generate text with DREAM model with visualization using the generation hook,
17
+ including special token handling (show once, then hide).
18
+
19
+ Args:
20
+ messages: List of message dictionaries with 'role' and 'content'
21
+ gen_length: Length of text to generate (max_new_tokens)
22
+ steps: Number of diffusion steps
23
+ constraints: Dictionary mapping positions (relative to response start) to words
24
+ temperature: Sampling temperature
25
+ top_p: Nucleus sampling p
26
+ alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy')
27
+ alg_temp: Temperature for confidence-based algorithms
28
+
29
+ Returns:
30
+ Tuple: (List of visualization states, final generated text string)
31
  """
32
  print("--- Starting DREAM Generation ---")
33
  print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
34
  print(f"Constraints: {constraints}")
35
 
36
  # --- Input Preparation ---
37
+ if constraints is None:
38
+ constraints = {}
39
 
40
  processed_constraints = {}
41
  print("Processing constraints:")
 
53
 
54
  try:
55
  inputs = tokenizer.apply_chat_template(
56
+ messages,
57
+ return_tensors="pt",
58
+ return_dict=True,
59
+ add_generation_prompt=True
60
  )
61
  input_ids = inputs.input_ids.to(device=device)
62
  attention_mask = inputs.attention_mask.to(device=device)
63
  prompt_length = input_ids.shape[1]
64
  print(f"Input prompt length: {prompt_length}")
65
+ # print(f"Input IDs: {input_ids}") # Verbose
66
  except Exception as e:
67
  print(f"Error applying chat template: {e}")
68
+ return [([("Error applying chat template.", "red")],)], f"Error: {e}"
69
 
 
70
  if prompt_length + gen_length > 2048:
71
  print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
72
  gen_length = 2048 - prompt_length
73
  if gen_length <= 0:
74
  print("Error: Prompt is already too long.")
75
+ return [([("Prompt too long.", "red")],)], "Error: Prompt too long."
76
 
77
  # --- State for Visualization Hook ---
78
  visualization_states = []
79
  last_x = None
80
 
81
+ # Define special token IDs to hide after first reveal
82
+ # Using a set for efficient lookup. Add others if needed (e.g., pad_token_id).
83
+ special_token_ids_to_hide = {
84
+ tokenizer.eos_token_id,
85
+ tokenizer.pad_token_id,
86
+ # tokenizer.bos_token_id # Usually not generated mid-sequence
87
+ }
88
+ # Filter out None values if some special tokens aren't defined
89
+ special_token_ids_to_hide = {tid for tid in special_token_ids_to_hide if tid is not None}
90
+ print(f"Special token IDs to hide visually after reveal: {special_token_ids_to_hide}")
91
+
92
+
93
  initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
94
  for pos, token_id in processed_constraints.items():
95
  absolute_pos = pos
 
100
  for i in range(gen_length):
101
  token_id = initial_x_part[0, i].item()
102
  if token_id == MASK_ID:
103
+ initial_state_vis.append((MASK_TOKEN, "#444444")) # Mask color
 
 
 
 
 
 
104
  else:
105
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True)
106
+ initial_state_vis.append((token_str if token_str else "?", "#800080")) # Constraint color (purple)
 
 
107
  visualization_states.append(initial_state_vis)
108
 
 
109
  # --- Define the Hook Function ---
110
  def generation_tokens_hook_func(step, x, logits):
111
+ nonlocal last_x, visualization_states # Allow modification of outer scope variables
112
+ # print(f"Hook called for step {step}") # Verbose
113
 
114
  current_x = x.clone()
115
  constrained_x = current_x.clone()
 
119
  return current_x
120
 
121
  # 1. Apply Constraints
 
122
  for pos, token_id in processed_constraints.items():
123
  absolute_pos = prompt_len + pos
124
  if prompt_len <= absolute_pos < current_x.shape[1]:
125
  if constrained_x[0, absolute_pos] != token_id:
126
  constrained_x[0, absolute_pos] = token_id
 
127
 
128
  # 2. Generate Visualization State for *this* step
129
  current_state_vis = []
 
133
  for i in range(gen_length):
134
  current_token_id = gen_part_current[i].item()
135
 
136
+ # Basic check for safety, though unlikely needed with prompt_len check
137
+ if current_token_id is None:
138
+ current_state_vis.append((MASK_TOKEN, "#444444"))
139
+ continue
 
 
 
 
 
 
 
140
 
 
141
  is_constrained = i in processed_constraints
142
+ is_special = current_token_id in special_token_ids_to_hide
143
+
144
+ # Decode carefully: don't skip specials initially for display text
145
+ raw_token_str = tokenizer.decode([current_token_id], skip_special_tokens=False).strip()
146
+ # Use MASK_TOKEN string for MASK_ID, otherwise use decoded string or '?'
147
+ display_token = MASK_TOKEN if current_token_id == MASK_ID else (raw_token_str if raw_token_str else "?")
148
+
149
+ # Determine the state based on current and previous token
150
+ previous_token_id = gen_part_last[i].item() if gen_part_last is not None else None
151
 
152
+ # Determine color and potentially modify display_token for hiding
153
  if current_token_id == MASK_ID:
154
+ color = "#444444" # Dark Gray
155
+ display_token = MASK_TOKEN
156
+ elif is_constrained and processed_constraints.get(i) == current_token_id:
157
+ color = "#800080" # Purple - keep constraint visible
158
+ # Even if special, show the constraint for clarity
159
+ elif previous_token_id == MASK_ID or previous_token_id is None:
160
+ # --- Newly revealed in this step ---
161
+ if is_special:
162
+ # Newly revealed special token: Show it this time
163
+ color = "#FF8C00" # Dark Orange (distinct color for first reveal)
164
+ # display_token is already the raw special token string
165
+ else:
166
+ # Newly revealed regular token
167
+ color = "#66CC66" # Light Green
168
+ # display_token is already the regular token string
169
+ elif is_special:
170
+ # --- Was revealed previously AND is special: Hide it now ---
171
+ color = "#FFFFFF" # White background / Transparent conceptually
172
+ display_token = "" # Make it disappear visually
173
+ # Alternative: Subtle placeholder
174
+ # display_token = "."
175
+ # color = "#EEEEEE"
176
  else:
177
+ # --- Previously revealed regular token ---
178
+ color = "#6699CC" # Light Blue
179
+ # display_token is already the regular token string
180
 
181
+ current_state_vis.append((display_token, color))
182
 
183
  visualization_states.append(current_state_vis)
184
 
 
188
  # 4. Return the sequence with constraints applied
189
  return constrained_x
190
 
191
+
192
  # --- Run DREAM Generation ---
193
  try:
194
  print("Calling model.diffusion_generate...")
195
  initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
196
+ last_x = initial_full_x.clone() # Initialize last_x for the first hook call
197
 
198
  output = model.diffusion_generate(
199
  input_ids,
 
213
  final_sequence = output.sequences[0]
214
  response_token_ids = final_sequence[prompt_length:]
215
 
216
+ # Decode final text *skipping* special tokens for the chatbot display
217
  final_text = tokenizer.decode(
218
  response_token_ids,
219
  skip_special_tokens=True,
220
  clean_up_tokenization_spaces=True
221
  ).strip()
222
+ print(f"Final generated text (cleaned): {final_text}")
223
 
224
+ # Add the very final state to visualization if needed (safeguard)
225
+ # This uses the same logic as the hook for consistency
226
  if len(visualization_states) <= steps:
227
+ print("Adding final visualization state manually (safeguard).")
228
  final_state_vis = []
229
  final_gen_part = final_sequence[prompt_length:]
230
+ # Need the state *before* this final one to know what was 'new'
231
+ gen_part_last_final = last_x[0, prompt_len:] if last_x is not None else None
232
+
233
  for i in range(gen_length):
234
+ current_token_id = final_gen_part[i].item()
 
 
 
 
 
 
 
235
  is_constrained = i in processed_constraints
236
+ is_special = current_token_id in special_token_ids_to_hide
237
+ raw_token_str = tokenizer.decode([current_token_id], skip_special_tokens=False).strip()
238
+ display_token = MASK_TOKEN if current_token_id == MASK_ID else (raw_token_str if raw_token_str else "?")
239
+ previous_token_id = gen_part_last_final[i].item() if gen_part_last_final is not None else None
240
+
241
+ if current_token_id == MASK_ID:
242
+ color = "#444444"
243
+ display_token = MASK_TOKEN
244
+ elif is_constrained and processed_constraints.get(i) == current_token_id:
245
+ color = "#800080"
246
+ elif previous_token_id == MASK_ID or previous_token_id is None: # Newly revealed
247
+ color = "#FF8C00" if is_special else "#66CC66"
248
+ elif is_special: # Previously revealed special
249
+ color = "#FFFFFF"
250
+ display_token = ""
251
+ else: # Previously revealed regular
252
+ color = "#6699CC"
253
+
254
+ final_state_vis.append((display_token, color))
255
  visualization_states.append(final_state_vis)
256
 
257
 
 
260
  import traceback
261
  traceback.print_exc()
262
  error_msg = f"Error during generation: {str(e)}"
263
+ visualization_states.append([("Error", "red")])
 
264
  final_text = f"Generation failed: {e}"
265
 
266
  print("--- DREAM Generation Finished ---")
267
+ return visualization_states, final_text