multimodalart HF Staff commited on
Commit
3d09f97
·
verified ·
1 Parent(s): fb0307e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -271
app.py CHANGED
@@ -13,69 +13,40 @@ import torch.distributions as dists # Added import
13
  # --- START: Copied Helper functions from generation_utils.py ---
14
  # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
15
  def top_p_logits(logits, top_p=None):
16
- """ Applies top-p filtering to logits. """
17
- if top_p is None or top_p >= 1.0:
18
- return logits
19
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
20
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
21
  sorted_indices_to_remove = cumulative_probs > top_p
22
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
23
- sorted_indices_to_remove[..., 0] = 0
24
- mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
25
- mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
26
- logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
27
- return logits
28
 
29
  def top_k_logits(logits, top_k=None):
30
- """ Applies top-k filtering to logits. """
31
- if top_k is None or top_k <= 0:
32
- return logits
33
  top_k = min(top_k, logits.size(-1))
34
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
35
- logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
36
- return logits
37
 
38
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
39
- """ Samples tokens based on logits and calculates confidence. """
40
- if temperature > 0:
41
- safe_temp = max(temperature, 1e-6)
42
- logits = logits / safe_temp
43
- if top_p is not None and 0.0 < top_p < 1.0:
44
- logits = top_p_logits(logits, top_p)
45
- if top_k is not None and top_k > 0:
46
- logits = top_k_logits(logits, top_k)
47
  is_all_neg_inf = torch.all(logits == torch.finfo(logits.dtype).min, dim=-1, keepdim=True)
48
- if torch.any(is_all_neg_inf):
49
- uniform_logits = torch.zeros_like(logits)
50
- logits = torch.where(is_all_neg_inf, uniform_logits, logits)
51
  probs = torch.softmax(logits, dim=-1)
52
- probs = torch.clamp(probs, min=0.0)
53
- probs = probs / probs.sum(dim=-1, keepdim=True)
54
- probs = torch.nan_to_num(probs, nan=0.0)
55
  if temperature > 0:
56
- try:
57
- x0 = dists.Categorical(probs=probs).sample()
58
- confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
59
- except Exception as e:
60
- print(f"Warning: Error during Categorical sampling: {e}. Falling back to argmax.")
61
- confidence, x0 = probs.max(dim=-1)
62
- else:
63
- confidence, x0 = probs.max(dim=-1)
64
- if margin_confidence:
65
- sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
66
- top1_probs = sorted_probs[..., 0]
67
- top2_probs = sorted_probs[..., 1] if sorted_probs.shape[-1] > 1 else top1_probs
68
- confidence = top1_probs - top2_probs
69
- if neg_entropy:
70
- epsilon = 1e-10
71
- log_probs = torch.log(probs + epsilon)
72
- confidence = torch.sum(probs * log_probs, dim=-1)
73
  confidence = torch.nan_to_num(confidence, nan=0.0)
74
  return confidence, x0
75
  # --- END: Copied Helper functions ---
76
 
77
-
78
- # [Keep model loading, constants]
79
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
80
  model_path = "Dream-org/Dream-v0-Instruct-7B"
81
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -104,10 +75,8 @@ try:
104
  SPECIAL_TOKEN_IDS.add(IM_END_ID)
105
  except KeyError: IM_START_ID, IM_END_ID = None, None
106
 
107
-
108
  # --- Helper Functions ---
109
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
110
- """ Parses word constraints. """
111
  constraints = {}
112
  if not constraints_text: return constraints
113
  parts = constraints_text.split(',')
@@ -119,55 +88,26 @@ def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
119
  pos = int(pos_str.strip())
120
  word = word.strip()
121
  token_ids = []
122
- if word:
123
- text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word
124
- token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False)
125
  if token_ids and pos >= 0: constraints[pos] = token_ids
126
  elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'")
127
  except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
128
  except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}")
129
  return constraints
130
 
131
- def format_chat_history(history: List[List[Optional[str]]]) -> List[Dict[str, str]]:
132
- """
133
- Formats chat history [[user, bot], [user, bot]] into [{'role': 'user', 'content': ...}, ...]
134
- for the tokenizer's chat template.
135
- """
136
- messages = []
137
- # Ensure history is not empty and is properly structured
138
- if not history:
139
- return messages
140
- for turn in history:
141
- if not isinstance(turn, (list, tuple)) or len(turn) != 2:
142
- print(f"Warning: Skipping malformed history turn: {turn}")
143
- continue
144
- user_msg, assistant_msg = turn
145
- if user_msg is not None: # Check if user message exists
146
- # Ensure content is a string
147
- user_content = str(user_msg) if user_msg is not None else ""
148
- messages.append({"role": "user", "content": user_content})
149
- # Add assistant message only if it exists and is not None
150
- if assistant_msg is not None:
151
- assistant_content = str(assistant_msg) if assistant_msg is not None else ""
152
- messages.append({"role": "assistant", "content": assistant_content})
153
- # print(f"Formatted messages for template: {messages}") # Debug
154
- return messages
155
 
156
  def apply_constraints_to_state(
157
  x: torch.Tensor, prompt_length: int, total_length: int,
158
  parsed_constraints: Dict[int, List[int]], current_step: Optional[int] = None
159
  ) -> torch.Tensor:
160
- """ Applies constraints to the state tensor `x`. """
161
  modified_x = x.clone()
162
  for rel_pos, word_token_ids in parsed_constraints.items():
