multimodalart HF Staff commited on
Commit
ce90309
·
verified ·
1 Parent(s): 47fc4a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +435 -649
app.py CHANGED
@@ -1,96 +1,62 @@
 
1
  import torch
2
- # import numpy as np # Not strictly needed anymore
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoTokenizer, AutoModel
6
  import time
7
- import re # Keep for parsing constraints
8
-
9
- # Use try-except for space deployment vs local
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
- # Used for spaces deployment with GPU
12
- gpu_check = spaces.GPU
13
- print("Running in Gradio Spaces with GPU environment.")
14
  except AttributeError:
15
- # Fallback for local execution or environments without spaces.GPU
16
- print("Running in local environment or without spaces.GPU.")
17
- # Define a dummy decorator if spaces.GPU is not available
18
- def gpu_check(func):
19
- return func
20
-
21
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
- print(f"Using device: {device}")
23
-
24
- # --- Load DREAM Model and Tokenizer ---
25
- # Ensure sufficient VRAM, Dream 7B needs ~16GB+ VRAM in bfloat16
26
- model_path = "Dream-org/Dream-v0-Instruct-7B"
27
- print(f"Loading model: {model_path}...")
28
  try:
29
- model = AutoModel.from_pretrained(
30
- model_path,
31
- torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
32
- trust_remote_code=True,
33
- # device_map='auto' # Consider if running into OOM errors, might split across GPUs/CPU
34
- ).to(device).eval()
35
- tokenizer = AutoTokenizer.from_pretrained(
36
- model_path,
37
- trust_remote_code=True
38
- )
39
- print("Model and tokenizer loaded successfully.")
40
- except Exception as e:
41
- print(f"Error loading model or tokenizer: {e}")
42
- print("Please ensure you have enough GPU memory and the model files are accessible.")
43
- # Exit or raise if loading fails
44
- raise e
45
-
46
-
47
- # --- Constants for DREAM ---
48
- # Find the mask token and ID from the DREAM tokenizer
49
- if tokenizer.mask_token is None:
50
- print("Warning: Mask token not found in tokenizer. Attempting to add '[MASK]'.")
51
- # This might require retraining or fine-tuning if the model didn't see this token
52
- num_added = tokenizer.add_special_tokens({'mask_token': '[MASK]'})
53
- if num_added > 0:
54
- print(f"Added '{tokenizer.mask_token}' to tokenizer.")
55
- # Resize model embeddings if vocab changed
56
- model.resize_token_embeddings(len(tokenizer))
57
- print("Resized model token embeddings.")
58
- else:
59
- # Fallback or error if adding failed or mask token still None
60
- # It's possible a different token serves this purpose in DREAM's training
61
- print("Error: Could not set a mask token. Visualization might be inaccurate.")
62
- # You might need to identify which token ID DREAM uses internally for masking tasks
63
- # For now, we'll proceed but this is a potential issue.
64
- MASK_TOKEN = "<?>" # Placeholder symbol
65
- MASK_ID = -1 # Invalid ID indicates issue
66
- if tokenizer.mask_token is None:
67
- raise ValueError("Could not set a mask token for the tokenizer.")
68
-
69
- MASK_TOKEN = tokenizer.mask_token
70
- MASK_ID = tokenizer.mask_token_id
71
- print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
72
-
73
- # Identify other special tokens to potentially hide/show
74
- eos_token_id = tokenizer.eos_token_id
75
- pad_token_id = tokenizer.pad_token_id
76
- special_token_ids_set = {MASK_ID} # Start with Mask ID
77
- if eos_token_id is not None:
78
- special_token_ids_set.add(eos_token_id)
79
- print(f"EOS token ID: {eos_token_id} ({tokenizer.decode([eos_token_id])})")
80
- if pad_token_id is not None:
81
- special_token_ids_set.add(pad_token_id)
82
- print(f"PAD token ID: {pad_token_id} ({tokenizer.decode([pad_token_id])})")
83
- # Add other common special tokens if needed (e.g., BOS, UNK)
84
- if tokenizer.bos_token_id is not None:
85
- special_token_ids_set.add(tokenizer.bos_token_id)
86
- print(f"BOS token ID: {tokenizer.bos_token_id} ({tokenizer.decode([tokenizer.bos_token_id])})")
87
- if tokenizer.unk_token_id is not None:
88
- special_token_ids_set.add(tokenizer.unk_token_id)
89
- print(f"UNK token ID: {tokenizer.unk_token_id} ({tokenizer.decode([tokenizer.unk_token_id])})")
90
-
91
- print(f"Identified special token IDs: {special_token_ids_set}")
92
-
93
- # --- Helper Functions (Constraint Parsing, History Formatting) ---
94
 
95
  def parse_constraints(constraints_text):
96
  """Parse constraints in format: 'position:word, position:word, ...'"""
@@ -100,674 +66,494 @@ def parse_constraints(constraints_text):
100
 
101
  parts = constraints_text.split(',')
102
  for part in parts:
103
- part = part.strip() # Trim whitespace
104
  if ':' not in part:
105
  continue
106
  try:
107
  pos_str, word = part.split(':', 1)
108
  pos = int(pos_str.strip())
109
  word = word.strip()
110
- # Allow empty words if needed? Forcing empty seems odd. Let's require a word.
111
  if word and pos >= 0:
112
- constraints[pos] = word
 
 
 
 
113
  except ValueError:
114
- print(f"Warning: Could not parse constraint part: '{part}'")
 
 
115
  continue
116
 
117
  return constraints
118
 
119
  def format_chat_history(history):
120
  """
121
- Format chat history for the DREAM model (standard messages format)
122
 
123
  Args:
124
  history: List of [user_message, assistant_message] pairs
125
 
126
  Returns:
127
- Formatted conversation for the model (list of dictionaries)
128
  """
129
  messages = []
130
- # Check if a system prompt is appropriate for Dream-Instruct
131
- # From demo_completion.py example, it seems it uses system prompt via template
132
- # messages.append({"role": "system", "content": "You are a helpful assistant."})
 
 
 
133
  for user_msg, assistant_msg in history:
134
- if user_msg is not None: # Handle potential None message if clearing failed
135
  messages.append({"role": "user", "content": user_msg})
136
- if assistant_msg: # Skip if None (for the latest user message awaiting response)
137
  messages.append({"role": "assistant", "content": assistant_msg})
138
 
139
  return messages
140
 
141
- # --- Core Generation Logic for DREAM with Visualization ---
142
 
143
- @gpu_check # Use the potentially dummy decorator
144
- @torch.no_grad() # Disable gradient calculations for inference
145
- def dream_generate_response_with_visualization(
146
- messages,
147
- gen_length=128,
148
- steps=128, # Default based on DREAM examples
 
 
 
 
 
 
149
  constraints=None,
150
- temperature=0.6, # Default based on DREAM examples
151
- top_p=0.95, # Default based on DREAM examples
152
- alg="entropy", # Default based on DREAM examples
153
- alg_temp=0.1, # Default based on DREAM examples
154
  ):
