multimodalart HF Staff commited on
Commit
47fc4a0
·
verified ·
1 Parent(s): 9b91020

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +637 -131
app.py CHANGED
@@ -1,20 +1,159 @@
1
- # Replace the existing dream_generate_response_with_visualization function
2
- # in the previous code block with this updated version.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- @gpu_check
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def dream_generate_response_with_visualization(
6
  messages,
7
- gen_length=64,
8
- steps=64,
9
  constraints=None,
10
- temperature=0.6,
11
- top_p=0.95,
12
- alg="entropy",
13
- alg_temp=0.0,
14
  ):
15
  """
16
- Generate text with DREAM model with visualization using the generation hook,
17
- including special token handling (show once, then hide).
18
 
19
  Args:
20
  messages: List of message dictionaries with 'role' and 'content'
@@ -29,239 +168,606 @@ def dream_generate_response_with_visualization(
29
  Returns:
30
  Tuple: (List of visualization states, final generated text string)
31
  """
32
- print("--- Starting DREAM Generation ---")
33
- print(f"Parameters: gen_length={gen_length}, steps={steps}, temperature={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
34
  print(f"Constraints: {constraints}")
35
 
36
  # --- Input Preparation ---
37
  if constraints is None:
38
  constraints = {}
39
 
 
40
  processed_constraints = {}
 
41
  print("Processing constraints:")
42
  for pos, word in constraints.items():
 
 
43
  tokens = tokenizer.encode(" " + word, add_special_tokens=False)
44
  if not tokens:
45
  print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
46
  continue
47
- print(f" Pos {pos}, Word '{word}' -> Tokens {tokens}")
 
48
  for i, token_id in enumerate(tokens):
49
- if pos + i not in processed_constraints:
50
- processed_constraints[pos + i] = token_id
 
51
  else:
52
- print(f" Warning: Overlapping constraint at position {pos+i}. Keeping first.")
53
 
 
54
  try:
55
  inputs = tokenizer.apply_chat_template(
56
  messages,
57
  return_tensors="pt",
58
  return_dict=True,
59
- add_generation_prompt=True
60
  )
61
  input_ids = inputs.input_ids.to(device=device)
 
62
  attention_mask = inputs.attention_mask.to(device=device)
63
  prompt_length = input_ids.shape[1]
64
  print(f"Input prompt length: {prompt_length}")
65
- # print(f"Input IDs: {input_ids}") # Verbose
 
66
  except Exception as e:
67
  print(f"Error applying chat template: {e}")
68
  return [([("Error applying chat template.", "red")],)], f"Error: {e}"
69
 
70
- if prompt_length + gen_length > 2048:
71
- print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length (2048). Truncating gen_length.")
72
- gen_length = 2048 - prompt_length
 
 
73
  if gen_length <= 0:
74
  print("Error: Prompt is already too long.")
75
  return [([("Prompt too long.", "red")],)], "Error: Prompt too long."
76
 
77
  # --- State for Visualization Hook ---
78
  visualization_states = []
79
- last_x = None
80
-
81
- # Define special token IDs to hide after first reveal
82
- # Using a set for efficient lookup. Add others if needed (e.g., pad_token_id).
83
- special_token_ids_to_hide = {
84
- tokenizer.eos_token_id,
85
- tokenizer.pad_token_id,
86
- # tokenizer.bos_token_id # Usually not generated mid-sequence
87
- }
88
- # Filter out None values if some special tokens aren't defined
89
- special_token_ids_to_hide = {tid for tid in special_token_ids_to_hide if tid is not None}
90
- print(f"Special token IDs to hide visually after reveal: {special_token_ids_to_hide}")
91
 
92
-
93
- initial_x_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
 
94
  for pos, token_id in processed_constraints.items():
95
- absolute_pos = pos
96
  if 0 <= absolute_pos < gen_length:
97
- initial_x_part[0, absolute_pos] = token_id
98
 
 
99
  initial_state_vis = []
100
  for i in range(gen_length):
101
- token_id = initial_x_part[0, i].item()
102
  if token_id == MASK_ID:
103
  initial_state_vis.append((MASK_TOKEN, "#444444")) # Mask color
104
  else:
105
- token_str = tokenizer.decode([token_id], skip_special_tokens=True)
 
 
106
  initial_state_vis.append((token_str if token_str else "?", "#800080")) # Constraint color (purple)
107
  visualization_states.append(initial_state_vis)
108
 
109
  # --- Define the Hook Function ---
110
  def generation_tokens_hook_func(step, x, logits):
111
  nonlocal last_x, visualization_states # Allow modification of outer scope variables
112
- # print(f"Hook called for step {step}") # Verbose
 
 
113
 
114
- current_x = x.clone()
115
  constrained_x = current_x.clone()
116
- prompt_len = current_x.shape[1] - gen_length
117
- if prompt_len < 0:
118
- print("Warning: prompt_len negative in hook, skipping constraints/vis.")
119
- return current_x
120
 
121
- # 1. Apply Constraints
122
  for pos, token_id in processed_constraints.items():
123
- absolute_pos = prompt_len + pos
124
- if prompt_len <= absolute_pos < current_x.shape[1]:
 
 
125
  if constrained_x[0, absolute_pos] != token_id:
126
  constrained_x[0, absolute_pos] = token_id
 
127
 
128
- # 2. Generate Visualization State for *this* step
 
129
  current_state_vis = []
130
- gen_part_current = current_x[0, prompt_len:]
131
- gen_part_last = last_x[0, prompt_len:] if last_x is not None else None
 
 
 
132
 
133
  for i in range(gen_length):
134
- current_token_id = gen_part_current[i].item()
 
 
 
135
 
136
- # Basic check for safety, though unlikely needed with prompt_len check
137
- if current_token_id is None:
138
- current_state_vis.append((MASK_TOKEN, "#444444"))
139
- continue
140
 
141
  is_constrained = i in processed_constraints
142
- is_special = current_token_id in special_token_ids_to_hide
143
-
144
- # Decode carefully: don't skip specials initially for display text
145
- raw_token_str = tokenizer.decode([current_token_id], skip_special_tokens=False).strip()
146
- # Use MASK_TOKEN string for MASK_ID, otherwise use decoded string or '?'
147
- display_token = MASK_TOKEN if current_token_id == MASK_ID else (raw_token_str if raw_token_str else "?")
148
 
149
- # Determine the state based on current and previous token
150
- previous_token_id = gen_part_last[i].item() if gen_part_last is not None else None
151
 
152
- # Determine color and potentially modify display_token for hiding
153
- if current_token_id == MASK_ID:
154
- color = "#444444" # Dark Gray
155
  display_token = MASK_TOKEN
156
- elif is_constrained and processed_constraints.get(i) == current_token_id:
157
- color = "#800080" # Purple - keep constraint visible
158
- # Even if special, show the constraint for clarity
159
- elif previous_token_id == MASK_ID or previous_token_id is None:
160
- # --- Newly revealed in this step ---
161
- if is_special:
162
- # Newly revealed special token: Show it this time
163
- color = "#FF8C00" # Dark Orange (distinct color for first reveal)
164
- # display_token is already the raw special token string
165
- else:
166
- # Newly revealed regular token
167
- color = "#66CC66" # Light Green
168
- # display_token is already the regular token string
169
  elif is_special:
170
- # --- Was revealed previously AND is special: Hide it now ---
171
- color = "#FFFFFF" # White background / Transparent conceptually
172
- display_token = "" # Make it disappear visually
173
- # Alternative: Subtle placeholder
174
- # display_token = "."
175
- # color = "#EEEEEE"
 
 
 
 
 
 
176
  else:
177
- # --- Previously revealed regular token ---
 
178
  color = "#6699CC" # Light Blue
179
- # display_token is already the regular token string
 
 
 
180
 
181
  current_state_vis.append((display_token, color))
182
 
183
  visualization_states.append(current_state_vis)
184
 
185
  # 3. Update last_x for the *next* step's comparison
 
186
  last_x = constrained_x.clone()
187
 
188
- # 4. Return the sequence with constraints applied
189
- return constrained_x
190
 
191
 
192
  # --- Run DREAM Generation ---
193
  try:
194
  print("Calling model.diffusion_generate...")
195
- initial_full_x = torch.cat([input_ids, initial_x_part], dim=1)
196
- last_x = initial_full_x.clone() # Initialize last_x for the first hook call
 
 
197
 
198
  output = model.diffusion_generate(
199
- input_ids,
200
- attention_mask=attention_mask,
201
  max_new_tokens=gen_length,
202
- output_history=False,
203
  return_dict_in_generate=True,
204
  steps=steps,
205
  temperature=temperature,
206
  top_p=top_p,
207
  alg=alg,
 
208
  alg_temp=alg_temp if alg != "origin" else 0.0,
209
  generation_tokens_hook_func=generation_tokens_hook_func
 
 
 
210
  )
211
  print("model.diffusion_generate finished.")
212
 
 
 
213
  final_sequence = output.sequences[0]
 
 
214
  response_token_ids = final_sequence[prompt_length:]
215
 
216
- # Decode final text *skipping* special tokens for the chatbot display
217
  final_text = tokenizer.decode(
218
  response_token_ids,
219
  skip_special_tokens=True,
220
- clean_up_tokenization_spaces=True
221
  ).strip()
222
- print(f"Final generated text (cleaned): {final_text}")
223
 
224
- # Add the very final state to visualization if needed (safeguard)
225
- # This uses the same logic as the hook for consistency
226
- if len(visualization_states) <= steps:
227
- print("Adding final visualization state manually (safeguard).")
228
  final_state_vis = []
229
- final_gen_part = final_sequence[prompt_length:]
230
- # Need the state *before* this final one to know what was 'new'
231
- gen_part_last_final = last_x[0, prompt_len:] if last_x is not None else None
232
 
233
- for i in range(gen_length):
234
- current_token_id = final_gen_part[i].item()
235
  is_constrained = i in processed_constraints
236
- is_special = current_token_id in special_token_ids_to_hide
237
- raw_token_str = tokenizer.decode([current_token_id], skip_special_tokens=False).strip()
238
- display_token = MASK_TOKEN if current_token_id == MASK_ID else (raw_token_str if raw_token_str else "?")
239
- previous_token_id = gen_part_last_final[i].item() if gen_part_last_final is not None else None
240
-
241
- if current_token_id == MASK_ID:
242
- color = "#444444"
243
- display_token = MASK_TOKEN
244
- elif is_constrained and processed_constraints.get(i) == current_token_id:
245
- color = "#800080"
246
- elif previous_token_id == MASK_ID or previous_token_id is None: # Newly revealed
247
- color = "#FF8C00" if is_special else "#66CC66"
248
- elif is_special: # Previously revealed special
249
- color = "#FFFFFF"
250
- display_token = ""
251
- else: # Previously revealed regular
252
- color = "#6699CC"
253
-
254
  final_state_vis.append((display_token, color))
255
- visualization_states.append(final_state_vis)
 
 
 
 
 
 
 
256
 
257
 
258
  except Exception as e:
259
- print(f"Error during generation: {e}")
260
  import traceback
261
  traceback.print_exc()
262
- error_msg = f"Error during generation: {str(e)}"
 
 
263
  visualization_states.append([("Error", "red")])
264
  final_text = f"Generation failed: {e}"
265
 
266
- print("--- DREAM Generation Finished ---")
267
- return visualization_states, final_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # import numpy as np # Not strictly needed anymore
3
+ import gradio as gr
4
+ import spaces
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import time
7
+ import re # Keep for parsing constraints
8
+
9
+ # Use try-except for space deployment vs local
10
+ try:
11
+ # Used for spaces deployment with GPU
12
+ gpu_check = spaces.GPU
13
+ print("Running in Gradio Spaces with GPU environment.")
14
+ except AttributeError:
15
+ # Fallback for local execution or environments without spaces.GPU
16
+ print("Running in local environment or without spaces.GPU.")
17
+ # Define a dummy decorator if spaces.GPU is not available
18
+ def gpu_check(func):
19
+ return func
20
+
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+ print(f"Using device: {device}")
23
+
24
+ # --- Load DREAM Model and Tokenizer ---
25
+ # Ensure sufficient VRAM, Dream 7B needs ~16GB+ VRAM in bfloat16
26
+ model_path = "Dream-org/Dream-v0-Instruct-7B"
27
+ print(f"Loading model: {model_path}...")
28
+ try:
29
+ model = AutoModel.from_pretrained(
30
+ model_path,
31
+ torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
32
+ trust_remote_code=True,
33
+ # device_map='auto' # Consider if running into OOM errors, might split across GPUs/CPU
34
+ ).to(device).eval()
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ model_path,
37
+ trust_remote_code=True
38
+ )
39
+ print("Model and tokenizer loaded successfully.")
40
+ except Exception as e:
41
+ print(f"Error loading model or tokenizer: {e}")
42
+ print("Please ensure you have enough GPU memory and the model files are accessible.")
43
+ # Exit or raise if loading fails
44
+ raise e
45
+
46
+
47
+ # --- Constants for DREAM ---
48
+ # Find the mask token and ID from the DREAM tokenizer
49
+ if tokenizer.mask_token is None:
50
+ print("Warning: Mask token not found in tokenizer. Attempting to add '[MASK]'.")
51
+ # This might require retraining or fine-tuning if the model didn't see this token
52
+ num_added = tokenizer.add_special_tokens({'mask_token': '[MASK]'})
53
+ if num_added > 0:
54
+ print(f"Added '{tokenizer.mask_token}' to tokenizer.")
55
+ # Resize model embeddings if vocab changed
56
+ model.resize_token_embeddings(len(tokenizer))
57
+ print("Resized model token embeddings.")
58
+ else:
59
+ # Fallback or error if adding failed or mask token still None
60
+ # It's possible a different token serves this purpose in DREAM's training
61
+ print("Error: Could not set a mask token. Visualization might be inaccurate.")
62
+ # You might need to identify which token ID DREAM uses internally for masking tasks
63
+ # For now, we'll proceed but this is a potential issue.
64
+ MASK_TOKEN = "<?>" # Placeholder symbol
65
+ MASK_ID = -1 # Invalid ID indicates issue
66
+ if tokenizer.mask_token is None:
67
+ raise ValueError("Could not set a mask token for the tokenizer.")
68
+
69
+ MASK_TOKEN = tokenizer.mask_token
70
+ MASK_ID = tokenizer.mask_token_id
71
+ print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
72
+
73
+ # Identify other special tokens to potentially hide/show
74
+ eos_token_id = tokenizer.eos_token_id
75
+ pad_token_id = tokenizer.pad_token_id
76
+ special_token_ids_set = {MASK_ID} # Start with Mask ID
77
+ if eos_token_id is not None:
78
+ special_token_ids_set.add(eos_token_id)
79
+ print(f"EOS token ID: {eos_token_id} ({tokenizer.decode([eos_token_id])})")
80
+ if pad_token_id is not None:
81
+ special_token_ids_set.add(pad_token_id)
82
+ print(f"PAD token ID: {pad_token_id} ({tokenizer.decode([pad_token_id])})")
83
+ # Add other common special tokens if needed (e.g., BOS, UNK)
84
+ if tokenizer.bos_token_id is not None:
85
+ special_token_ids_set.add(tokenizer.bos_token_id)
86
+ print(f"BOS token ID: {tokenizer.bos_token_id} ({tokenizer.decode([tokenizer.bos_token_id])})")
87
+ if tokenizer.unk_token_id is not None:
88
+ special_token_ids_set.add(tokenizer.unk_token_id)
89
+ print(f"UNK token ID: {tokenizer.unk_token_id} ({tokenizer.decode([tokenizer.unk_token_id])})")
90
+
91
+ print(f"Identified special token IDs: {special_token_ids_set}")
92
+
93
+ # --- Helper Functions (Constraint Parsing, History Formatting) ---
94
+
95
+ def parse_constraints(constraints_text):
96
+ """Parse constraints in format: 'position:word, position:word, ...'"""
97
+ constraints = {}
98
+ if not constraints_text:
99
+ return constraints
100
+
101
+ parts = constraints_text.split(',')
102
+ for part in parts:
103
+ part = part.strip() # Trim whitespace
104
+ if ':' not in part:
105
+ continue
106
+ try:
107
+ pos_str, word = part.split(':', 1)
108
+ pos = int(pos_str.strip())
109
+ word = word.strip()
110
+ # Allow empty words if needed? Forcing empty seems odd. Let's require a word.
111
+ if word and pos >= 0:
112
+ constraints[pos] = word
113
+ except ValueError:
114
+ print(f"Warning: Could not parse constraint part: '{part}'")
115
+ continue
116
+
117
+ return constraints
118
+
119
+ def format_chat_history(history):
120
+ """
121
+ Format chat history for the DREAM model (standard messages format)
122
+
123
+ Args:
124
+ history: List of [user_message, assistant_message] pairs
125
 
126
+ Returns:
127
+ Formatted conversation for the model (list of dictionaries)
128
+ """
129
+ messages = []
130
+ # Check if a system prompt is appropriate for Dream-Instruct
131
+ # From demo_completion.py example, it seems it uses system prompt via template
132
+ # messages.append({"role": "system", "content": "You are a helpful assistant."})
133
+ for user_msg, assistant_msg in history:
134
+ if user_msg is not None: # Handle potential None message if clearing failed
135
+ messages.append({"role": "user", "content": user_msg})
136
+ if assistant_msg: # Skip if None (for the latest user message awaiting response)
137
+ messages.append({"role": "assistant", "content": assistant_msg})
138
+
139
+ return messages
140
+
141
+ # --- Core Generation Logic for DREAM with Visualization ---
142
+
143
+ @gpu_check # Use the potentially dummy decorator
144
+ @torch.no_grad() # Disable gradient calculations for inference
145
  def dream_generate_response_with_visualization(
146
  messages,
147
+ gen_length=128,
148
+ steps=128, # Default based on DREAM examples
149
  constraints=None,
150
+ temperature=0.6, # Default based on DREAM examples
151
+ top_p=0.95, # Default based on DREAM examples
152
+ alg="entropy", # Default based on DREAM examples
153
+ alg_temp=0.1, # Default based on DREAM examples
154
  ):