163
- abs_start_pos = prompt_length + rel_pos
164
- abs_end_pos = abs_start_pos + len(word_token_ids)
165
  if abs_start_pos < total_length and abs_end_pos <= total_length:
166
- try:
167
- constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device)
168
- modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
169
- except IndexError: print(f"Warning (Step {current_step}): Constraint OOB: {rel_pos}")
170
- except Exception as e: print(f"Warning (Step {current_step}): Constraint failed {rel_pos}: {e}")
171
  return modified_x
172
 
173
 
@@ -176,7 +116,7 @@ def apply_constraints_to_state(
176
  @spaces.GPU
177
  @torch.no_grad()
178
  def generate_dream_response(
179
- history: List[List[Optional[str]]], # IMPORTANT: This is the *full* history from the state
180
  gen_length: int,
181
  steps: int,
182
  constraints_text: str,
@@ -186,37 +126,32 @@ def generate_dream_response(
186
  alg: str,
187
  alg_temp: Optional[float],
188
  visualization_delay: float
189
- ): # No return type annotation for generators in older Python? Or use -> Iterator[Tuple[...]]
190
  """ Generates text step-by-step and yields visualization states live. """
191
 
192
- # Ensure history is valid before proceeding
193
- if not history or not history[-1] or history[-1][0] is None:
194
- # Yield the current (potentially empty) history back
195
- yield history, [("No valid input message found.", "red")], ""
196
  return
197
 
198
  # --- 1. Preparation ---
199
- # Use the *entire* history received from the state for context
200
- messages_for_template = format_chat_history(history)
201
  parsed_constraints = parse_constraints(constraints_text)
202
 
203
  try:
 
204
  inputs = tokenizer.apply_chat_template(
205
- messages_for_template,
206
  return_tensors="pt",
207
  return_dict=True,
208
- add_generation_prompt=True # This adds the assistant prompt turn
209
  )
210
  input_ids = inputs.input_ids.to(device)
211
  prompt_attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else torch.ones_like(input_ids)
212
- prompt_length = input_ids.shape[1]
213
- # print(f"Prompt length for model: {prompt_length}") # Debug
214
- # print(f"Input IDs to model (first 50): {input_ids[0, :50].tolist()}") # Debug
215
-
216
  except Exception as e:
217
  print(f"Error applying chat template: {e}")
218
- # Yield the current history back with an error message
219
- yield history, [("Error preparing input.", "red")], ""
220
  return
221
 
222
  eps = 1e-3
@@ -227,132 +162,111 @@ def generate_dream_response(
227
  # --- 2. Initialize Generation State ---
228
  total_length = prompt_length + gen_length
229
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
 
230
  x = torch.cat((input_ids, initial_generation_part), dim=1)
231
 
232
- # --- Prepare Attention Mask ---
233
  generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
 
234
  full_attention_mask_long = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1)
 
235
  attention_mask_for_model = full_attention_mask_long.to(model.dtype)
236
  large_neg_val = torch.finfo(model.dtype).min
237
  attention_mask_for_model = (1.0 - attention_mask_for_model) * large_neg_val
238
- attention_mask_for_model = attention_mask_for_model.unsqueeze(1).unsqueeze(2) # Shape [B, 1, 1, N]
239
 
 
240
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
 
 
241
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=-1)
242
 
243
- # --- 3. Visualization & State Setup ---
244
  previous_tokens_vis = None
245
- # Use the passed-in history directly. We will modify the *last* item's assistant response.
246
- # No need for history_copy if we are careful. Let's try modifying `history` directly.
247
- # IMPORTANT: Gradio state needs the component to receive the *entire object* back if it's mutated.
248
- # So yielding the modified `history` list itself should work.
249
- history_for_yield = history # Reference the original list
250
 
251
  # --- 4. Initial Yield (Masked State) ---
252
  initial_generated_tokens = x[0, prompt_length:].cpu()
253
  vis_data_initial = []
254
  for tok_id in initial_generated_tokens.tolist():
255
- vis_data_initial.append((MASK_TOKEN, "#444444"))
 
 
256
  previous_tokens_vis = initial_generated_tokens
257
- # Yield the *current* history (with None for last bot msg)
258
- yield history_for_yield, vis_data_initial, ""
259
  time.sleep(visualization_delay)
260
 
261
  # --- 5. Step-by-Step Diffusion Loop ---
262
  try:
263
  start_time = time.time()
264
- current_response_text = "" # Store intermediate text
265
-
266
  for i in range(steps):
267
  mask_index = (x == MASK_ID)
268
- if not mask_index.any():
269
- print(f"No mask tokens left at step {i}. Stopping early.")
270
- break
271
-
272
- outputs = model(
273
- input_ids=x,
274
- attention_mask=attention_mask_for_model,
275
- position_ids=None, use_cache=False, return_dict=True
276
- )
277
  logits = outputs.logits
278
- logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
279
 
280
  mask_logits = logits[mask_index]
281
- if mask_logits.numel() == 0:
282
- print(f"No masked tokens found for logit selection at step {i}. Stopping.")
283
- break
284
 
285
  t = timesteps[i]; s = timesteps[i + 1]
286
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
287
 
288
- # [Sampling logic remains the same as previous working version]
289
  if alg == 'origin':
290
  p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0
291
  num_masked = mask_logits.shape[0]
292
  transfer_indices_relative = torch.rand(num_masked, device=device) < p_transfer
293
  logits_to_sample = mask_logits[transfer_indices_relative]
294
- if logits_to_sample.numel() > 0:
295
- _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val)
296
- x_new_masked_part[transfer_indices_relative] = sampled_tokens
297
- else: # Confidence-based
298
- use_margin = (alg == 'topk_margin'); use_entropy = (alg == 'entropy')
299
- confidence, x0_candidates = sample_tokens(
300
- mask_logits, temperature=temperature, top_p=top_p_val, top_k=top_k_val,
301
- margin_confidence=use_margin, neg_entropy=use_entropy
302
- )
303
  num_mask_token = mask_logits.shape[0]