155
  """
156
- Generate text with DREAM model with visualization using the generation hook.
157
 
158
  Args:
159
- messages: List of message dictionaries with 'role' and 'content'
160
- gen_length: Length of text to generate (max_new_tokens)
161
- steps: Number of diffusion steps
162
- constraints: Dictionary mapping positions (relative to response start) to words
163
- temperature: Sampling temperature
164
- top_p: Nucleus sampling p
165
- alg: Remasking algorithm ('origin', 'maskgit_plus', 'topk_margin', 'entropy')
166
- alg_temp: Temperature for confidence-based algorithms
167
 
168
  Returns:
169
  Tuple: (List of visualization states, final generated text string)
170
  """
171
- print("\n--- Starting DREAM Generation ---")
172
- print(f"Params: len={gen_length}, steps={steps}, temp={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
173
- print(f"Constraints: {constraints}")
174
-
175
- # --- Input Preparation ---
176
  if constraints is None:
177
  constraints = {}
178
 
179
- # Convert word constraints to token IDs (handle multi-token words)
180
- processed_constraints = {}
181
- constraint_token_lengths = {} # Store length for multi-token constraints
182
- print("Processing constraints:")
183
- for pos, word in constraints.items():
184
- # Prepend space for potentially better tokenization consistency
185
- # (though apply_chat_template should handle spacing)
186
- tokens = tokenizer.encode(" " + word, add_special_tokens=False)
187
- if not tokens:
188
- print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
189
- continue
190
- print(f" Pos {pos}, Word '{word}' -> Tokens {tokens} ({tokenizer.convert_ids_to_tokens(tokens)})")
191
- constraint_token_lengths[pos] = len(tokens)
192
- for i, token_id in enumerate(tokens):
193
- target_pos = pos + i
194
- if target_pos in processed_constraints:
195
- print(f" Warning: Overlapping constraint token at position {target_pos}. Keeping first constraint's token ({processed_constraints[target_pos]}).")
196
- else:
197
- processed_constraints[target_pos] = token_id
198
 
199
  # Prepare the prompt using chat template
 
200
  try:
201
  inputs = tokenizer.apply_chat_template(
202
  messages,
203
  return_tensors="pt",
204
- return_dict=True,
205
- add_generation_prompt=True # Crucial for Dream-Instruct
206
  )
207
- input_ids = inputs.input_ids.to(device=device)
208
- # Use the attention mask generated by the template
209
- attention_mask = inputs.attention_mask.to(device=device)
210
- prompt_length = input_ids.shape[1]
211
- print(f"Input prompt length: {prompt_length}")
212
- # print(f"Input IDs: {input_ids}")
213
- # print(f"Attention Mask: {attention_mask}") # Verify mask covers prompt
214
  except Exception as e:
215
  print(f"Error applying chat template: {e}")
216
- return [([("Error applying chat template.", "red")],)], f"Error: {e}"
217
-
218
- # Check context length (DREAM uses 2048 default)
219
- model_max_length = getattr(model.config, 'max_position_embeddings', 2048)
220
- if prompt_length + gen_length > model_max_length:
221
- print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length ({model_max_length}). Truncating gen_length.")
222
- gen_length = model_max_length - prompt_length
223
- if gen_length <= 0:
224
- print("Error: Prompt is already too long.")
225
- return [([("Prompt too long.", "red")],)], "Error: Prompt too long."
226
-
227
- # --- State for Visualization Hook ---
228
- visualization_states = []
229
- last_x = None # Store the full sequence (prompt + generation) from the previous step
230
-
231
- # Initial state: Prompt + all masks for generation part
232
- initial_gen_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
233
- # Apply initial constraints to the masked part *before* the first visualization state
234
- for pos, token_id in processed_constraints.items():
235
- absolute_pos = pos # Position relative to start of generation
236
- if 0 <= absolute_pos < gen_length:
237
- initial_gen_part[0, absolute_pos] = token_id
238
-
239
- # Create the first visualization state (only the generation part)
240
- initial_state_vis = []
241
- for i in range(gen_length):
242
- token_id = initial_gen_part[0, i].item()
243
- if token_id == MASK_ID:
244
- initial_state_vis.append((MASK_TOKEN, "#444444")) # Mask color
245
- else:
246
- # This must be a constraint applied initially
247
- # Decode without skipping special to see raw constraint if needed
248
- token_str = tokenizer.decode([token_id], skip_special_tokens=False).strip()
249
- initial_state_vis.append((token_str if token_str else "?", "#800080")) # Constraint color (purple)
250
- visualization_states.append(initial_state_vis)
251
 
252
  # --- Define the Hook Function ---
253
  def generation_tokens_hook_func(step, x, logits):
254
- nonlocal last_x, visualization_states # Allow modification of outer scope variables
255
- # print(f"Hook step {step}") # Keep console less noisy
256
-
257
- current_x = x.clone() # Full sequence (prompt + generation) at this step
258
-
259
- # 1. Apply Constraints to the current sequence
260
- constrained_x = current_x.clone()
261
- current_prompt_len = current_x.shape[1] - gen_length # Recalculate prompt length based on current x
262
- if current_prompt_len < 0:
263
- print(f"Warning: prompt_len {current_prompt_len} negative in hook step {step}, skipping constraints/vis.")
264
- return current_x # Return unmodified if something is wrong
265
-
266
- for pos, token_id in processed_constraints.items():
267
- # pos is relative to the start of the *generation* part
268
- absolute_pos = current_prompt_len + pos
269
- # Ensure position is within the bounds of the *current* sequence 'x'
270
- if current_prompt_len <= absolute_pos < current_x.shape[1]:
271
- if constrained_x[0, absolute_pos] != token_id:
272
- constrained_x[0, absolute_pos] = token_id
273
- # print(f" Constraint enforced at pos {pos} ({absolute_pos}) -> {token_id}")
274
-
275
-
276
- # 2. Generate Visualization State for *this* step (generation part only)
277
- current_state_vis = []
278
- # Compare current_x (before explicit constraint application in *this* hook call)
279
- # with last_x (state from *previous* hook call / initial state)
280
- gen_part_current = current_x[0, current_prompt_len:]
281
- # Ensure last_x exists and has the same shape for comparison
282
- gen_part_last = last_x[0, current_prompt_len:] if (last_x is not None and last_x.shape == current_x.shape) else None
283
-
284
- for i in range(gen_length):
285
- # Ensure index i is valid for the current generation part
286
- if i >= gen_part_current.shape[0]:
287
- print(f"Warning: Index {i} out of bounds for gen_part_current (shape {gen_part_current.shape}) in step {step}.")
288
- continue # Skip if index is invalid
289
-
290
- current_token_id = gen_part_current[i].item()
291
- # Handle case where last_x was None or had different shape
292
- last_token_id = gen_part_last[i].item() if gen_part_last is not None and i < gen_part_last.shape[0] else MASK_ID # Assume mask initially
293
-
294
- is_constrained = i in processed_constraints
295
- is_special = current_token_id in special_token_ids_set
296
- is_mask = current_token_id == MASK_ID
297
- was_mask = last_token_id == MASK_ID or last_x is None # Treat first step as coming from mask
298
-
299
- display_token = ""
300
- color = ""
301
-
302
- # Determine display token and color based on state transitions
303
- if is_mask:
304
- display_token = MASK_TOKEN
305
- color = "#444444" # Dark Gray
306
- elif is_constrained and processed_constraints[i] == current_token_id:
307
- # Always show the constrained token, color purple
308
- # Decide whether to show raw special tokens when constrained
309
- raw_decode = tokenizer.decode([current_token_id], skip_special_tokens=False).strip()
310
- display_token = raw_decode if raw_decode else "?"
311
- color = "#800080" # Purple
312
- elif is_special:
313
- if was_mask:
314
- # Newly revealed special token: Show its representation once
315
- display_token = tokenizer.decode([current_token_id], skip_special_tokens=False).strip() # Show raw special
316
- color = "#FF8C00" # DarkOrange
317
  else:
318
- # Already revealed special token: Hide it by showing a space
319
- display_token = " " # Effectively hides it
320
- color = "#6699CC" # Use 'Old' color (Light Blue) but content is hidden space
321
- elif was_mask:
322
- # Newly revealed normal token
323
- display_token = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
324
- color = "#66CC66" # Light Green
325
- else:
326
- # Previously revealed normal token
327
- display_token = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
328
- color = "#6699CC" # Light Blue
329
-
330
- # Fallback for empty decodes of non-special, non-mask tokens
331
- if not display_token and not is_mask and not (is_special and not was_mask):
332
- display_token = "?" # Use question mark for unexpected empty decodes
333
-
334
- current_state_vis.append((display_token, color))
335
-
336
- visualization_states.append(current_state_vis)
337
-
338
- # 3. Update last_x for the *next* step's comparison
339
- # Store the state *after* applying constraints for accurate comparison next time
340
- last_x = constrained_x.clone()
341
-
342
- # 4. Return the sequence with constraints applied for the model's next step
343
- return constrained_x # Return the sequence with constraints enforced
344
-
345
-
346
- # --- Run DREAM Generation ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  try:
348
- print("Calling model.diffusion_generate...")
349
- # Make sure last_x is initialized correctly before the first hook call
350
- # It should represent the state *before* the first diffusion step.
351
- initial_full_x = torch.cat([input_ids, initial_gen_part], dim=1)
352
- last_x = initial_full_x.clone() # Initialize last_x with prompt + initial masked/constrained gen part
353
-
354
  output = model.diffusion_generate(
355
- input_ids=input_ids,
356
- attention_mask=attention_mask, # Pass the correct attention mask
357
- max_new_tokens=gen_length,
358
- output_history=False, # We build history in the hook
359
  return_dict_in_generate=True,
360
  steps=steps,
361
  temperature=temperature,
362
  top_p=top_p,
363
  alg=alg,
364
- # alg_temp is only relevant for confidence-based algs (not 'origin')
365
- alg_temp=alg_temp if alg != "origin" else 0.0,
366
  generation_tokens_hook_func=generation_tokens_hook_func
367
- # Ensure generation doesn't run past eos_token if not desired
368
- # eos_token_id=eos_token_id, # This might stop generation early
369
- # pad_token_id=tokenizer.eos_token_id # Often pad is same as eos for LLMs
370
  )
371
- print("model.diffusion_generate finished.")
372
-
373
- # Extract final generated sequence (response part only)
374
- # The hook ensures the returned sequence has constraints applied
375
- final_sequence = output.sequences[0]
376
- # Handle potential length mismatch if generation stopped early
377
- actual_gen_len = final_sequence.shape[0] - prompt_length
378
- response_token_ids = final_sequence[prompt_length:]
379
-
380
- # Decode the final response, skipping special tokens like EOS/PAD
381
- final_text = tokenizer.decode(
382
- response_token_ids,
383
- skip_special_tokens=True,
384
- clean_up_tokenization_spaces=True # Recommended for cleaner output
385
- ).strip()
386
- print(f"Final generated text: '{final_text}'")
387
-
388
- # Add the very final state to visualization if the hook didn't capture it
389
- # (Mainly a safeguard, hook should run 'steps' times or until completion)
390
- if len(visualization_states) <= steps: # Hook might run 'steps' times
391
- final_state_vis = []
392
- final_gen_part = response_token_ids # Use the extracted response tokens
393
-
394
- for i in range(len(final_gen_part)): # Iterate over actual generated tokens
395
- token_id = final_gen_part[i].item()
396
- is_constrained = i in processed_constraints
397
- is_special = token_id in special_token_ids_set
398
- is_mask = token_id == MASK_ID # Should not happen in final output
399
-
400
- display_token = ""
401
- color = ""
402
-
403
- if is_mask: color = "#444444"; display_token = MASK_TOKEN
404
- elif is_constrained and processed_constraints.get(i) == token_id:
405
- raw_decode = tokenizer.decode([token_id], skip_special_tokens=False).strip()
406
- display_token = raw_decode if raw_decode else "?"; color = "#800080" # Purple
407
- elif is_special:
408
- # Hide special tokens in the *final* display state for cleaner look
409
- display_token = " "; color = "#6699CC" # Hide as 'Old' blue
410
- else:
411
- display_token = tokenizer.decode([token_id], skip_special_tokens=True).strip()
412
- color = "#6699CC" # Final state uses 'Old' blue
413
-
414
- if not display_token: display_token = "?" # Fallback
415
- final_state_vis.append((display_token, color))
416
-
417
- # Pad the final state visualization if actual gen len < requested gen_length
418
- # This shouldn't be necessary if HighlightedText handles shorter lists
419
- # while len(final_state_vis) < gen_length:
420
- # final_state_vis.append((" ", "#FFFFFF")) # Add empty space
421
-
422
- if final_state_vis: # Only append if we generated something
423
- visualization_states.append(final_state_vis)
424
 
 
 
 
 
 
 
 
 
 
425
 
426
  except Exception as e:
427
- print(f"\n--- Error during generation ---")
428
  import traceback
429
  traceback.print_exc()
430
- # Add error message to visualization
431
- error_msg = f"Generation Error: Check Logs"
432
- # Append error to visualization states if possible
433
- visualization_states.append([("Error", "red")])
434
- final_text = f"Generation failed: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
- print("--- DREAM Generation Finished ---\n")
437
- # Ensure we always return a list (even if empty) and a string
438
- if not isinstance(visualization_states, list): visualization_states = []
439
- if not isinstance(final_text, str): final_text = str(final_text)
440
 
441
- return visualization_states, final_text
442
 
 
443
 
444
- # --- Gradio UI Setup ---
 
445
 
446
  css = '''
447
  .category-legend{display:none}
448
- /* Increase overall base font size */
449
- body, .gradio-container { font-size: 105%; }
450
- /* Make buttons slightly larger */
451
- /* button { min-height: 40px; } */
452
- .small_btn {
453
- min-width: 60px; /* Adjust as needed */
454
- max-width: 100px;
455
- height: 42px; /* Adjust height */
456
- flex-grow: 0 !important; /* Prevent button from growing */
457
- margin-left: 5px !important; /* Add some space */
458
- font-size: 100%; /* Match button font size */
459
- padding: 0 10px !important; /* Adjust padding */
460
- }
461
- .chat-input-row {
462
- display: flex;
463
- align-items: center; /* Vertically align items */
464
- margin-top: 10px; /* Add space above input row */
465
- }
466
- /* Ensure Textbox takes up space */
467
- .chat-input-row .gr-textbox {
468
- flex-grow: 1;
469
- margin-right: 5px;
470
- }
471
- /* Chatbot styling */
472
- .gr-chatbot .message {
473
- font-size: 100%; /* Ensure chat message font size is reasonable */
474
- padding: 10px !important;
475
- border-radius: 8px !important;
476
- }
477
- .gr-chatbot .message.user { background-color: #E0F7FA !important; align-self: flex-end; } /* Light cyan for user */
478
- .gr-chatbot .message.bot { background-color: #F1F8E9 !important; align-self: flex-start; } /* Light green for bot */
479
- /* HighlightedText styling */
480
- .gr-highlightedtext span {
481
- padding: 1px 2px; /* Minimal padding */
482
- margin: 0 1px; /* Minimal margin */
483
- border-radius: 3px;
484
- font-family: monospace; /* Use monospace font for better alignment */
485
- font-size: 95%; /* Slightly smaller font for dense vis */
486
- line-height: 1.4; /* Adjust line spacing */
487
- }
488
- .gr-highlightedtext {
489
- padding: 10px;
490
- border: 1px solid #E0E0E0;
491
- border-radius: 5px;
492
- background-color: #FAFAFA; /* Light background for the container */
493
- }
494
- /* Legend Styling */
495
- .legend {
496
- font-size: 90%;
497
- margin-top: 5px;
498
- color: #555;
499
- }
500
- .legend span {
501
- display: inline-block; /* Keep legend items inline */
502
- margin-right: 10px;
503
- white-space: nowrap; /* Prevent wrapping */
504
- }
505
- .legend span::before { /* Style the color square */
506
- content: '■';
507
- display: inline-block;
508
- margin-right: 4px;
509
- font-size: 120%; /* Make square slightly larger */
510
- vertical-align: middle; /* Align square with text */
511
- }
512
  '''
513
  def create_chatbot_demo():
514
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
515
- gr.Markdown("## Dream 7B - Diffusion Language Model Demo")
516
- gr.Markdown("Interact with the Dream 7B instruction-tuned model and watch the diffusion process unfold step-by-step. "
517
- "You can optionally constrain specific words at certain positions.")
518
- with gr.Row():
519
- gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)", scale=1)
520
- gr.Markdown("[Blog Post](https://hkunlp.github.io/blog/2025/dream/)", scale=1)
521
 