155
  """
156
+ Generate text with DREAM model with visualization using the generation hook.
 
157
 
158
  Args:
159
  messages: List of message dictionaries with 'role' and 'content'
 
168
  Returns:
169
  Tuple: (List of visualization states, final generated text string)
170
  """
171
+ print("\n--- Starting DREAM Generation ---")
172
+ print(f"Params: len={gen_length}, steps={steps}, temp={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
173
  print(f"Constraints: {constraints}")
174
 
175
  # --- Input Preparation ---
176
  if constraints is None:
177
  constraints = {}
178
 
179
+ # Convert word constraints to token IDs (handle multi-token words)
180
  processed_constraints = {}
181
+ constraint_token_lengths = {} # Store length for multi-token constraints
182
  print("Processing constraints:")
183
  for pos, word in constraints.items():
184
+ # Prepend space for potentially better tokenization consistency
185
+ # (though apply_chat_template should handle spacing)
186
  tokens = tokenizer.encode(" " + word, add_special_tokens=False)
187
  if not tokens:
188
  print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
189
  continue
190
+ print(f" Pos {pos}, Word '{word}' -> Tokens {tokens} ({tokenizer.convert_ids_to_tokens(tokens)})")
191
+ constraint_token_lengths[pos] = len(tokens)
192
  for i, token_id in enumerate(tokens):
193
+ target_pos = pos + i
194
+ if target_pos in processed_constraints:
195
+ print(f" Warning: Overlapping constraint token at position {target_pos}. Keeping first constraint's token ({processed_constraints[target_pos]}).")
196
  else:
197
+ processed_constraints[target_pos] = token_id
198
 
199
+ # Prepare the prompt using chat template
200
  try:
201
  inputs = tokenizer.apply_chat_template(
202
  messages,
203
  return_tensors="pt",
204
  return_dict=True,
205
+ add_generation_prompt=True # Crucial for Dream-Instruct
206
  )
207
  input_ids = inputs.input_ids.to(device=device)
208
+ # Use the attention mask generated by the template
209
  attention_mask = inputs.attention_mask.to(device=device)
210
  prompt_length = input_ids.shape[1]
211
  print(f"Input prompt length: {prompt_length}")
212
+ # print(f"Input IDs: {input_ids}")
213
+ # print(f"Attention Mask: {attention_mask}") # Verify mask covers prompt
214
  except Exception as e:
215
  print(f"Error applying chat template: {e}")
216
  return [([("Error applying chat template.", "red")],)], f"Error: {e}"
217
 
218
+ # Check context length (DREAM uses 2048 default)
219
+ model_max_length = getattr(model.config, 'max_position_embeddings', 2048)
220
+ if prompt_length + gen_length > model_max_length:
221
+ print(f"Warning: Requested length ({prompt_length + gen_length}) exceeds model max length ({model_max_length}). Truncating gen_length.")
222
+ gen_length = model_max_length - prompt_length
223
  if gen_length <= 0:
224
  print("Error: Prompt is already too long.")
225
  return [([("Prompt too long.", "red")],)], "Error: Prompt too long."
226
 
227
  # --- State for Visualization Hook ---
228
  visualization_states = []
229
+ last_x = None # Store the full sequence (prompt + generation) from the previous step
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ # Initial state: Prompt + all masks for generation part
232
+ initial_gen_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
233
+ # Apply initial constraints to the masked part *before* the first visualization state
234
  for pos, token_id in processed_constraints.items():
235
+ absolute_pos = pos # Position relative to start of generation
236
  if 0 <= absolute_pos < gen_length:
237
+ initial_gen_part[0, absolute_pos] = token_id
238
 
239
+ # Create the first visualization state (only the generation part)
240
  initial_state_vis = []
241
  for i in range(gen_length):
242
+ token_id = initial_gen_part[0, i].item()
243
  if token_id == MASK_ID:
244
  initial_state_vis.append((MASK_TOKEN, "#444444")) # Mask color
245
  else:
246
+ # This must be a constraint applied initially
247
+ # Decode without skipping special to see raw constraint if needed
248
+ token_str = tokenizer.decode([token_id], skip_special_tokens=False).strip()
249
  initial_state_vis.append((token_str if token_str else "?", "#800080")) # Constraint color (purple)
250
  visualization_states.append(initial_state_vis)
251
 
252
  # --- Define the Hook Function ---
253
  def generation_tokens_hook_func(step, x, logits):
254
  nonlocal last_x, visualization_states # Allow modification of outer scope variables
255
+ # print(f"Hook step {step}") # Keep console less noisy
256
+
257
+ current_x = x.clone() # Full sequence (prompt + generation) at this step
258
 
259
+ # 1. Apply Constraints to the current sequence
260
  constrained_x = current_x.clone()
261
+ current_prompt_len = current_x.shape[1] - gen_length # Recalculate prompt length based on current x
262
+ if current_prompt_len < 0:
263
+ print(f"Warning: prompt_len {current_prompt_len} negative in hook step {step}, skipping constraints/vis.")
264
+ return current_x # Return unmodified if something is wrong
265
 
 
266
  for pos, token_id in processed_constraints.items():
267
+ # pos is relative to the start of the *generation* part
268
+ absolute_pos = current_prompt_len + pos
269
+ # Ensure position is within the bounds of the *current* sequence 'x'
270
+ if current_prompt_len <= absolute_pos < current_x.shape[1]:
271
  if constrained_x[0, absolute_pos] != token_id:
272
  constrained_x[0, absolute_pos] = token_id
273
+ # print(f" Constraint enforced at pos {pos} ({absolute_pos}) -> {token_id}")
274
 
275
+
276
+ # 2. Generate Visualization State for *this* step (generation part only)
277
  current_state_vis = []
278
+ # Compare current_x (before explicit constraint application in *this* hook call)
279
+ # with last_x (state from *previous* hook call / initial state)
280
+ gen_part_current = current_x[0, current_prompt_len:]
281
+ # Ensure last_x exists and has the same shape for comparison
282
+ gen_part_last = last_x[0, current_prompt_len:] if (last_x is not None and last_x.shape == current_x.shape) else None
283
 
284
  for i in range(gen_length):
285
+ # Ensure index i is valid for the current generation part
286
+ if i >= gen_part_current.shape[0]:
287
+ print(f"Warning: Index {i} out of bounds for gen_part_current (shape {gen_part_current.shape}) in step {step}.")
288
+ continue # Skip if index is invalid
289
 
290
+ current_token_id = gen_part_current[i].item()
291
+ # Handle case where last_x was None or had different shape
292
+ last_token_id = gen_part_last[i].item() if gen_part_last is not None and i < gen_part_last.shape[0] else MASK_ID # Assume mask initially
 
293
 
294
  is_constrained = i in processed_constraints
295
+ is_special = current_token_id in special_token_ids_set
296
+ is_mask = current_token_id == MASK_ID
297
+ was_mask = last_token_id == MASK_ID or last_x is None # Treat first step as coming from mask
 
 
 
298
 
299
+ display_token = ""
300
+ color = ""
301
 
302
+ # Determine display token and color based on state transitions
303
+ if is_mask:
 
304
  display_token = MASK_TOKEN
305
+ color = "#444444" # Dark Gray
306
+ elif is_constrained and processed_constraints[i] == current_token_id:
307
+ # Always show the constrained token, color purple
308
+ # Decide whether to show raw special tokens when constrained
309
+ raw_decode = tokenizer.decode([current_token_id], skip_special_tokens=False).strip()
310
+ display_token = raw_decode if raw_decode else "?"
311
+ color = "#800080" # Purple
 
 
 
 
 
 
312
  elif is_special:
313
+ if was_mask:
314
+ # Newly revealed special token: Show its representation once
315
+ display_token = tokenizer.decode([current_token_id], skip_special_tokens=False).strip() # Show raw special
316
+ color = "#FF8C00" # DarkOrange
317
+ else:
318
+ # Already revealed special token: Hide it by showing a space
319
+ display_token = " " # Effectively hides it
320
+ color = "#6699CC" # Use 'Old' color (Light Blue) but content is hidden space
321
+ elif was_mask:
322
+ # Newly revealed normal token
323
+ display_token = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
324
+ color = "#66CC66" # Light Green
325
  else:
326
+ # Previously revealed normal token
327
+ display_token = tokenizer.decode([current_token_id], skip_special_tokens=True).strip()
328
  color = "#6699CC" # Light Blue
329
+
330
+ # Fallback for empty decodes of non-special, non-mask tokens
331
+ if not display_token and not is_mask and not (is_special and not was_mask):
332
+ display_token = "?" # Use question mark for unexpected empty decodes
333
 
334
  current_state_vis.append((display_token, color))
335
 
336
  visualization_states.append(current_state_vis)
337
 
338
  # 3. Update last_x for the *next* step's comparison
339
+ # Store the state *after* applying constraints for accurate comparison next time
340
  last_x = constrained_x.clone()
341
 
342
+ # 4. Return the sequence with constraints applied for the model's next step
343
+ return constrained_x # Return the sequence with constraints enforced
344
 
345
 
346
  # --- Run DREAM Generation ---
347
  try:
348
  print("Calling model.diffusion_generate...")
349
+ # Make sure last_x is initialized correctly before the first hook call
350
+ # It should represent the state *before* the first diffusion step.
351
+ initial_full_x = torch.cat([input_ids, initial_gen_part], dim=1)
352
+ last_x = initial_full_x.clone() # Initialize last_x with prompt + initial masked/constrained gen part
353
 
354
  output = model.diffusion_generate(
355
+ input_ids=input_ids,
356
+ attention_mask=attention_mask, # Pass the correct attention mask
357
  max_new_tokens=gen_length,
358
+ output_history=False, # We build history in the hook
359
  return_dict_in_generate=True,
360
  steps=steps,
361
  temperature=temperature,
362
  top_p=top_p,
363
  alg=alg,
364
+ # alg_temp is only relevant for confidence-based algs (not 'origin')
365
  alg_temp=alg_temp if alg != "origin" else 0.0,
366
  generation_tokens_hook_func=generation_tokens_hook_func
367
+ # Ensure generation doesn't run past eos_token if not desired
368
+ # eos_token_id=eos_token_id, # This might stop generation early
369
+ # pad_token_id=tokenizer.eos_token_id # Often pad is same as eos for LLMs
370
  )
371
  print("model.diffusion_generate finished.")
372
 
373
+ # Extract final generated sequence (response part only)
374
+ # The hook ensures the returned sequence has constraints applied
375
  final_sequence = output.sequences[0]
376
+ # Handle potential length mismatch if generation stopped early
377
+ actual_gen_len = final_sequence.shape[0] - prompt_length
378
  response_token_ids = final_sequence[prompt_length:]
379
 
380
+ # Decode the final response, skipping special tokens like EOS/PAD
381
  final_text = tokenizer.decode(
382
  response_token_ids,
383
  skip_special_tokens=True,
384
+ clean_up_tokenization_spaces=True # Recommended for cleaner output
385
  ).strip()
386
+ print(f"Final generated text: '{final_text}'")
387
 
388
+ # Add the very final state to visualization if the hook didn't capture it
389
+ # (Mainly a safeguard, hook should run 'steps' times or until completion)
390
+ if len(visualization_states) <= steps: # Hook might run 'steps' times
 
391
  final_state_vis = []
392
+ final_gen_part = response_token_ids # Use the extracted response tokens
 
 
393
 
394
+ for i in range(len(final_gen_part)): # Iterate over actual generated tokens
395
+ token_id = final_gen_part[i].item()
396
  is_constrained = i in processed_constraints
397
+ is_special = token_id in special_token_ids_set
398
+ is_mask = token_id == MASK_ID # Should not happen in final output
399
+
400
+ display_token = ""
401
+ color = ""
402
+
403
+ if is_mask: color = "#444444"; display_token = MASK_TOKEN
404
+ elif is_constrained and processed_constraints.get(i) == token_id:
405
+ raw_decode = tokenizer.decode([token_id], skip_special_tokens=False).strip()
406
+ display_token = raw_decode if raw_decode else "?"; color = "#800080" # Purple
407
+ elif is_special:
408
+ # Hide special tokens in the *final* display state for cleaner look
409
+ display_token = " "; color = "#6699CC" # Hide as 'Old' blue
410
+ else:
411
+ display_token = tokenizer.decode([token_id], skip_special_tokens=True).strip()
412
+ color = "#6699CC" # Final state uses 'Old' blue
413
+
414
+ if not display_token: display_token = "?" # Fallback
415
  final_state_vis.append((display_token, color))
416
+
417
+ # Pad the final state visualization if actual gen len < requested gen_length
418
+ # This shouldn't be necessary if HighlightedText handles shorter lists
419
+ # while len(final_state_vis) < gen_length:
420
+ # final_state_vis.append((" ", "#FFFFFF")) # Add empty space
421
+
422
+ if final_state_vis: # Only append if we generated something
423
+ visualization_states.append(final_state_vis)
424
 
425
 
426
  except Exception as e:
427
+ print(f"\n--- Error during generation ---")
428
  import traceback
429
  traceback.print_exc()
430
+ # Add error message to visualization
431
+ error_msg = f"Generation Error: Check Logs"
432
+ # Append error to visualization states if possible
433
  visualization_states.append([("Error", "red")])
434
  final_text = f"Generation failed: {e}"
435
 
436
+ print("--- DREAM Generation Finished ---\n")
437
+ # Ensure we always return a list (even if empty) and a string
438
+ if not isinstance(visualization_states, list): visualization_states = []
439
+ if not isinstance(final_text, str): final_text = str(final_text)
440
+
441
+ return visualization_states, final_text
442
+
443
+
444
+ # --- Gradio UI Setup ---
445
+
446
+ css = '''
447
+ .category-legend{display:none}
448
+ /* Increase overall base font size */
449
+ body, .gradio-container { font-size: 105%; }
450
+ /* Make buttons slightly larger */
451
+ /* button { min-height: 40px; } */
452
+ .small_btn {
453
+ min-width: 60px; /* Adjust as needed */
454
+ max-width: 100px;
455
+ height: 42px; /* Adjust height */
456
+ flex-grow: 0 !important; /* Prevent button from growing */
457
+ margin-left: 5px !important; /* Add some space */
458
+ font-size: 100%; /* Match button font size */
459
+ padding: 0 10px !important; /* Adjust padding */
460
+ }
461
+ .chat-input-row {
462
+ display: flex;
463
+ align-items: center; /* Vertically align items */
464
+ margin-top: 10px; /* Add space above input row */
465
+ }
466
+ /* Ensure Textbox takes up space */
467
+ .chat-input-row .gr-textbox {
468
+ flex-grow: 1;
469
+ margin-right: 5px;
470
+ }
471
+ /* Chatbot styling */
472
+ .gr-chatbot .message {
473
+ font-size: 100%; /* Ensure chat message font size is reasonable */
474
+ padding: 10px !important;
475
+ border-radius: 8px !important;
476
+ }
477
+ .gr-chatbot .message.user { background-color: #E0F7FA !important; align-self: flex-end; } /* Light cyan for user */
478
+ .gr-chatbot .message.bot { background-color: #F1F8E9 !important; align-self: flex-start; } /* Light green for bot */
479
+ /* HighlightedText styling */
480
+ .gr-highlightedtext span {
481
+ padding: 1px 2px; /* Minimal padding */
482
+ margin: 0 1px; /* Minimal margin */
483
+ border-radius: 3px;
484
+ font-family: monospace; /* Use monospace font for better alignment */
485
+ font-size: 95%; /* Slightly smaller font for dense vis */
486
+ line-height: 1.4; /* Adjust line spacing */
487
+ }
488
+ .gr-highlightedtext {
489
+ padding: 10px;
490
+ border: 1px solid #E0E0E0;
491
+ border-radius: 5px;
492
+ background-color: #FAFAFA; /* Light background for the container */
493
+ }
494
+ /* Legend Styling */
495
+ .legend {
496
+ font-size: 90%;
497
+ margin-top: 5px;
498
+ color: #555;
499
+ }
500
+ .legend span {
501
+ display: inline-block; /* Keep legend items inline */
502
+ margin-right: 10px;
503
+ white-space: nowrap; /* Prevent wrapping */
504
+ }
505
+ .legend span::before { /* Style the color square */
506
+ content: '■';
507
+ display: inline-block;
508
+ margin-right: 4px;
509
+ font-size: 120%; /* Make square slightly larger */
510
+ vertical-align: middle; /* Align square with text */
511
+ }
512
+ '''
513
+ def create_chatbot_demo():
514
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
515
+ gr.Markdown("## Dream 7B - Diffusion Language Model Demo")
516
+ gr.Markdown("Interact with the Dream 7B instruction-tuned model and watch the diffusion process unfold step-by-step. "
517
+ "You can optionally constrain specific words at certain positions.")
518
+ with gr.Row():
519
+ gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)", scale=1)
520
+ gr.Markdown("[Blog Post](https://hkunlp.github.io/blog/2025/dream/)", scale=1)
521
+
522
+ # STATE MANAGEMENT
523
+ chat_history = gr.State([]) # Stores conversation [[user, bot], ...]
524
+
525
+ # UI LAYOUT
526
+ with gr.Row():
527
+ # Left Column: Chat Interface
528
+ with gr.Column(scale=3):
529
+ chatbot_ui = gr.Chatbot(
530
+ label="Conversation",
531
+ height=550,
532
+ bubble_full_width=False,
533
+ show_copy_button=True,
534
+ render=False # Rendered explicitly later for streaming
535
+ )
536
+ chatbot_ui.render() # Manually render after setting parameters
537
+
538
+ # Message input Row
539
+ with gr.Row(elem_classes="chat-input-row"):
540
+ user_input = gr.Textbox(
541
+ label="Your Message",
542
+ placeholder="Type your message and press Enter, or click Send...",
543
+ scale=4, # Give textbox more space relative to button
544
+ container=False,
545
+ show_label=False
546
+ )
547
+ send_btn = gr.Button("Send", scale=1, elem_classes="small_btn", variant="primary")
548
+
549
+ constraints_input = gr.Textbox(
550
+ label="Word Constraints (Optional)",
551
+ info="Force words at positions (0-indexed from response start). Format: 'pos:word, pos:word'. Example: '0:Once, 5:upon, 10:time'",
552
+ placeholder="e.g., 0:Hello, 6:world",
553
+ lines=1
554
+ )
555
+
556
+ # Right Column: Visualization and Settings
557
+ with gr.Column(scale=2):
558
+ gr.Markdown("### Denoising Process Visualization")
559
+ output_vis = gr.HighlightedText(
560
+ label="Generation Steps",
561
+ show_label=False, # Label provided by Markdown above
562
+ combine_adjacent=False,
563
+ show_legend=False, # Using custom HTML legend below
564
+ # color_map is not directly used due to show_legend=False, but useful for reference
565
+ color_map={
566
+ "Mask": "#444444",
567
+ "New": "#66CC66",
568
+ "Old": "#6699CC",
569
+ "Constraint": "#800080",
570
+ "Special (New)": "#FF8C00",
571
+ "Error": "red"
572
+ }
573
+ )
574
+ # Custom HTML Legend
575
+ gr.HTML(
576
+ """
577
+ <div class='legend'>
578
+ <span style="color:#444444;">■ Mask</span> |
579
+ <span style='color:#66CC66;'>■ New</span> |
580
+ <span style='color:#FF8C00;'>■ Special (New)</span> |
581
+ <span style='color:#6699CC;'>■ Old</span> |
582
+ <span style='color:#800080;'>■ Constraint</span>
583
+ </div>
584
+ """,
585
+ elem_id="legend-html"
586
+ )
587
+
588
+ # Generation Settings Accordion
589
+ with gr.Accordion("Generation Settings", open=False):
590
+ gen_length = gr.Slider(
591
+ minimum=16, maximum=512, value=128, step=16,
592
+ label="Max New Tokens", info="Max response length."
593
+ )
594
+ steps = gr.Slider(
595
+ minimum=8, maximum=512, value=128, step=8,
596
+ label="Diffusion Steps", info="More steps = finer generation (potentially slower)."
597
+ )
598
+ temperature = gr.Slider(
599
+ minimum=0.0, maximum=1.5, value=0.6, step=0.05,
600
+ label="Temperature", info="Controls randomness. Lower=more deterministic."
601
+ )
602
+ top_p = gr.Slider(
603
+ minimum=0.0, maximum=1.0, value=0.95, step=0.05,
604
+ label="Top-P (Nucleus)", info="Filters vocabulary probabilistically. Lower=less diverse."
605
+ )
606
+ # Map UI choices to DREAM's alg parameters
607
+ remasking_strategy = gr.Radio(
608
+ choices=[
609
+ ("Random", "origin"), # User friendly name -> actual param
610
+ ("Entropy", "entropy"),
611
+ ("MaskGit+", "maskgit_plus"),
612
+ ("TopK Margin", "topk_margin"),
613
+ ],
614
+ value="entropy", # Default
615
+ label="Generation Order Strategy (alg)",
616
+ info="How the model decides which tokens to generate first."
617
+ )
618
+ alg_temp = gr.Slider(
619
+ minimum=0.0, maximum=1.0, value=0.1, step=0.05,
620
+ label="Order Randomness (alg_temp)" ,
621
+ info="Adds randomness to confidence-based strategies (Entropy, MaskGit+, TopK). Ignored for Random."
622
+ )
623
+ visualization_delay = gr.Slider(
624
+ minimum=0.0, maximum=0.5, value=0.05, step=0.01,
625
+ label="Visualization Delay (sec)", info="Pause between steps in visualization."
626
+ )
627
+
628
+ # Clear button Row
629
+ with gr.Row():
630
+ clear_btn = gr.Button("Clear Conversation", variant="stop", icon="🗑️")
631
+
632
+
633
+ # --- Event Handlers ---
634
+
635
+ # Helper to add message to history state
636
+ def add_message_to_history(history_state, user_message, bot_message):
637
+ # history_state is the raw list from gr.State
638
+ history_state.append([user_message, bot_message])
639
+ return history_state
640
+
641
+ # Function when user submits message (Enter or Send button)
642
+ def handle_user_message(message, history_state):
643
+ print(f"User submitted: '{message}'")
644
+ if not message or not message.strip():
645
+ print("Empty message submitted, doing nothing.")
646
+ # Return unchanged state if message is empty
647
+ # Need to return values for all outputs of the .submit/.click
648
+ # history_state, chatbot_ui, user_input, output_vis
649
+ return history_state, history_state, "", [] # No change to chatbot UI yet
650
+
651
+ # Add user message to history state (with None for bot response initially)
652
+ updated_history_state = add_message_to_history(history_state, message, None)
653
+
654
+ # Prepare updated history for display in Chatbot UI
655
+ # We only display the user message now, bot response comes later
656
+ chatbot_display = updated_history_state.copy()
657
+
658
+ # Clear the input textbox and visualization
659
+ return updated_history_state, chatbot_display, "", []
660
+
661
+ # Function to generate bot response (triggered after user message is handled)
662
+ # Uses yield for streaming visualization updates
663
+ def generate_bot_response(
664
+ history_state, # The current state [[user, None], ...]
665
+ gen_length_val, steps_val, constraints_text, delay_val,
666
+ temperature_val, top_p_val, alg_val, alg_temp_val
667
+ ):
668
+ print("\n--- Streaming Bot Response ---")
669
+ if not history_state or history_state[-1][1] is not None:
670
+ print("History empty or last message already has response. Skipping generation.")
671
+ # Yield current state if called unnecessarily
672
+ yield history_state, [] # Chatbot UI, Visualization
673
+ return
674
+
675
+ # Get the conversation history in the format the model expects
676
+ messages_for_model = format_chat_history(history_state) # Includes the latest user query
677
+
678
+ # Parse constraints from the textbox
679
+ parsed_constraints = parse_constraints(constraints_text)
680
+
681
+ # Generate response with visualization (this function handles the core logic)
682
+ vis_states, response_text = dream_generate_response_with_visualization(
683
+ messages_for_model,
684
+ gen_length=gen_length_val,
685
+ steps=steps_val,
686
+ constraints=parsed_constraints,
687
+ temperature=temperature_val,
688
+ top_p=top_p_val,
689
+ alg=alg_val,
690
+ alg_temp=alg_temp_val
691
+ )
692
+
693
+ # Update the history state with the final bot response (critical!)
694
+ history_state[-1][1] = response_text.strip()
695
+
696
+ # Stream the updates
697
+ if vis_states:
698
+ # Yield the initial visualization state first
699
+ yield history_state, vis_states[0] # Update chatbot UI (implicitly via history), update visualization
700
+
701
+ # Then animate through the rest of the visualization states
702
+ for state in vis_states[1:]:
703
+ time.sleep(delay_val)
704
+ yield history_state, state # Update chatbot UI, update visualization
705
+ else:
706
+ # Handle case where generation failed or produced no visualization
707
+ print("Warning: No visualization states generated.")
708
+ yield history_state, [("No visualization generated.", "orange")] # Update chatbot UI, show warning in vis
709
+
710
+ print("--- Streaming Complete ---")
711
+
712
+
713
+ # Function to clear everything
714
+ def clear_conversation_state():
715
+ print("Clearing conversation.")
716
+ # Reset state and UI components
717
+ return [], [], "", [] # chat_history (State), chatbot_ui, user_input, output_vis
718
+
719
+ # --- Wire UI elements to functions ---
720
+
721
+ # Define shared inputs for generation to avoid repetition
722
+ generation_inputs = [
723
+ chat_history, gen_length, steps, constraints_input, visualization_delay,
724
+ temperature, top_p, remasking_strategy, alg_temp
725
+ ]
726
+ # Define shared outputs for streaming
727
+ streaming_outputs = [chatbot_ui, output_vis]
728
+
729
+ # Typing in Textbox and pressing Enter
730
+ user_input.submit(
731
+ fn=handle_user_message,
732
+ inputs=[user_input, chat_history],
733
+ outputs=[chat_history, chatbot_ui, user_input, output_vis], # Update history state, chatbot display, clear input, clear vis
734
+ queue=False # Process user input immediately
735
+ ).then(
736
+ fn=generate_bot_response,
737
+ inputs=generation_inputs,
738
+ outputs=streaming_outputs, # Stream updates to chatbot and visualization
739
+ #api_name="generate_stream" # Optional: Name for API endpoint
740
+ )
741
+
742
+ # Clicking the Send button
743
+ send_btn.click(
744
+ fn=handle_user_message,
745
+ inputs=[user_input, chat_history],
746
+ outputs=[chat_history, chatbot_ui, user_input, output_vis],
747
+ queue=False
748
+ ).then(
749
+ fn=generate_bot_response,
750
+ inputs=generation_inputs,
751
+ outputs=streaming_outputs,
752
+ # api_name="generate_stream_click" # Optional
753
+ )
754
+
755
+ # Clicking the Clear button
756
+ clear_btn.click(
757
+ fn=clear_conversation_state,
758
+ inputs=[],
759
+ outputs=[chat_history, chatbot_ui, user_input, output_vis],
760
+ queue=False # Clearing should be instant
761
+ )
762
+
763
+ return demo
764
+
765
+ # --- Launch the Gradio App ---
766
+ if __name__ == "__main__":
767
+ print("Creating Gradio demo...")
768
+ gradio_demo = create_chatbot_demo()
769
+ print("Launching Gradio demo...")
770
+ # Use queue() for handling concurrent users and potentially long generation times
771
+ # share=True generates a public link (useful for Colab/Spaces)
772
+ # debug=True provides helpful Gradio logs in the console
773
+ gradio_demo.queue().launch(share=True, debug=False) # Set debug=True for more verbose logs if needed