304
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
305
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
306
  if number_transfer_tokens > 0:
307
  num_samples = min(number_transfer_tokens, num_mask_token)
308
  if num_samples > 0:
309
- transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # Init empty
310
- if alg_temp_val is None or alg_temp_val <= 0: # Top-k
311
  sort_metric = confidence if alg != 'entropy' else -confidence
312
  k_topk = min(num_samples, sort_metric.numel())
313
  if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
314
- else: # Sampled
315
  if confidence.numel() > 0:
316
- conf_probs = confidence / alg_temp_val
317
- conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9)
318
- conf_probs = torch.clamp(conf_probs - conf_probs.max(), min=-30)
319
- conf_probs = F.softmax(conf_probs, dim=-1)
320
- conf_probs = torch.clamp(conf_probs, min=0.0)
321
- conf_probs = torch.nan_to_num(conf_probs, nan=0.0)
322
- prob_sum = conf_probs.sum()
323
- target_sum_tensor = torch.tensor(1.0, device=device, dtype=prob_sum.dtype)
324
- if not torch.isclose(prob_sum, target_sum_tensor, atol=1e-4) and prob_sum > 0:
325
- safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=device, dtype=prob_sum.dtype))
326
- conf_probs = conf_probs / safe_prob_sum
327
  final_prob_sum_check = conf_probs.sum()
328
  if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(final_prob_sum_check, target_sum_tensor, atol=1e-4):
329
  try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
330
- except RuntimeError as e: print(f"W{i}: Multinomial failed ('{e}'). Fallback.") # Fallback handled below
331
- if transfer_indices_relative.numel() == 0: # Fallback if sampling failed or wasn't possible
332
- sort_metric = confidence if alg != 'entropy' else -confidence
333
- k_fallback = min(num_samples, sort_metric.numel())
334
- if k_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_fallback)
335
  # Apply transfer
336
  if transfer_indices_relative.numel() > 0:
337
- valid_indices = transfer_indices_relative < x0_candidates.shape[0]
338
- valid_transfer_indices = transfer_indices_relative[valid_indices]
339
- if valid_transfer_indices.numel() > 0 and valid_transfer_indices.max() < x_new_masked_part.shape[0]:
340
- x_new_masked_part[valid_transfer_indices] = x0_candidates[valid_transfer_indices].clone()
341
 
 
342
 
343
- x[mask_index] = x_new_masked_part
 
344
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
345
 
346
  # --- Yield Visualization ---
347
  current_generated_tokens = x[0, prompt_length:].cpu()
348
  vis_data = []
349
- # [Visualization formatting logic remains the same]
350
  for j in range(gen_length):
351
  current_tok_id = current_generated_tokens[j].item()
352
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
353
- try:
354
- decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
355
- display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
356
  except Exception: display_token = f"[ID:{current_tok_id}]"
357
  color = None; token_to_display = display_token
358
  if current_tok_id == MASK_ID: color = "#444444"
@@ -361,27 +275,17 @@ def generate_dream_response(
361
  should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or (EOS_ID is not None and current_tok_id == EOS_ID)
362
  if should_hide and previous_tok_id == current_tok_id: token_to_display = ""; color = None
363
  if token_to_display: vis_data.append((token_to_display, color))
364
- # ---
365
 
366
  previous_tokens_vis = current_generated_tokens
367
 
368
- # --- Update intermediate response text ---
369
  intermediate_response_tokens = x[0, prompt_length:]
370
- current_response_text = tokenizer.decode(
371
- intermediate_response_tokens,
372
- skip_special_tokens=True,
373
- clean_up_tokenization_spaces=True
374
- ).strip()
375
-
376
- # --- Update history for yield ---
377
- # Update the placeholder in the *last turn* of the history list
378
- if history_for_yield and history_for_yield[-1]:
379
- history_for_yield[-1][1] = current_response_text + "..." # Indicate streaming
380
-
381
- # --- Yield current state ---
382
- yield history_for_yield, vis_data, current_response_text
383
  time.sleep(visualization_delay)
384
- # --- End loop iteration ---
385
 
386
  end_time = time.time()
387
  print(f"Dream generation finished in {end_time - start_time:.2f} seconds.")
@@ -389,49 +293,38 @@ def generate_dream_response(
389
  # --- 6. Final Processing & Yield ---
390
  final_sequence = x[0]
391
  response_tokens = final_sequence[prompt_length:]
392
- final_response_text = tokenizer.decode(
393
- response_tokens,
394
- skip_special_tokens=True,
395
- clean_up_tokenization_spaces=True
396
- ).strip()
397
 
398
- # Update the history definitively with the final text
399
- if history_for_yield and history_for_yield[-1]:
400
- history_for_yield[-1][1] = final_response_text
401
-
402
- # Format final visualization
403
  final_generated_tokens = x[0, prompt_length:].cpu()
404
  vis_data_final = []
405
- # [Final visualization formatting logic remains the same]
406
  for j in range(gen_length):
407
- current_tok_id = final_generated_tokens[j].item()
408
- previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
409
- try:
410
- decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False)
411
- display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
412
- except Exception: display_token = f"[ID:{current_tok_id}]"
413
- color = None; token_to_display = display_token
414
- if current_tok_id == MASK_ID: color = "#444444"
415
- elif previous_tok_id == MASK_ID: color = "#66CC66"
416
- else: color = "#6699CC"
417
- should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or (EOS_ID is not None and current_tok_id == EOS_ID)
418
- if should_hide and previous_tok_id == current_tok_id: token_to_display = ""; color = None
419
- if token_to_display: vis_data_final.append((token_to_display, color))
420
- # ---
421
-
422
- # Yield the final state
423
- yield history_for_yield, vis_data_final, final_response_text
424
  print("Visualization streaming complete.")
425
 
426
  except Exception as e:
427
  print(f"Error during generation or processing: {e}")
428
  import traceback
429
  traceback.print_exc()
430
- # Ensure the history state reflects the error somehow? Or just yield error vis.
431
- # Yield the history *as it was* when the error occurred.
432
- if history_for_yield and history_for_yield[-1]:
433
- history_for_yield[-1][1] = f"<Error: {e}>" # Put error in bot response
434
- yield history_for_yield, [("Error during generation.", "red")], ""
435
  return
436
 
437
 
@@ -448,17 +341,17 @@ def create_chatbot_demo():
448
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]"
449
  )