522
  # STATE MANAGEMENT
523
- chat_history = gr.State([]) # Stores conversation [[user, bot], ...]
524
 
525
- # UI LAYOUT
526
  with gr.Row():
527
- # Left Column: Chat Interface
528
  with gr.Column(scale=3):
529
- chatbot_ui = gr.Chatbot(
530
- label="Conversation",
531
- height=550,
532
- bubble_full_width=False,
533
- show_copy_button=True,
534
- render=False # Rendered explicitly later for streaming
535
- )
536
- chatbot_ui.render() # Manually render after setting parameters
537
 
538
- # Message input Row
539
- with gr.Row(elem_classes="chat-input-row"):
 
540
  user_input = gr.Textbox(
541
  label="Your Message",
542
- placeholder="Type your message and press Enter, or click Send...",
543
- scale=4, # Give textbox more space relative to button
544
- container=False,
545
- show_label=False
546
  )
547
- send_btn = gr.Button("Send", scale=1, elem_classes="small_btn", variant="primary")
548
 
549
  constraints_input = gr.Textbox(
550
  label="Word Constraints (Optional)",
551
- info="Force words at positions (0-indexed from response start). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
552
- placeholder="e.g., 0:Hello, 6:world",
553
- lines=1
554
  )
555
-
556
- # Right Column: Visualization and Settings
557
  with gr.Column(scale=2):
558
- gr.Markdown("### Denoising Process Visualization")
559
  output_vis = gr.HighlightedText(
560
- label="Generation Steps",
561
- show_label=False, # Label provided by Markdown above
562
  combine_adjacent=False,
563
- show_legend=False, # Using custom HTML legend below
564
- # color_map is not directly used due to show_legend=False, but useful for reference
565
- color_map={
566
- "Mask": "#444444",
567
- "New": "#66CC66",
568
- "Old": "#6699CC",
569
- "Constraint": "#800080",
570
- "Special (New)": "#FF8C00",
571
- "Error": "red"
572
- }
573
  )
574
- # Custom HTML Legend
575
- gr.HTML(
576
- """
577
- <div class='legend'>
578
- <span style="color:#444444;">■ Mask</span> |
579
- <span style='color:#66CC66;'>■ New</span> |
580
- <span style='color:#FF8C00;'>■ Special (New)</span> |
581
- <span style='color:#6699CC;'>■ Old</span> |
582
- <span style='color:#800080;'>■ Constraint</span>
583
- </div>
584
- """,
585
- elem_id="legend-html"
586
  )
587
 
588
- # Generation Settings Accordion
589
- with gr.Accordion("Generation Settings", open=False):
590
- gen_length = gr.Slider(
591
- minimum=16, maximum=512, value=128, step=16,
592
- label="Max New Tokens", info="Max response length."
593
- )
594
- steps = gr.Slider(
595
- minimum=8, maximum=512, value=128, step=8,
596
- label="Diffusion Steps", info="More steps = finer generation (potentially slower)."
597
- )
598
- temperature = gr.Slider(
599
- minimum=0.0, maximum=1.5, value=0.6, step=0.05,
600
- label="Temperature", info="Controls randomness. Lower=more deterministic."
601
- )
602
- top_p = gr.Slider(
603
- minimum=0.0, maximum=1.0, value=0.95, step=0.05,
604
- label="Top-P (Nucleus)", info="Filters vocabulary probabilistically. Lower=less diverse."
605
- )
606
- # Map UI choices to DREAM's alg parameters
607
- remasking_strategy = gr.Radio(
608
- choices=[
609
- ("Random", "origin"), # User friendly name -> actual param
610
- ("Entropy", "entropy"),
611
- ("MaskGit+", "maskgit_plus"),
612
- ("TopK Margin", "topk_margin"),
613
- ],
614
- value="entropy", # Default
615
- label="Generation Order Strategy (alg)",
616
- info="How the model decides which tokens to generate first."
617
- )
618
- alg_temp = gr.Slider(
619
- minimum=0.0, maximum=1.0, value=0.1, step=0.05,
620
- label="Order Randomness (alg_temp)" ,
621
- info="Adds randomness to confidence-based strategies (Entropy, MaskGit+, TopK). Ignored for Random."
622
- )
623
- visualization_delay = gr.Slider(
624
- minimum=0.0, maximum=0.5, value=0.05, step=0.01,
625
- label="Visualization Delay (sec)", info="Pause between steps in visualization."
626
- )
627
-
628
- # Clear button Row
629
- with gr.Row():
630
- clear_btn = gr.Button("Clear Conversation", variant="stop", icon="🗑️")
631
-
632
-
633
- # --- Event Handlers ---
634
-
635
- # Helper to add message to history state
636
- def add_message_to_history(history_state, user_message, bot_message):
637
- # history_state is the raw list from gr.State
638
- history_state.append([user_message, bot_message])
639
- return history_state
640
-
641
- # Function when user submits message (Enter or Send button)
642
- def handle_user_message(message, history_state):
643
- print(f"User submitted: '{message}'")
644
- if not message or not message.strip():
645
- print("Empty message submitted, doing nothing.")
646
- # Return unchanged state if message is empty
647
- # Need to return values for all outputs of the .submit/.click
648
- # history_state, chatbot_ui, user_input, output_vis
649
- return history_state, history_state, "", [] # No change to chatbot UI yet
650
-
651
- # Add user message to history state (with None for bot response initially)
652
- updated_history_state = add_message_to_history(history_state, message, None)
653
-
654
- # Prepare updated history for display in Chatbot UI
655
- # We only display the user message now, bot response comes later
656
- chatbot_display = updated_history_state.copy()
657
-
658
- # Clear the input textbox and visualization
659
- return updated_history_state, chatbot_display, "", []
660
-
661
- # Function to generate bot response (triggered after user message is handled)
662
- # Uses yield for streaming visualization updates
663
- def generate_bot_response(
664
- history_state, # The current state [[user, None], ...]
665
- gen_length_val, steps_val, constraints_text, delay_val,
666
- temperature_val, top_p_val, alg_val, alg_temp_val
667
- ):
668
- print("\n--- Streaming Bot Response ---")
669
- if not history_state or history_state[-1][1] is not None:
670
- print("History empty or last message already has response. Skipping generation.")
671
- # Yield current state if called unnecessarily
672
- yield history_state, [] # Chatbot UI, Visualization
673
- return
674
-
675
- # Get the conversation history in the format the model expects
676
- messages_for_model = format_chat_history(history_state) # Includes the latest user query
677
-
678
- # Parse constraints from the textbox
679
- parsed_constraints = parse_constraints(constraints_text)
680
-
681
- # Generate response with visualization (this function handles the core logic)
682
- vis_states, response_text = dream_generate_response_with_visualization(
683
- messages_for_model,
684
- gen_length=gen_length_val,
685
- steps=steps_val,
686
- constraints=parsed_constraints,
687
- temperature=temperature_val,
688
- top_p=top_p_val,
689
- alg=alg_val,
690
- alg_temp=alg_temp_val
691
  )