450
 
451
- # Use a single state variable for the history list
452
- chat_history_state = gr.State([])
453
 
454
  with gr.Row():
455
  with gr.Column(scale=3):
 
456
  chatbot_ui = gr.Chatbot(
457
  label="Conversation",
 
458
  height=500,
459
  show_copy_button=True,
460
  bubble_full_width=False,
461
- # value=[] # Initial value set by state binding later
462
  )
463
  with gr.Group():
464
  with gr.Row():
@@ -474,12 +367,10 @@ def create_chatbot_demo():
474
  )
475
  with gr.Column(scale=2):
476
  output_vis = gr.HighlightedText(
477
- label="Denoising Process Visualization", combine_adjacent=True,
478
- show_legend=False, interactive=False
479
- )
480
- response_text_display = gr.Textbox(
481
- label="Generated Response (Live)", interactive=False, lines=5
482
  )
 
483
 
484
  with gr.Accordion("Generation Settings", open=False):
485
  # [Settings sliders remain the same]
@@ -497,88 +388,75 @@ def create_chatbot_demo():
497
  with gr.Row():
498
  visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.03, step=0.01, label="Visualization Delay (seconds)")
499
 
500
-
501
  clear_btn = gr.Button("Clear Conversation")
502
 
503
- # --- Event Handler Functions ---
504
 
505
- def add_user_message(message: str, history: List[List[Optional[str]]]):
506
- """
507
- Adds the user message to the history state and prepares the UI
508
- for the bot's response (clearing previous outputs).
509
- """
510
  if not message.strip():
511
  gr.Warning("Please enter a message.")
512
- # Return unchanged history and empty outputs
513
- return history, history, "", [], ""
514
- # Append new turn with user message and None placeholder for bot response
515
- history.append([message, None])
516
- # Return updated history (for state), history (for immediate UI update),
517
- # empty input, empty vis, empty response text.
518
- return history, history, "", [], ""
519
 
520
  def clear_all():
521
- """Clears state and all relevant UI components."""
522
- return [], [], "", [], "" # state, chatbot, input, vis, response text
523
 
524
  # --- Connect UI elements ---
525
 
526
- # Define inputs/outputs for the generator
 
527
  generation_inputs = [
528
- chat_history_state, gen_length, steps, constraints_input,
 
529
  temperature, top_p, top_k, remasking_strategy, alg_temp,
530
  visualization_delay
531
  ]
532
- # Generator yields: history_list, vis_data, response_text
533
- generation_outputs = [chatbot_ui, output_vis, response_text_display]
534
-
535
- # Chain the actions: Submit/Click -> add_user_message -> generate_dream_response
536
-
537
- # 1. User submits message (Enter or Button)
538
- user_interaction = [user_input, chat_history_state]
539
- outputs_after_user_add = [
540
- chat_history_state, # Update the state
541
- chatbot_ui, # Update chatbot UI immediately
542
- user_input, # Clear user input box
543
- output_vis, # Clear visualization
544
- response_text_display # Clear response text box
545
- ]
546
 
 
547
  submit_listener = user_input.submit(
548
- fn=add_user_message,
549
- inputs=user_interaction,
550
- outputs=outputs_after_user_add
551
- ).then( # 2. Trigger generation AFTER user message is added and UI cleared
 
552
  fn=generate_dream_response,
553
- inputs=generation_inputs, # Pass the updated state and parameters
554
- outputs=generation_outputs, # Stream updates to chatbot, vis, text
555
  show_progress="hidden"
556
  )
557
 
 
558
  click_listener = send_btn.click(
559
- fn=add_user_message,
560
- inputs=user_interaction,
561
- outputs=outputs_after_user_add
562
- ).then( # 2. Trigger generation AFTER user message is added and UI cleared
 
563
  fn=generate_dream_response,
564
  inputs=generation_inputs,
565
- outputs=generation_outputs,
566
  show_progress="hidden"
567
  )
568
 
569
- # 3. Clear Button
570
  clear_btn.click(
571
- clear_all,
572
  inputs=[],
573
- outputs=[
574
- chat_history_state, chatbot_ui, user_input,
575
- output_vis, response_text_display
576
- ]
577
  )
578
 
579
  return demo
580
 
581
-
582
  # --- Launch ---
583
  if __name__ == "__main__":
584
  demo = create_chatbot_demo()
 
13
  # --- START: Copied Helper functions from generation_utils.py ---
14
  # [Keep the copied functions: top_p_logits, top_k_logits, sample_tokens]
15
  def top_p_logits(logits, top_p=None):
16
+ if top_p is None or top_p >= 1.0: return logits
 
 
17
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
18
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
19
  sorted_indices_to_remove = cumulative_probs > top_p
20
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone(); sorted_indices_to_remove[..., 0] = 0
21
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device).scatter_(-1, sorted_indices, sorted_indices_to_remove)
22
+ return logits.masked_fill(mask, torch.finfo(logits.dtype).min)
 
 
 
23
 
24
  def top_k_logits(logits, top_k=None):
25
+ if top_k is None or top_k <= 0: return logits
 
 
26
  top_k = min(top_k, logits.size(-1))
27
  indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
28
+ return logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
 
29
 
30
  def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
31
+ if temperature > 0: safe_temp = max(temperature, 1e-6); logits = logits / safe_temp
32
+ if top_p is not None and 0.0 < top_p < 1.0: logits = top_p_logits(logits, top_p)
33
+ if top_k is not None and top_k > 0: logits = top_k_logits(logits, top_k)
 
 
 
 
 
34
  is_all_neg_inf = torch.all(logits == torch.finfo(logits.dtype).min, dim=-1, keepdim=True)
35
+ if torch.any(is_all_neg_inf): uniform_logits = torch.zeros_like(logits); logits = torch.where(is_all_neg_inf, uniform_logits, logits)
 
 
36
  probs = torch.softmax(logits, dim=-1)
37
+ probs = torch.clamp(probs, min=0.0); probs = probs / probs.sum(dim=-1, keepdim=True); probs = torch.nan_to_num(probs, nan=0.0)
 
 
38
  if temperature > 0:
39
+ try: x0 = dists.Categorical(probs=probs).sample(); confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
40
+ except Exception as e: print(f"Warning: Sampling failed: {e}. Argmax fallback."); confidence, x0 = probs.max(dim=-1)
41
+ else: confidence, x0 = probs.max(dim=-1)
42
+ if margin_confidence: sorted_probs, _ = torch.sort(probs, dim=-1, descending=True); top1_probs = sorted_probs[..., 0]; top2_probs = sorted_probs[..., 1] if sorted_probs.shape[-1] > 1 else top1_probs; confidence = top1_probs - top2_probs
43
+ if neg_entropy: epsilon = 1e-10; log_probs = torch.log(probs + epsilon); confidence = torch.sum(probs * log_probs, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
44
  confidence = torch.nan_to_num(confidence, nan=0.0)
45
  return confidence, x0
46
  # --- END: Copied Helper functions ---
47
 
48
+ # [Keep model loading, constants as before]
49
+ # Load model configuration to get special token IDs
50
  config = AutoConfig.from_pretrained("Dream-org/Dream-v0-Instruct-7B", trust_remote_code=True)
51
  model_path = "Dream-org/Dream-v0-Instruct-7B"
52
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
75
  SPECIAL_TOKEN_IDS.add(IM_END_ID)
76
  except KeyError: IM_START_ID, IM_END_ID = None, None
77
 
 
78
  # --- Helper Functions ---
79
  def parse_constraints(constraints_text: str) -> Dict[int, List[int]]:
 
80
  constraints = {}
81
  if not constraints_text: return constraints
82
  parts = constraints_text.split(',')
 
88
  pos = int(pos_str.strip())
89
  word = word.strip()
90
  token_ids = []
91
+ if word: text_to_encode = (" " + word) if (pos > 0 and not word.startswith(" ")) else word; token_ids = tokenizer.encode(text_to_encode, add_special_tokens=False)
 
 
92
  if token_ids and pos >= 0: constraints[pos] = token_ids
93
  elif not token_ids and word: print(f"Warning: Could not tokenize constraint word '{word}'")
94
  except ValueError: print(f"Warning: Invalid position '{pos_str}' in constraint part '{part}'")
95
  except Exception as e: print(f"Warning: Error processing constraint '{part}': {e}")
96
  return constraints
97
 
98
+ # Removed format_chat_history as history will be in the correct format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def apply_constraints_to_state(
101
  x: torch.Tensor, prompt_length: int, total_length: int,
102
  parsed_constraints: Dict[int, List[int]], current_step: Optional[int] = None
103
  ) -> torch.Tensor:
 
104
  modified_x = x.clone()
105
  for rel_pos, word_token_ids in parsed_constraints.items():
106
+ abs_start_pos = prompt_length + rel_pos; abs_end_pos = abs_start_pos + len(word_token_ids)
 
107
  if abs_start_pos < total_length and abs_end_pos <= total_length:
108
+ try: constraint_tensor = torch.tensor(word_token_ids, dtype=torch.long, device=modified_x.device); modified_x[0, abs_start_pos:abs_end_pos] = constraint_tensor
109
+ except IndexError: print(f"Warning (Step {current_step}): Constraint idx error at {rel_pos}")
110
+ except Exception as e: print(f"Warning (Step {current_step}): Constraint apply error at {rel_pos}: {e}")
 
 
111
  return modified_x
112
 
113
 
 
116
  @spaces.GPU
117
  @torch.no_grad()