692
 
693
- # Update the history state with the final bot response (critical!)
694
- history_state[-1][1] = response_text.strip()
695
 
696
- # Stream the updates
697
- if vis_states:
698
- # Yield the initial visualization state first
699
- yield history_state, vis_states[0] # Update chatbot UI (implicitly via history), update visualization
 
700
 
701
- # Then animate through the rest of the visualization states
702
- for state in vis_states[1:]:
703
- time.sleep(delay_val)
704
- yield history_state, state # Update chatbot UI, update visualization
705
- else:
706
- # Handle case where generation failed or produced no visualization
707
- print("Warning: No visualization states generated.")
708
- yield history_state, [("No visualization generated.", "orange")] # Update chatbot UI, show warning in vis
709
-
710
- print("--- Streaming Complete ---")
711
-
712
-
713
- # Function to clear everything
714
- def clear_conversation_state():
715
- print("Clearing conversation.")
716
- # Reset state and UI components
717
- return [], [], "", [] # chat_history (State), chatbot_ui, user_input, output_vis
718
-
719
- # --- Wire UI elements to functions ---
720
-
721
- # Define shared inputs for generation to avoid repetition
722
- generation_inputs = [
723
- chat_history, gen_length, steps, constraints_input, visualization_delay,
724
- temperature, top_p, remasking_strategy, alg_temp
725
- ]
726
- # Define shared outputs for streaming
727
- streaming_outputs = [chatbot_ui, output_vis]
728
-
729
- # Typing in Textbox and pressing Enter
730
- user_input.submit(
731
- fn=handle_user_message,
732
- inputs=[user_input, chat_history],
733
- outputs=[chat_history, chatbot_ui, user_input, output_vis], # Update history state, chatbot display, clear input, clear vis
734
- queue=False # Process user input immediately
735
- ).then(
736
- fn=generate_bot_response,
737
- inputs=generation_inputs,
738
- outputs=streaming_outputs, # Stream updates to chatbot and visualization
739
- #api_name="generate_stream" # Optional: Name for API endpoint
740
- )
741
 
742
- # Clicking the Send button
743
- send_btn.click(
744
- fn=handle_user_message,
745
- inputs=[user_input, chat_history],
746
- outputs=[chat_history, chatbot_ui, user_input, output_vis],
747
- queue=False
748
- ).then(
749
- fn=generate_bot_response,
750
- inputs=generation_inputs,
751
- outputs=streaming_outputs,
752
- # api_name="generate_stream_click" # Optional
753
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
 
755
- # Clicking the Clear button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
  clear_btn.click(
757
- fn=clear_conversation_state,
758
  inputs=[],
759
- outputs=[chat_history, chatbot_ui, user_input, output_vis],
760
- queue=False # Clearing should be instant
761
  )
762
 
763
  return demo
764
 
765
- # --- Launch the Gradio App ---
766
  if __name__ == "__main__":
767
- print("Creating Gradio demo...")
768
- gradio_demo = create_chatbot_demo()
769
- print("Launching Gradio demo...")
770
- # Use queue() for handling concurrent users and potentially long generation times
771
- # share=True generates a public link (useful for Colab/Spaces)
772
- # debug=True provides helpful Gradio logs in the console
773
- gradio_demo.queue().launch(share=True, debug=False) # Set debug=True for more verbose logs if needed
 
1
+ # dream_app.py
2
  import torch
3
+ import numpy as np
4
  import gradio as gr
5
  import spaces
 
6
  import time
7
+ import re
8
+ from transformers import AutoModel, AutoTokenizer
9
+ from threading import Lock
10
+ from queue import Queue
11
+
12
+ # --- Configuration ---
13
+ MODEL_PATH = "Dream-org/Dream-v0-Instruct-7B"
14
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ print(f"Using device: {DEVICE}")
16
+
17
+ # --- Load Model and Tokenizer ---
18
+ print("Loading model and tokenizer...")
19
+ # Need configuration files for trust_remote_code
20
+ # Make sure config.json, configuration_dream.py, modeling_dream.py,
21
+ # generation_utils.py, generation_config.json are in the same directory
22
+ # or accessible in the Hugging Face cache.
23
+ model = AutoModel.from_pretrained(
24
+ MODEL_PATH,
25
+ torch_dtype=torch.bfloat16,
26
+ trust_remote_code=True
27
+ ).to(DEVICE).eval()
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ MODEL_PATH,
30
+ trust_remote_code=True
31
+ )
32
+ print("Model and tokenizer loaded.")
33
+
34
+ # --- Constants ---
35
+ # Get IDs from tokenizer/config if possible, otherwise hardcode from provided files
36
+ MASK_TOKEN = tokenizer.mask_token # Should be "<|mask|>"
37
  try:
38
+ MASK_ID = tokenizer.mask_token_id # Should be 151666
39
+ if MASK_ID is None: raise AttributeError # Handle case where it might not be set directly
 
40
  except AttributeError:
41
+ print("Warning: Could not directly get mask_token_id, using hardcoded value 151666.")
42
+ MASK_ID = 151666
43
+
 
 
 
 
 
 
 
 
 
 
44
  try:
45
+ EOS_ID = tokenizer.eos_token_id # Should be 151643
46
+ PAD_ID = tokenizer.pad_token_id # Should be 151643
47
+ if EOS_ID is None or PAD_ID is None: raise AttributeError
48
+ except AttributeError:
49
+ print("Warning: Could not directly get eos/pad_token_id, using hardcoded value 151643.")
50
+ EOS_ID = 151643
51
+ PAD_ID = 151643
52
+
53
+ # Ensure MASK_TOKEN and MASK_ID are valid
54
+ if MASK_TOKEN is None or MASK_ID is None:
55
+ raise ValueError("Mask token or ID is not defined correctly.")
56
+ if EOS_ID is None or PAD_ID is None:
57
+ raise ValueError("EOS/PAD token or ID is not defined correctly.")
58
+
59
+ # --- Helper Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def parse_constraints(constraints_text):
62
  """Parse constraints in format: 'position:word, position:word, ...'"""
 
66
 
67
  parts = constraints_text.split(',')
68
  for part in parts:
 
69
  if ':' not in part:
70
  continue
71
  try:
72
  pos_str, word = part.split(':', 1)
73
  pos = int(pos_str.strip())
74
  word = word.strip()
 
75
  if word and pos >= 0:
76
+ # Tokenize the word - handle potential multi-token words
77
+ # Add space prefix for consistency, similar to how model might see words mid-sentence
78
+ tokens = tokenizer.encode(" " + word, add_special_tokens=False)
79
+ for i, token_id in enumerate(tokens):
80
+ constraints[pos + i] = token_id
81
  except ValueError:
82
+ continue
83
+ except Exception as e:
84
+ print(f"Error parsing constraint part '{part}': {e}")
85
  continue
86
 
87
  return constraints
88
 
89
  def format_chat_history(history):
90
  """
91
+ Format chat history for the Dream model using its chat template logic.
92
 
93
  Args:
94
  history: List of [user_message, assistant_message] pairs
95
 
96
  Returns:
97
+ Formatted list of message dictionaries for the model
98
  """
99
  messages = []
100
+ # Add system prompt if history is empty or doesn't start with system
101
+ if not history or history[0][0].lower() != 'system':
102
+ # Check if the tokenizer's template expects an explicit system message
103
+ # The template provided in tokenizer_config.json handles adding a default one
104
+ pass # Let apply_chat_template handle the default system message
105
+
106
  for user_msg, assistant_msg in history:
107
+ if user_msg: # Handle potential initial system message possibility if needed
108
  messages.append({"role": "user", "content": user_msg})
109
+ if assistant_msg is not None: # Skip if None (for the latest user message)
110
  messages.append({"role": "assistant", "content": assistant_msg})
111
 
112
  return messages
113
 
114
+ # --- Core Generation Logic with Visualization ---
115
 
116
+ # Use a thread-safe queue to pass visualization states from the hook
117
+ vis_queue = Queue()
118
+ # Lock to prevent race conditions when accessing shared state like previous_x
119
+ state_lock = Lock()
120
+ # Store the previous state for comparison in the hook
121
+ previous_x_shared = None
122
+
123
+ @spaces.GPU
124
+ def generate_response_with_visualization(
125
+ messages, # List of message dicts from format_chat_history
126
+ max_new_tokens=64,
127
+ steps=64, # Default steps based on README example
128
  constraints=None,
129
+ temperature=0.6, # Default from demo_token_control
130
+ top_p=0.95, # Default from demos
131
+ alg="entropy", # Default from demos
132
+ alg_temp=0.1, # Default from demo_multiturn_chat
133
  ):
134
  """
135
+ Generate text with Dream model and capture visualization states using a hook.
136
 
137
  Args:
138
+ messages: List of message dictionaries with 'role' and 'content'.
139
+ max_new_tokens: Max tokens to generate.
140
+ steps: Diffusion steps.
141
+ constraints: Dictionary mapping positions (relative to response start) to token IDs.
142
+ temperature: Sampling temperature.
143
+ top_p: Nucleus sampling p.
144
+ alg: Remasking algorithm ('origin', 'entropy', 'maskgit_plus', 'topk_margin').
145
+ alg_temp: Temperature for confidence-based algorithms.
146
 
147
  Returns:
148
  Tuple: (List of visualization states, final generated text string)
149
  """
150
+ global previous_x_shared, vis_queue
 
 
 
 
151
  if constraints is None:
152
  constraints = {}
153
 
154
+ visualization_states = []
155
+
156
+ # Clear the queue for a new generation
157
+ while not vis_queue.empty():
158
+ try:
159
+ vis_queue.get_nowait()
160
+ except Queue.Empty:
161
+ break
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # Prepare the prompt using chat template
164
+ # The template automatically adds the generation prompt like "<|im_start|>assistant\n"
165
  try:
166
  inputs = tokenizer.apply_chat_template(
167
  messages,
168
  return_tensors="pt",
169
+ add_generation_prompt=True,
170
+ return_dict=True
171
  )
172
+ input_ids = inputs.input_ids.to(device=DEVICE)
173
+ # Dream doesn't seem to explicitly use attention_mask in simple demos,
174
+ # but it's good practice if padding were involved.
175
+ # For now, assume no padding in this interactive demo.
176
+ attention_mask = inputs.attention_mask.to(device=DEVICE) if 'attention_mask' in inputs else None
177
+
 
178
  except Exception as e:
179
  print(f"Error applying chat template: {e}")
180
+ # Provide a fallback or error state
181
+ error_state = [("Error in chat formatting.", "red")]
182
+ return [error_state], f"Error: Could not format chat history. {e}"
183
+
184
+ prompt_length = input_ids.shape[1]
185
+ total_length = prompt_length + max_new_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  # --- Define the Hook Function ---
188
  def generation_tokens_hook_func(step, x, logits):