118
  def generate_dream_response(
119
+ history: List[Dict[str, str]], # MODIFIED: Expect List[Dict]
120
  gen_length: int,
121
  steps: int,
122
  constraints_text: str,
 
126
  alg: str,
127
  alg_temp: Optional[float],
128
  visualization_delay: float
129
+ ): # Removed -> type hint for cleaner yield handling
130
  """ Generates text step-by-step and yields visualization states live. """
131
 
132
+ if not history or history[-1]["role"] != "user": # Check last message is from user
133
+ yield history, [("No user message found to respond to.", "red")]
 
 
134
  return
135
 
136
  # --- 1. Preparation ---
137
+ # History is already formatted for the template
 
138
  parsed_constraints = parse_constraints(constraints_text)
139
 
140
  try:
141
+ # apply_chat_template expects List[Dict[str, str]]
142
  inputs = tokenizer.apply_chat_template(
143
+ history, # Use history directly
144
  return_tensors="pt",
145
  return_dict=True,
146
+ add_generation_prompt=True # Crucial: Adds the "<|im_start|>assistant\n" prompt
147
  )
148
  input_ids = inputs.input_ids.to(device)
149
  prompt_attention_mask = inputs.attention_mask.to(device) if 'attention_mask' in inputs else torch.ones_like(input_ids)
150
+ prompt_length = input_ids.shape[1] # Length *after* adding the generation prompt
 
 
 
151
  except Exception as e:
152
  print(f"Error applying chat template: {e}")
153
+ # Yield current history and error message for visualization
154
+ yield history, [("Error preparing input.", "red")]
155
  return
156
 
157
  eps = 1e-3
 
162
  # --- 2. Initialize Generation State ---
163
  total_length = prompt_length + gen_length
164
  initial_generation_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
165
+ # input_ids already includes the assistant prompt, so just append masks
166
  x = torch.cat((input_ids, initial_generation_part), dim=1)
167
 
168
+ # --- Prepare Attention Mask for SDPA ---
169
  generation_attention_mask = torch.ones((1, gen_length), dtype=torch.long, device=device)
170
+ # prompt_attention_mask corresponds to input_ids (which includes assistant prompt)
171
  full_attention_mask_long = torch.cat((prompt_attention_mask, generation_attention_mask), dim=1)
172
+
173
  attention_mask_for_model = full_attention_mask_long.to(model.dtype)
174
  large_neg_val = torch.finfo(model.dtype).min
175
  attention_mask_for_model = (1.0 - attention_mask_for_model) * large_neg_val
176
+ attention_mask_for_model = attention_mask_for_model.unsqueeze(1).unsqueeze(2) # [B, 1, 1, N]
177
 
178
+ # --- Timesteps ---
179
  timesteps = torch.linspace(1, eps, steps + 1, device=device)
180
+
181
+ # Apply initial constraints (relative to start of generation = prompt_length)
182
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=-1)
183
 
184
+ # --- 3. Visualization & History Setup ---
185
  previous_tokens_vis = None
186
+ # MODIFIED: Append placeholder assistant message to the history state *before* looping
187
+ history.append({"role": "assistant", "content": ""})
 
 
 
188
 
189
  # --- 4. Initial Yield (Masked State) ---
190
  initial_generated_tokens = x[0, prompt_length:].cpu()
191
  vis_data_initial = []
192
  for tok_id in initial_generated_tokens.tolist():
193
+ display_token = MASK_TOKEN; color = "#444444"
194
+ vis_data_initial.append((display_token, color))
195
+
196
  previous_tokens_vis = initial_generated_tokens
197
+ # Yield the history (which now includes the empty assistant message) and initial vis
198
+ yield history, vis_data_initial
199
  time.sleep(visualization_delay)
200
 
201
  # --- 5. Step-by-Step Diffusion Loop ---
202
  try:
203
  start_time = time.time()
 
 
204
  for i in range(steps):
205
  mask_index = (x == MASK_ID)
206
+ if not mask_index.any(): break # Stop early
207
+
208
+ outputs = model(input_ids=x, attention_mask=attention_mask_for_model, return_dict=True)
 
 
 
 
 
 
209
  logits = outputs.logits
210
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1) # Align logits
211
 
212
  mask_logits = logits[mask_index]
213
+ if mask_logits.numel() == 0: break # Stop early
 
 
214
 
215
  t = timesteps[i]; s = timesteps[i + 1]
216
  x_new_masked_part = torch.full_like(x[mask_index], MASK_ID, device=device, dtype=torch.long)
217
 
218
+ # [Keep sampling/remasking logic ('origin' and confidence-based) exactly the same]
219
  if alg == 'origin':
220
  p_transfer = (1.0 - s / t) if i < steps - 1 else 1.0
221
  num_masked = mask_logits.shape[0]
222
  transfer_indices_relative = torch.rand(num_masked, device=device) < p_transfer
223
  logits_to_sample = mask_logits[transfer_indices_relative]
224
+ if logits_to_sample.numel() > 0: _, sampled_tokens = sample_tokens(logits_to_sample, temperature=temperature, top_p=top_p_val, top_k=top_k_val); x_new_masked_part[transfer_indices_relative] = sampled_tokens
225
+ else:
226
+ use_margin=(alg == 'topk_margin'); use_entropy=(alg == 'entropy')
227
+ confidence, x0_candidates = sample_tokens(mask_logits, temperature=temperature, top_p=top_p_val, top_k=top_k_val, margin_confidence=use_margin, neg_entropy=use_entropy)
 
 
 
 
 
228
  num_mask_token = mask_logits.shape[0]
229
  target_num_revealed_float = num_mask_token * (1.0 - s / t)
230
  number_transfer_tokens = int(target_num_revealed_float) if i < steps - 1 else num_mask_token
231
  if number_transfer_tokens > 0:
232
  num_samples = min(number_transfer_tokens, num_mask_token)
233
  if num_samples > 0:
234
+ transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
235
+ if alg_temp_val is None or alg_temp_val <= 0: # Top-k confidence
236
  sort_metric = confidence if alg != 'entropy' else -confidence
237
  k_topk = min(num_samples, sort_metric.numel())
238
  if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
239
+ else: # Sample based on confidence temperature
240
  if confidence.numel() > 0:
241
+ conf_probs = confidence / alg_temp_val; conf_probs = torch.nan_to_num(conf_probs, nan=0.0, posinf=1e9, neginf=-1e9); conf_probs = torch.clamp(conf_probs - conf_probs.max(), min=-30); conf_probs = F.softmax(conf_probs, dim=-1); conf_probs = torch.clamp(conf_probs, min=0.0); conf_probs = torch.nan_to_num(conf_probs, nan=0.0)
242
+ prob_sum = conf_probs.sum(); target_sum_tensor = torch.tensor(1.0, device=device, dtype=prob_sum.dtype)
243
+ if not torch.isclose(prob_sum, target_sum_tensor, atol=1e-4) and prob_sum > 0: safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=device, dtype=prob_sum.dtype)); conf_probs = conf_probs / safe_prob_sum
 
 
 
 
 
 
 
 
244
  final_prob_sum_check = conf_probs.sum()
245
  if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(final_prob_sum_check, target_sum_tensor, atol=1e-4):
246
  try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
247
+ except RuntimeError as e: print(f"Warning step {i}: Multinomial failed ('{e}'). Fallback."); sort_metric = confidence if alg != 'entropy' else -confidence; k_fallback = min(num_samples, sort_metric.numel()); if k_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_fallback)
248
+ else: sort_metric = confidence if alg != 'entropy' else -confidence; k_fallback = min(num_samples, sort_metric.numel()); if k_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_fallback)
 
 
 
249
  # Apply transfer
250
  if transfer_indices_relative.numel() > 0:
251
+ valid_indices = transfer_indices_relative < x0_candidates.shape[0]; valid_transfer_indices = transfer_indices_relative[valid_indices]
252
+ if valid_transfer_indices.numel() > 0:
253
+ if valid_transfer_indices.max() < x_new_masked_part.shape[0]: x_new_masked_part[valid_transfer_indices] = x0_candidates[valid_transfer_indices].clone()
254
+ else: print(f"Warning step {i}: transfer_indices OOB for x_new_masked_part.")
255
 
256
+ x[mask_index] = x_new_masked_part # Update state
257
 
258
+ # --- Apply Constraints ---
259
+ # Remember prompt_length now includes the assistant prompt turn
260
  x = apply_constraints_to_state(x, prompt_length, total_length, parsed_constraints, current_step=i)
261
 
262
  # --- Yield Visualization ---
263
  current_generated_tokens = x[0, prompt_length:].cpu()
264
  vis_data = []
265
+ # [Keep visualization formatting logic the same]
266
  for j in range(gen_length):
267
  current_tok_id = current_generated_tokens[j].item()
268
  previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
269
+ try: decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False); display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
 
 
270
  except Exception: display_token = f"[ID:{current_tok_id}]"
271
  color = None; token_to_display = display_token
272
  if current_tok_id == MASK_ID: color = "#444444"
 
275
  should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or (EOS_ID is not None and current_tok_id == EOS_ID)
276
  if should_hide and previous_tok_id == current_tok_id: token_to_display = ""; color = None
277
  if token_to_display: vis_data.append((token_to_display, color))
 
278
 
279
  previous_tokens_vis = current_generated_tokens
280
 
281
+ # MODIFIED: Update the *content* of the last history item
282
  intermediate_response_tokens = x[0, prompt_length:]