189
+ global previous_x_shared, vis_queue
190
+ with state_lock: # Ensure thread safety if needed, though hooks might run sequentially
191
+ current_x = x.clone() # Shape: (batch_size, total_length)
192
+
193
+ # --- Apply Constraints ---
194
+ # Constraints are relative to the start of the *response*
195
+ for rel_pos, token_id in constraints.items():
196
+ abs_pos = prompt_length + rel_pos
197
+ if 0 <= abs_pos < current_x.shape[1]:
198
+ # Ensure constraint application doesn't go out of bounds
199
+ # Apply constraint for the first batch element (batch size is 1 here)
200
+ current_x[0, abs_pos] = token_id
201
+
202
+ # --- Create Visualization State ---
203
+ current_vis_state = []
204
+ x_response = current_x[0, prompt_length:] # Get the response part for batch 0
205
+ prev_x_response = previous_x_shared[0, prompt_length:] if previous_x_shared is not None else None
206
+
207
+ for i in range(max_new_tokens):
208
+ current_token_id = x_response[i].item()
209
+ token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for vis
210
+
211
+ # Clean up visual representation of special tokens
212
+ if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
213
+ token_str = "[EOS/PAD]" # Make it visually distinct
214
+ elif token_str == tokenizer.mask_token:
215
+ token_str = "[MASK]"
216
+ elif token_str.strip() == "": # Handle empty strings from decoding potentially odd tokens
217
+ token_str = "[UNK/SPACE]"
218
+
219
+
220
+ color = "#DDDDDD" # Default background
221
+
222
+ if current_token_id == MASK_ID:
223
+ color = "#444444" # Dark gray for masks
224
+ elif prev_x_response is not None and prev_x_response[i].item() == MASK_ID:
225
+ # Token was mask, now it's revealed in this step
226
+ # Use green for newly revealed
227
+ color = "#66CC66" # Light green
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  else:
229
+ # Token was already revealed in a previous step or is a constraint
230
+ # Check if it's a constraint applied *now*
231
+ is_constraint = (prompt_length + i - prompt_length) in constraints and \
232
+ constraints[prompt_length + i - prompt_length] == current_token_id
233
+
234
+ if is_constraint:
235
+ color = "#FFD700" # Gold for constraints
236
+ else:
237
+ color = "#6699CC" # Light blue for previously revealed
238
+
239
+ current_vis_state.append((token_str, color))
240
+
241
+ # --- Update shared state and put vis state in queue ---
242
+ previous_x_shared = current_x.clone() # Update for the *next* step's comparison
243
+ vis_queue.put(current_vis_state)
244
+
245
+ # The hook must return the potentially modified tensor `x`
246
+ return current_x
247
+ # --- End of Hook Function ---
248
+
249
+ # Initialize previous_x_shared before generation starts
250
+ # Create initial masked state for visualization
251
+ initial_x = input_ids.clone()
252
+ if initial_x.shape[1] < total_length:
253
+ padding = torch.full((1, total_length - initial_x.shape[1]), MASK_ID, dtype=torch.long, device=DEVICE)
254
+ initial_x = torch.cat([initial_x, padding], dim=1)
255
+ else:
256
+ initial_x = initial_x[:, :total_length] # Truncate if prompt is too long
257
+
258
+ # Apply initial constraints to the starting state
259
+ for rel_pos, token_id in constraints.items():
260
+ abs_pos = prompt_length + rel_pos
261
+ if 0 <= abs_pos < initial_x.shape[1]:
262
+ initial_x[0, abs_pos] = token_id
263
+
264
+ with state_lock:
265
+ previous_x_shared = initial_x.clone()
266
+
267
+ # Add the initial all-masked state (or with constraints) to the visualization queue
268
+ initial_vis_state = []
269
+ initial_x_response = initial_x[0, prompt_length:]
270
+ for i in range(max_new_tokens):
271
+ token_id = initial_x_response[i].item()
272
+ if token_id == MASK_ID:
273
+ initial_vis_state.append((MASK_TOKEN, "#444444"))
274
+ else:
275
+ # Must be a pre-applied constraint
276
+ token_str = tokenizer.decode([token_id], skip_special_tokens=False)
277
+ if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
278
+ token_str = "[EOS/PAD]"
279
+ elif token_str.strip() == "":
280
+ token_str = "[UNK/SPACE]"
281
+ initial_vis_state.append((token_str, "#FFD700")) # Gold for constraints
282
+ vis_queue.put(initial_vis_state)
283
+
284
+
285
+ # --- Run Generation ---
286
  try:
287
+ # output_history=False because the hook handles state capture
288
+ # return_dict_in_generate=True to get the GenerationOutput object
 
 
 
 
289
  output = model.diffusion_generate(
290
+ initial_x, # Start with the potentially constraint-applied tensor
291
+ attention_mask=None, # Assuming no padding needed for interactive use
292
+ max_new_tokens=max_new_tokens, # This might not be strictly needed if total_length is fixed
293
+ output_history=False,
294
  return_dict_in_generate=True,
295
  steps=steps,
296
  temperature=temperature,
297
  top_p=top_p,
298
  alg=alg,
299
+ alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence algs
 
300
  generation_tokens_hook_func=generation_tokens_hook_func
 
 
 
301
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ final_sequence = output.sequences[0] # Batch size 1
304
+
305
+ # Decode the final response text, cleaning up special tokens
306
+ response_tokens = final_sequence[prompt_length:]
307
+ # Filter out EOS/PAD tokens for the final text display
308
+ response_tokens_filtered = [tok for tok in response_tokens.tolist() if tok != EOS_ID and tok != PAD_ID]
309
+ final_text = tokenizer.decode(response_tokens_filtered,
310
+ skip_special_tokens=True,
311
+ clean_up_tokenization_spaces=True) # Standard cleanup
312
 
313
  except Exception as e:
314
+ print(f"Error during generation: {e}")
315
  import traceback
316
  traceback.print_exc()
317
+ # Provide error state
318
+ error_state = [("Generation Error.", "red")]
319
+ visualization_states.append(error_state)
320
+ final_text = f"Error: Generation failed. {e}"
321
+ # Add any states captured before the error
322
+ while not vis_queue.empty():
323
+ try:
324
+ visualization_states.append(vis_queue.get_nowait())
325
+ except Queue.Empty:
326
+ break
327
+ return visualization_states, final_text
328
+
329
+ # Retrieve all visualization states captured by the hook
330
+ while not vis_queue.empty():
331
+ try:
332
+ visualization_states.append(vis_queue.get_nowait())
333
+ except Queue.Empty:
334
+ break
335
 
336
+ # If somehow no states were captured, add the initial one
337
+ if not visualization_states:
338
+ visualization_states.append(initial_vis_state)
 
339
 
 
340
 
341
+ return visualization_states, final_text.strip()
342
 
343
+
344
+ # --- Gradio UI ---
345
 
346
  css = '''
347
  .category-legend{display:none}
348
+ button{height: 60px}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  '''
350
  def create_chatbot_demo():
351
+ with gr.Blocks(css=css) as demo:
352
+ gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
353
+ gr.Markdown("Chat with the Dream 7B Instruct model and visualize the diffusion generation process.")
354
+ gr.Markdown("Model: [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)")
 
 
 
355
 
356
  # STATE MANAGEMENT
357
+ chat_history = gr.State([])
358
 
359
+ # UI COMPONENTS
360
  with gr.Row():
 
361
  with gr.Column(scale=3):
362
+ chatbot_ui = gr.Chatbot(label="Conversation", height=500, avatar_images=["user.png", "robot.png"])
 
 
 
 
 
 
 
363
 
364
+ # Message input
365
+ with gr.Group():
366
+ with gr.Row():
367
  user_input = gr.Textbox(
368
  label="Your Message",
369
+ placeholder="Type your message here...",
370
+ show_label=False,
371
+ scale=9
 
372
  )
373
+ send_btn = gr.Button("Send", scale=1)
374
 
375
  constraints_input = gr.Textbox(
376
  label="Word Constraints (Optional)",
377
+ info="Place words at specific positions (0-indexed from response start). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:a'",
378
+ placeholder="0:Once, 5:upon, 10:a",
379
+ value=""
380
  )
 
 
381
  with gr.Column(scale=2):
 
382
  output_vis = gr.HighlightedText(
383
+ label="Diffusion Process Visualization",
 
384
  combine_adjacent=False,
385
+ show_legend=True, # Keep legend hidden via CSS if desired
386
+ height=560 # Adjust height to match chatbot area
 
 
 
 
 
 
 
 
387
  )
388
+ # Legend (colors defined in generate_response_with_visualization)
389
+ gr.Markdown(
390
+ "<small>Color Legend: <span style='background-color:#444444; color:white;'>[MASK]</span>"
391
+ " <span style='background-color:#66CC66;'>Newly Revealed</span>"
392
+ " <span style='background-color:#6699CC;'>Previously Revealed</span>"
393
+ " <span style='background-color:#FFD700;'>Constraint</span>"
394
+ " <span style='background-color:#DDDDDD;'>[EOS/PAD/UNK]</span></small>"
 
 
 
 
 
395
  )
396
 
397
+ # Advanced generation settings
398
+ with gr.Accordion("Generation Settings", open=False):
399
+ max_new_tokens_slider = gr.Slider(
400
+ minimum=16, maximum=512, value=128, step=16, # Increased default/max
401
+ label="Max New Tokens (Generation Length)"
402
+ )
403
+ steps_slider = gr.Slider(
404
+ minimum=8, maximum=512, value=128, step=8, # Increased default/max
405
+ label="Diffusion Steps"
406
+ )
407
+ temp_slider = gr.Slider(
408
+ minimum=0.0, maximum=1.0, value=0.6, step=0.05, # Finer steps for temp
409
+ label="Temperature"
410
+ )
411
+ top_p_slider = gr.Slider(
412
+ minimum=0.0, maximum=1.0, value=0.95, step=0.05,
413
+ label="Top-P (Nucleus Sampling)"
414
+ )
415
+ alg_radio = gr.Radio(
416
+ # Choices from README
417
+ choices=['origin', 'entropy', 'maskgit_plus', 'topk_margin'],
418
+ value='entropy',
419
+ label="Remasking Algorithm"
420
+ )
421
+ alg_temp_slider = gr.Slider(
422
+ minimum=0.0, maximum=1.0, value=0.1, step=0.05,
423
+ label="Algorithm Temperature (for confidence-based algs)"
424
+ )
425
+ vis_delay_slider = gr.Slider(
426
+ minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
427
+ label="Visualization Delay (seconds)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  )
429
 
430
+ # Clear button
431
+ clear_btn = gr.Button("Clear Conversation")
432
 
433
+ # HELPER FUNCTIONS (UI Logic)
434
+ def add_message_to_history(history, message, response):
435
+ """Add a message pair to the history state"""
436
+ new_history = history + [[message, response]]
437
+ return new_history
438
 
439
+ def user_message_submitted(message, history):
440
+ """ Handle user sending a message: update history, clear input """
441
+ if not message or message.strip() == "":
442
+ return history, history, "", [] # No change if empty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
+ # Add user message, response is initially None
445
+ new_history = add_message_to_history(history, message, None)
446
+
447
+ # Prepare display version (immediately shows user message)
448
+ display_history = new_history
449
+
450
+ # Clear input box
451
+ message_out = ""
452
+
453
+ # Clear visualization
454
+ vis_out = []
455
+
456
+ return new_history, display_history, message_out, vis_out
457
+
458
+ def bot_response_generator(history, constraints_str, max_tokens, steps, temp, top_p, alg, alg_temp, delay):
459
+ """ Generator function to stream bot response and visualization """
460
+ if not history or history[-1][1] is not None: # Ensure there's a user msg waiting for response
461
+ print("Warning: Bot response triggered without pending user message.")
462
+ yield history, [], "Error: No user message to respond to." # Send error state back?
463
+ return
464
+
465
+ # Get the full conversation history formatted for the model
466
+ last_user_message = history[-1][0]
467
+ messages_for_model = format_chat_history(history[:-1]) # History *before* the last user msg
468
+ messages_for_model.append({"role": "user", "content": last_user_message})
469
+
470
+ # Parse constraints
471
+ try:
472
+ parsed_constraints = parse_constraints(constraints_str)
473
+ except Exception as e:
474
+ print(f"Error parsing constraints: {e}")
475
+ yield history, [("Constraint Error", "red")], f"Error: Failed to parse constraints: {e}"
476
+ return
477
+
478
+ # Generate response and visualization states
479
+ try:
480
+ vis_states, final_response_text = generate_response_with_visualization(
481
+ messages_for_model,
482
+ max_new_tokens=max_tokens,
483
+ steps=steps,
484
+ constraints=parsed_constraints,
485
+ temperature=temp,
486
+ top_p=top_p,
487
+ alg=alg,
488
+ alg_temp=alg_temp
489
+ )
490
+ except Exception as e:
491
+ print(f"Error in generate_response_with_visualization: {e}")
492
+ import traceback
493
+ traceback.print_exc()
494
+ yield history, [("Generation Error", "red")], f"Error: Generation failed: {e}"
495
+ return
496
 
497
+ # Update the history state with the final response *once*
498
+ history[-1][1] = final_response_text # Update the None placeholder
499
+
500
+ # Yield initial state immediately
501
+ if vis_states:
502
+ yield history, vis_states[0]
503
+ else:
504
+ yield history, [] # Should not happen if generation worked
505
+
506
+ # Stream intermediate visualization states
507
+ for state in vis_states[1:]:
508
+ time.sleep(delay)
509
+ yield history, state
510
+
511
+ # Final yield ensures the chatbot UI has the complete history
512
+ # The last state in vis_states should already be yielded by the loop
513
+ # yield history, vis_states[-1] if vis_states else []
514
+
515
+
516
+ def clear_conversation():
517
+ """Clear the conversation history and visualization"""
518
+ return [], [], "", [] # history, chatbot_ui, user_input, output_vis
519
+
520
+ # EVENT HANDLERS
521
+
522
+ # User presses Enter or Send button
523
+ submit_args = {
524
+ "fn": user_message_submitted,
525
+ "inputs": [user_input, chat_history],
526
+ "outputs": [chat_history, chatbot_ui, user_input, output_vis]
527
+ }
528
+ user_input.submit(**submit_args)
529
+ send_btn.click(**submit_args)
530
+
531
+ # After user message is submitted, trigger bot response generation
532
+ generate_args = {
533
+ "fn": bot_response_generator,
534
+ "inputs": [
535
+ chat_history, constraints_input, max_new_tokens_slider, steps_slider,
536
+ temp_slider, top_p_slider, alg_radio, alg_temp_slider, vis_delay_slider
537
+ ],
538
+ "outputs": [chatbot_ui, output_vis] # Update chatbot history and visualization
539
+ }
540
+ # Trigger generation after submit OR click
541
+ user_input.submit(None, None, None, queue=True).then(**generate_args)
542
+ send_btn.click(None, None, None, queue=True).then(**generate_args)
543
+
544
+
545
+ # Clear button handler
546
  clear_btn.click(
547
+ fn=clear_conversation,
548
  inputs=[],
549
+ outputs=[chat_history, chatbot_ui, user_input, output_vis]
 
550
  )
551
 
552
  return demo
553
 
554
+ # Launch the demo
555
  if __name__ == "__main__":
556
+ demo = create_chatbot_demo()
557
+ # queue() allows streaming and handling multiple users
558
+ # share=True creates a public link (use with caution)
559
+ demo.queue().launch(share=True, debug=True)