283
+ intermediate_response_text = tokenizer.decode(intermediate_response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
284
+ history[-1]["content"] = intermediate_response_text # Update last dict entry
285
+
286
+ # Yield the updated history list and current vis data
287
+ yield history, vis_data
 
 
 
 
 
 
 
 
288
  time.sleep(visualization_delay)
 
289
 
290
  end_time = time.time()
291
  print(f"Dream generation finished in {end_time - start_time:.2f} seconds.")
 
293
  # --- 6. Final Processing & Yield ---
294
  final_sequence = x[0]
295
  response_tokens = final_sequence[prompt_length:]
296
+ final_response_text = tokenizer.decode(response_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
297
+ # Update the final content in the history object
298
+ history[-1]["content"] = final_response_text
 
 
299
 
 
 
 
 
 
300
  final_generated_tokens = x[0, prompt_length:].cpu()
301
  vis_data_final = []
302
+ # [Keep final visualization formatting logic the same]
303
  for j in range(gen_length):
304
+ current_tok_id = final_generated_tokens[j].item()
305
+ previous_tok_id = previous_tokens_vis[j].item() if previous_tokens_vis is not None and j < len(previous_tokens_vis) else MASK_ID
306
+ try: decoded_token = tokenizer.decode([current_tok_id], skip_special_tokens=False, clean_up_tokenization_spaces=False); display_token = MASK_TOKEN if current_tok_id == MASK_ID else decoded_token
307
+ except Exception: display_token = f"[ID:{current_tok_id}]"
308
+ color = None; token_to_display = display_token
309
+ if current_tok_id == MASK_ID: color = "#444444"
310
+ elif previous_tok_id == MASK_ID: color = "#66CC66"
311
+ else: color = "#6699CC"
312
+ should_hide = (PAD_ID is not None and current_tok_id == PAD_ID) or (EOS_ID is not None and current_tok_id == EOS_ID)
313
+ if should_hide and previous_tok_id == current_tok_id: token_to_display = ""; color = None
314
+ if token_to_display: vis_data_final.append((token_to_display, color))
315
+
316
+ # Yield final history and visualization
317
+ yield history, vis_data_final
 
 
 
318
  print("Visualization streaming complete.")
319
 
320
  except Exception as e:
321
  print(f"Error during generation or processing: {e}")
322
  import traceback
323
  traceback.print_exc()
324
+ # Set error message in the last history item? Or yield separate error?
325
+ # Let's just yield the current history and error vis
326
+ history[-1]["content"] = f"Error: {e}" # Put error in assistant message
327
+ yield history, [("Error during generation.", "red")]
 
328
  return
329
 
330
 
 
341
  "[[Blog](https://hkunlp.github.io/blog/2025/dream/)]"
342
  )
343
 
344
+ # STATE: No explicit state needed if chatbot manages it via input/output
 
345
 
346
  with gr.Row():
347
  with gr.Column(scale=3):
348
+ # MODIFIED: Use type="messages"
349
  chatbot_ui = gr.Chatbot(
350
  label="Conversation",
351
+ type="messages", # Use dictionary format
352
  height=500,
353
  show_copy_button=True,
354
  bubble_full_width=False,
 
355
  )
356
  with gr.Group():
357
  with gr.Row():
 
367
  )
368
  with gr.Column(scale=2):
369
  output_vis = gr.HighlightedText(
370
+ label="Denoising Process Visualization",
371
+ combine_adjacent=True, show_legend=False, interactive=False
 
 
 
372
  )
373
+ # REMOVED: Separate response text display
374
 
375
  with gr.Accordion("Generation Settings", open=False):
376
  # [Settings sliders remain the same]
 
388
  with gr.Row():
389
  visualization_delay = gr.Slider(minimum=0.0, maximum=0.5, value=0.03, step=0.01, label="Visualization Delay (seconds)")
390
 
 
391
  clear_btn = gr.Button("Clear Conversation")
392
 
393
+ # --- Event Handlers ---
394
 
395
+ # MODIFIED: add_user_message uses dictionary format
396
+ def add_user_message(message: str, history: List[Dict[str, str]]):
397
+ """Adds user message in dictionary format, clears input."""
 
 
398
  if not message.strip():
399
  gr.Warning("Please enter a message.")
400
+ return history, "" # Return unchanged history, don't clear input here
401
+ # Append user message as a dictionary
402
+ history.append({"role": "user", "content": message})
403
+ # Return updated history, clear input box
404
+ return history, ""
 
 
405
 
406
  def clear_all():
407
+ """Clears chatbot, visualization, and input."""
408
+ return [], [], "" # Chatbot, Vis, Input
409
 
410
  # --- Connect UI elements ---
411
 
412
+ # Define the inputs for the generation function
413
+ # MODIFIED: Input is chatbot_ui (provides List[Dict])
414
  generation_inputs = [
415
+ chatbot_ui, # Get history directly from chatbot component
416
+ gen_length, steps, constraints_input,
417
  temperature, top_p, top_k, remasking_strategy, alg_temp,
418
  visualization_delay
419
  ]
420
+ # Define the outputs for the generation function
421
+ # MODIFIED: Output history (List[Dict]) to chatbot_ui, vis_data to output_vis
422
+ generation_outputs = [chatbot_ui, output_vis]
 
 
 
 
 
 
 
 
 
 
 
423
 
424
+ # Handle Textbox Submission (Enter key)
425
  submit_listener = user_input.submit(
426
+ fn=add_user_message, # Use modified function
427
+ inputs=[user_input, chatbot_ui], # Pass chatbot state
428
+ outputs=[chatbot_ui, user_input], # Update chatbot state, clear input
429
+ queue=False # User message add should be quick
430
+ ).then(
431
  fn=generate_dream_response,
432
+ inputs=generation_inputs,
433
+ outputs=generation_outputs, # Stream history to chatbot, vis to output_vis
434
  show_progress="hidden"
435
  )
436
 
437
+ # Handle Send Button Click
438
  click_listener = send_btn.click(
439
+ fn=add_user_message, # Use modified function
440
+ inputs=[user_input, chatbot_ui], # Pass chatbot state
441
+ outputs=[chatbot_ui, user_input], # Update chatbot state, clear input
442
+ queue=False # User message add should be quick
443
+ ).then(
444
  fn=generate_dream_response,
445
  inputs=generation_inputs,
446
+ outputs=generation_outputs, # Stream history to chatbot, vis to output_vis
447
  show_progress="hidden"
448
  )
449
 
450
+ # Clear Button Action
451
  clear_btn.click(
452
+ clear_all, # Use modified clear function
453
  inputs=[],
454
+ outputs=[chatbot_ui, output_vis, user_input], # Clear chatbot, vis, input
455
+ queue=False
 
 
456
  )
457
 
458
  return demo
459
 
 
460
  # --- Launch ---
461
  if __name__ == "__main__":
462
  demo = create_chatbot_demo()