Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,96 +1,62 @@
|
|
|
|
1 |
import torch
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
-
from transformers import AutoTokenizer, AutoModel
|
6 |
import time
|
7 |
-
import re
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
try:
|
11 |
-
|
12 |
-
|
13 |
-
print("Running in Gradio Spaces with GPU environment.")
|
14 |
except AttributeError:
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
def gpu_check(func):
|
19 |
-
return func
|
20 |
-
|
21 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
-
print(f"Using device: {device}")
|
23 |
-
|
24 |
-
# --- Load DREAM Model and Tokenizer ---
|
25 |
-
# Ensure sufficient VRAM, Dream 7B needs ~16GB+ VRAM in bfloat16
|
26 |
-
model_path = "Dream-org/Dream-v0-Instruct-7B"
|
27 |
-
print(f"Loading model: {model_path}...")
|
28 |
try:
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
raise e
|
45 |
-
|
46 |
-
|
47 |
-
# --- Constants for DREAM ---
|
48 |
-
# Find the mask token and ID from the DREAM tokenizer
|
49 |
-
if tokenizer.mask_token is None:
|
50 |
-
print("Warning: Mask token not found in tokenizer. Attempting to add '[MASK]'.")
|
51 |
-
# This might require retraining or fine-tuning if the model didn't see this token
|
52 |
-
num_added = tokenizer.add_special_tokens({'mask_token': '[MASK]'})
|
53 |
-
if num_added > 0:
|
54 |
-
print(f"Added '{tokenizer.mask_token}' to tokenizer.")
|
55 |
-
# Resize model embeddings if vocab changed
|
56 |
-
model.resize_token_embeddings(len(tokenizer))
|
57 |
-
print("Resized model token embeddings.")
|
58 |
-
else:
|
59 |
-
# Fallback or error if adding failed or mask token still None
|
60 |
-
# It's possible a different token serves this purpose in DREAM's training
|
61 |
-
print("Error: Could not set a mask token. Visualization might be inaccurate.")
|
62 |
-
# You might need to identify which token ID DREAM uses internally for masking tasks
|
63 |
-
# For now, we'll proceed but this is a potential issue.
|
64 |
-
MASK_TOKEN = "<?>" # Placeholder symbol
|
65 |
-
MASK_ID = -1 # Invalid ID indicates issue
|
66 |
-
if tokenizer.mask_token is None:
|
67 |
-
raise ValueError("Could not set a mask token for the tokenizer.")
|
68 |
-
|
69 |
-
MASK_TOKEN = tokenizer.mask_token
|
70 |
-
MASK_ID = tokenizer.mask_token_id
|
71 |
-
print(f"Using MASK_TOKEN='{MASK_TOKEN}' with ID={MASK_ID}")
|
72 |
-
|
73 |
-
# Identify other special tokens to potentially hide/show
|
74 |
-
eos_token_id = tokenizer.eos_token_id
|
75 |
-
pad_token_id = tokenizer.pad_token_id
|
76 |
-
special_token_ids_set = {MASK_ID} # Start with Mask ID
|
77 |
-
if eos_token_id is not None:
|
78 |
-
special_token_ids_set.add(eos_token_id)
|
79 |
-
print(f"EOS token ID: {eos_token_id} ({tokenizer.decode([eos_token_id])})")
|
80 |
-
if pad_token_id is not None:
|
81 |
-
special_token_ids_set.add(pad_token_id)
|
82 |
-
print(f"PAD token ID: {pad_token_id} ({tokenizer.decode([pad_token_id])})")
|
83 |
-
# Add other common special tokens if needed (e.g., BOS, UNK)
|
84 |
-
if tokenizer.bos_token_id is not None:
|
85 |
-
special_token_ids_set.add(tokenizer.bos_token_id)
|
86 |
-
print(f"BOS token ID: {tokenizer.bos_token_id} ({tokenizer.decode([tokenizer.bos_token_id])})")
|
87 |
-
if tokenizer.unk_token_id is not None:
|
88 |
-
special_token_ids_set.add(tokenizer.unk_token_id)
|
89 |
-
print(f"UNK token ID: {tokenizer.unk_token_id} ({tokenizer.decode([tokenizer.unk_token_id])})")
|
90 |
-
|
91 |
-
print(f"Identified special token IDs: {special_token_ids_set}")
|
92 |
-
|
93 |
-
# --- Helper Functions (Constraint Parsing, History Formatting) ---
|
94 |
|
95 |
def parse_constraints(constraints_text):
|
96 |
"""Parse constraints in format: 'position:word, position:word, ...'"""
|
@@ -100,674 +66,494 @@ def parse_constraints(constraints_text):
|
|
100 |
|
101 |
parts = constraints_text.split(',')
|
102 |
for part in parts:
|
103 |
-
part = part.strip() # Trim whitespace
|
104 |
if ':' not in part:
|
105 |
continue
|
106 |
try:
|
107 |
pos_str, word = part.split(':', 1)
|
108 |
pos = int(pos_str.strip())
|
109 |
word = word.strip()
|
110 |
-
# Allow empty words if needed? Forcing empty seems odd. Let's require a word.
|
111 |
if word and pos >= 0:
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
except ValueError:
|
114 |
-
|
|
|
|
|
115 |
continue
|
116 |
|
117 |
return constraints
|
118 |
|
119 |
def format_chat_history(history):
|
120 |
"""
|
121 |
-
Format chat history for the
|
122 |
|
123 |
Args:
|
124 |
history: List of [user_message, assistant_message] pairs
|
125 |
|
126 |
Returns:
|
127 |
-
Formatted
|
128 |
"""
|
129 |
messages = []
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
133 |
for user_msg, assistant_msg in history:
|
134 |
-
if user_msg
|
135 |
messages.append({"role": "user", "content": user_msg})
|
136 |
-
if assistant_msg:
|
137 |
messages.append({"role": "assistant", "content": assistant_msg})
|
138 |
|
139 |
return messages
|
140 |
|
141 |
-
# --- Core Generation Logic
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
constraints=None,
|
150 |
-
temperature=0.6, # Default
|
151 |
-
top_p=0.95,
|
152 |
-
alg="entropy",
|
153 |
-
alg_temp=0.1,
|
154 |
):
|
155 |
"""
|
156 |
-
Generate text with
|
157 |
|
158 |
Args:
|
159 |
-
messages: List of message dictionaries with 'role' and 'content'
|
160 |
-
|
161 |
-
steps:
|
162 |
-
constraints: Dictionary mapping positions (relative to response start) to
|
163 |
-
temperature: Sampling temperature
|
164 |
-
top_p: Nucleus sampling p
|
165 |
-
alg: Remasking algorithm ('origin', '
|
166 |
-
alg_temp: Temperature for confidence-based algorithms
|
167 |
|
168 |
Returns:
|
169 |
Tuple: (List of visualization states, final generated text string)
|
170 |
"""
|
171 |
-
|
172 |
-
print(f"Params: len={gen_length}, steps={steps}, temp={temperature}, top_p={top_p}, alg='{alg}', alg_temp={alg_temp}")
|
173 |
-
print(f"Constraints: {constraints}")
|
174 |
-
|
175 |
-
# --- Input Preparation ---
|
176 |
if constraints is None:
|
177 |
constraints = {}
|
178 |
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
if not tokens:
|
188 |
-
print(f" Warning: Could not tokenize constraint word '{word}' at position {pos}. Skipping.")
|
189 |
-
continue
|
190 |
-
print(f" Pos {pos}, Word '{word}' -> Tokens {tokens} ({tokenizer.convert_ids_to_tokens(tokens)})")
|
191 |
-
constraint_token_lengths[pos] = len(tokens)
|
192 |
-
for i, token_id in enumerate(tokens):
|
193 |
-
target_pos = pos + i
|
194 |
-
if target_pos in processed_constraints:
|
195 |
-
print(f" Warning: Overlapping constraint token at position {target_pos}. Keeping first constraint's token ({processed_constraints[target_pos]}).")
|
196 |
-
else:
|
197 |
-
processed_constraints[target_pos] = token_id
|
198 |
|
199 |
# Prepare the prompt using chat template
|
|
|
200 |
try:
|
201 |
inputs = tokenizer.apply_chat_template(
|
202 |
messages,
|
203 |
return_tensors="pt",
|
204 |
-
|
205 |
-
|
206 |
)
|
207 |
-
input_ids = inputs.input_ids.to(device=
|
208 |
-
#
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
# print(f"Attention Mask: {attention_mask}") # Verify mask covers prompt
|
214 |
except Exception as e:
|
215 |
print(f"Error applying chat template: {e}")
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
gen_length = model_max_length - prompt_length
|
223 |
-
if gen_length <= 0:
|
224 |
-
print("Error: Prompt is already too long.")
|
225 |
-
return [([("Prompt too long.", "red")],)], "Error: Prompt too long."
|
226 |
-
|
227 |
-
# --- State for Visualization Hook ---
|
228 |
-
visualization_states = []
|
229 |
-
last_x = None # Store the full sequence (prompt + generation) from the previous step
|
230 |
-
|
231 |
-
# Initial state: Prompt + all masks for generation part
|
232 |
-
initial_gen_part = torch.full((1, gen_length), MASK_ID, dtype=torch.long, device=device)
|
233 |
-
# Apply initial constraints to the masked part *before* the first visualization state
|
234 |
-
for pos, token_id in processed_constraints.items():
|
235 |
-
absolute_pos = pos # Position relative to start of generation
|
236 |
-
if 0 <= absolute_pos < gen_length:
|
237 |
-
initial_gen_part[0, absolute_pos] = token_id
|
238 |
-
|
239 |
-
# Create the first visualization state (only the generation part)
|
240 |
-
initial_state_vis = []
|
241 |
-
for i in range(gen_length):
|
242 |
-
token_id = initial_gen_part[0, i].item()
|
243 |
-
if token_id == MASK_ID:
|
244 |
-
initial_state_vis.append((MASK_TOKEN, "#444444")) # Mask color
|
245 |
-
else:
|
246 |
-
# This must be a constraint applied initially
|
247 |
-
# Decode without skipping special to see raw constraint if needed
|
248 |
-
token_str = tokenizer.decode([token_id], skip_special_tokens=False).strip()
|
249 |
-
initial_state_vis.append((token_str if token_str else "?", "#800080")) # Constraint color (purple)
|
250 |
-
visualization_states.append(initial_state_vis)
|
251 |
|
252 |
# --- Define the Hook Function ---
|
253 |
def generation_tokens_hook_func(step, x, logits):
|
254 |
-
|
255 |
-
#
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
#
|
268 |
-
|
269 |
-
|
270 |
-
if
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
is_constrained = i in processed_constraints
|
295 |
-
is_special = current_token_id in special_token_ids_set
|
296 |
-
is_mask = current_token_id == MASK_ID
|
297 |
-
was_mask = last_token_id == MASK_ID or last_x is None # Treat first step as coming from mask
|
298 |
-
|
299 |
-
display_token = ""
|
300 |
-
color = ""
|
301 |
-
|
302 |
-
# Determine display token and color based on state transitions
|
303 |
-
if is_mask:
|
304 |
-
display_token = MASK_TOKEN
|
305 |
-
color = "#444444" # Dark Gray
|
306 |
-
elif is_constrained and processed_constraints[i] == current_token_id:
|
307 |
-
# Always show the constrained token, color purple
|
308 |
-
# Decide whether to show raw special tokens when constrained
|
309 |
-
raw_decode = tokenizer.decode([current_token_id], skip_special_tokens=False).strip()
|
310 |
-
display_token = raw_decode if raw_decode else "?"
|
311 |
-
color = "#800080" # Purple
|
312 |
-
elif is_special:
|
313 |
-
if was_mask:
|
314 |
-
# Newly revealed special token: Show its representation once
|
315 |
-
display_token = tokenizer.decode([current_token_id], skip_special_tokens=False).strip() # Show raw special
|
316 |
-
color = "#FF8C00" # DarkOrange
|
317 |
else:
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
color
|
329 |
-
|
330 |
-
#
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
try:
|
348 |
-
|
349 |
-
#
|
350 |
-
# It should represent the state *before* the first diffusion step.
|
351 |
-
initial_full_x = torch.cat([input_ids, initial_gen_part], dim=1)
|
352 |
-
last_x = initial_full_x.clone() # Initialize last_x with prompt + initial masked/constrained gen part
|
353 |
-
|
354 |
output = model.diffusion_generate(
|
355 |
-
|
356 |
-
attention_mask=
|
357 |
-
max_new_tokens=
|
358 |
-
output_history=False,
|
359 |
return_dict_in_generate=True,
|
360 |
steps=steps,
|
361 |
temperature=temperature,
|
362 |
top_p=top_p,
|
363 |
alg=alg,
|
364 |
-
# alg_temp
|
365 |
-
alg_temp=alg_temp if alg != "origin" else 0.0,
|
366 |
generation_tokens_hook_func=generation_tokens_hook_func
|
367 |
-
# Ensure generation doesn't run past eos_token if not desired
|
368 |
-
# eos_token_id=eos_token_id, # This might stop generation early
|
369 |
-
# pad_token_id=tokenizer.eos_token_id # Often pad is same as eos for LLMs
|
370 |
)
|
371 |
-
print("model.diffusion_generate finished.")
|
372 |
-
|
373 |
-
# Extract final generated sequence (response part only)
|
374 |
-
# The hook ensures the returned sequence has constraints applied
|
375 |
-
final_sequence = output.sequences[0]
|
376 |
-
# Handle potential length mismatch if generation stopped early
|
377 |
-
actual_gen_len = final_sequence.shape[0] - prompt_length
|
378 |
-
response_token_ids = final_sequence[prompt_length:]
|
379 |
-
|
380 |
-
# Decode the final response, skipping special tokens like EOS/PAD
|
381 |
-
final_text = tokenizer.decode(
|
382 |
-
response_token_ids,
|
383 |
-
skip_special_tokens=True,
|
384 |
-
clean_up_tokenization_spaces=True # Recommended for cleaner output
|
385 |
-
).strip()
|
386 |
-
print(f"Final generated text: '{final_text}'")
|
387 |
-
|
388 |
-
# Add the very final state to visualization if the hook didn't capture it
|
389 |
-
# (Mainly a safeguard, hook should run 'steps' times or until completion)
|
390 |
-
if len(visualization_states) <= steps: # Hook might run 'steps' times
|
391 |
-
final_state_vis = []
|
392 |
-
final_gen_part = response_token_ids # Use the extracted response tokens
|
393 |
-
|
394 |
-
for i in range(len(final_gen_part)): # Iterate over actual generated tokens
|
395 |
-
token_id = final_gen_part[i].item()
|
396 |
-
is_constrained = i in processed_constraints
|
397 |
-
is_special = token_id in special_token_ids_set
|
398 |
-
is_mask = token_id == MASK_ID # Should not happen in final output
|
399 |
-
|
400 |
-
display_token = ""
|
401 |
-
color = ""
|
402 |
-
|
403 |
-
if is_mask: color = "#444444"; display_token = MASK_TOKEN
|
404 |
-
elif is_constrained and processed_constraints.get(i) == token_id:
|
405 |
-
raw_decode = tokenizer.decode([token_id], skip_special_tokens=False).strip()
|
406 |
-
display_token = raw_decode if raw_decode else "?"; color = "#800080" # Purple
|
407 |
-
elif is_special:
|
408 |
-
# Hide special tokens in the *final* display state for cleaner look
|
409 |
-
display_token = " "; color = "#6699CC" # Hide as 'Old' blue
|
410 |
-
else:
|
411 |
-
display_token = tokenizer.decode([token_id], skip_special_tokens=True).strip()
|
412 |
-
color = "#6699CC" # Final state uses 'Old' blue
|
413 |
-
|
414 |
-
if not display_token: display_token = "?" # Fallback
|
415 |
-
final_state_vis.append((display_token, color))
|
416 |
-
|
417 |
-
# Pad the final state visualization if actual gen len < requested gen_length
|
418 |
-
# This shouldn't be necessary if HighlightedText handles shorter lists
|
419 |
-
# while len(final_state_vis) < gen_length:
|
420 |
-
# final_state_vis.append((" ", "#FFFFFF")) # Add empty space
|
421 |
-
|
422 |
-
if final_state_vis: # Only append if we generated something
|
423 |
-
visualization_states.append(final_state_vis)
|
424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
except Exception as e:
|
427 |
-
print(f"
|
428 |
import traceback
|
429 |
traceback.print_exc()
|
430 |
-
#
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
if not isinstance(final_text, str): final_text = str(final_text)
|
440 |
|
441 |
-
return visualization_states, final_text
|
442 |
|
|
|
443 |
|
444 |
-
|
|
|
445 |
|
446 |
css = '''
|
447 |
.category-legend{display:none}
|
448 |
-
|
449 |
-
body, .gradio-container { font-size: 105%; }
|
450 |
-
/* Make buttons slightly larger */
|
451 |
-
/* button { min-height: 40px; } */
|
452 |
-
.small_btn {
|
453 |
-
min-width: 60px; /* Adjust as needed */
|
454 |
-
max-width: 100px;
|
455 |
-
height: 42px; /* Adjust height */
|
456 |
-
flex-grow: 0 !important; /* Prevent button from growing */
|
457 |
-
margin-left: 5px !important; /* Add some space */
|
458 |
-
font-size: 100%; /* Match button font size */
|
459 |
-
padding: 0 10px !important; /* Adjust padding */
|
460 |
-
}
|
461 |
-
.chat-input-row {
|
462 |
-
display: flex;
|
463 |
-
align-items: center; /* Vertically align items */
|
464 |
-
margin-top: 10px; /* Add space above input row */
|
465 |
-
}
|
466 |
-
/* Ensure Textbox takes up space */
|
467 |
-
.chat-input-row .gr-textbox {
|
468 |
-
flex-grow: 1;
|
469 |
-
margin-right: 5px;
|
470 |
-
}
|
471 |
-
/* Chatbot styling */
|
472 |
-
.gr-chatbot .message {
|
473 |
-
font-size: 100%; /* Ensure chat message font size is reasonable */
|
474 |
-
padding: 10px !important;
|
475 |
-
border-radius: 8px !important;
|
476 |
-
}
|
477 |
-
.gr-chatbot .message.user { background-color: #E0F7FA !important; align-self: flex-end; } /* Light cyan for user */
|
478 |
-
.gr-chatbot .message.bot { background-color: #F1F8E9 !important; align-self: flex-start; } /* Light green for bot */
|
479 |
-
/* HighlightedText styling */
|
480 |
-
.gr-highlightedtext span {
|
481 |
-
padding: 1px 2px; /* Minimal padding */
|
482 |
-
margin: 0 1px; /* Minimal margin */
|
483 |
-
border-radius: 3px;
|
484 |
-
font-family: monospace; /* Use monospace font for better alignment */
|
485 |
-
font-size: 95%; /* Slightly smaller font for dense vis */
|
486 |
-
line-height: 1.4; /* Adjust line spacing */
|
487 |
-
}
|
488 |
-
.gr-highlightedtext {
|
489 |
-
padding: 10px;
|
490 |
-
border: 1px solid #E0E0E0;
|
491 |
-
border-radius: 5px;
|
492 |
-
background-color: #FAFAFA; /* Light background for the container */
|
493 |
-
}
|
494 |
-
/* Legend Styling */
|
495 |
-
.legend {
|
496 |
-
font-size: 90%;
|
497 |
-
margin-top: 5px;
|
498 |
-
color: #555;
|
499 |
-
}
|
500 |
-
.legend span {
|
501 |
-
display: inline-block; /* Keep legend items inline */
|
502 |
-
margin-right: 10px;
|
503 |
-
white-space: nowrap; /* Prevent wrapping */
|
504 |
-
}
|
505 |
-
.legend span::before { /* Style the color square */
|
506 |
-
content: '■';
|
507 |
-
display: inline-block;
|
508 |
-
margin-right: 4px;
|
509 |
-
font-size: 120%; /* Make square slightly larger */
|
510 |
-
vertical-align: middle; /* Align square with text */
|
511 |
-
}
|
512 |
'''
|
513 |
def create_chatbot_demo():
|
514 |
-
with gr.Blocks(css=css
|
515 |
-
gr.Markdown("
|
516 |
-
gr.Markdown("
|
517 |
-
|
518 |
-
with gr.Row():
|
519 |
-
gr.Markdown("[Model Card](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)", scale=1)
|
520 |
-
gr.Markdown("[Blog Post](https://hkunlp.github.io/blog/2025/dream/)", scale=1)
|
521 |
|
522 |
# STATE MANAGEMENT
|
523 |
-
chat_history = gr.State([])
|
524 |
|
525 |
-
# UI
|
526 |
with gr.Row():
|
527 |
-
# Left Column: Chat Interface
|
528 |
with gr.Column(scale=3):
|
529 |
-
chatbot_ui = gr.Chatbot(
|
530 |
-
label="Conversation",
|
531 |
-
height=550,
|
532 |
-
bubble_full_width=False,
|
533 |
-
show_copy_button=True,
|
534 |
-
render=False # Rendered explicitly later for streaming
|
535 |
-
)
|
536 |
-
chatbot_ui.render() # Manually render after setting parameters
|
537 |
|
538 |
-
# Message input
|
539 |
-
with gr.
|
|
|
540 |
user_input = gr.Textbox(
|
541 |
label="Your Message",
|
542 |
-
placeholder="Type your message
|
543 |
-
|
544 |
-
|
545 |
-
show_label=False
|
546 |
)
|
547 |
-
send_btn = gr.Button("Send", scale=1
|
548 |
|
549 |
constraints_input = gr.Textbox(
|
550 |
label="Word Constraints (Optional)",
|
551 |
-
info="
|
552 |
-
placeholder="
|
553 |
-
|
554 |
)
|
555 |
-
|
556 |
-
# Right Column: Visualization and Settings
|
557 |
with gr.Column(scale=2):
|
558 |
-
gr.Markdown("### Denoising Process Visualization")
|
559 |
output_vis = gr.HighlightedText(
|
560 |
-
label="
|
561 |
-
show_label=False, # Label provided by Markdown above
|
562 |
combine_adjacent=False,
|
563 |
-
show_legend=
|
564 |
-
#
|
565 |
-
color_map={
|
566 |
-
"Mask": "#444444",
|
567 |
-
"New": "#66CC66",
|
568 |
-
"Old": "#6699CC",
|
569 |
-
"Constraint": "#800080",
|
570 |
-
"Special (New)": "#FF8C00",
|
571 |
-
"Error": "red"
|
572 |
-
}
|
573 |
)
|
574 |
-
#
|
575 |
-
gr.
|
576 |
-
""
|
577 |
-
<
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
<span style='color:#6699CC;'>■ Old</span> |
|
582 |
-
<span style='color:#800080;'>■ Constraint</span>
|
583 |
-
</div>
|
584 |
-
""",
|
585 |
-
elem_id="legend-html"
|
586 |
)
|
587 |
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
|
620 |
-
label="Order Randomness (alg_temp)" ,
|
621 |
-
info="Adds randomness to confidence-based strategies (Entropy, MaskGit+, TopK). Ignored for Random."
|
622 |
-
)
|
623 |
-
visualization_delay = gr.Slider(
|
624 |
-
minimum=0.0, maximum=0.5, value=0.05, step=0.01,
|
625 |
-
label="Visualization Delay (sec)", info="Pause between steps in visualization."
|
626 |
-
)
|
627 |
-
|
628 |
-
# Clear button Row
|
629 |
-
with gr.Row():
|
630 |
-
clear_btn = gr.Button("Clear Conversation", variant="stop", icon="🗑️")
|
631 |
-
|
632 |
-
|
633 |
-
# --- Event Handlers ---
|
634 |
-
|
635 |
-
# Helper to add message to history state
|
636 |
-
def add_message_to_history(history_state, user_message, bot_message):
|
637 |
-
# history_state is the raw list from gr.State
|
638 |
-
history_state.append([user_message, bot_message])
|
639 |
-
return history_state
|
640 |
-
|
641 |
-
# Function when user submits message (Enter or Send button)
|
642 |
-
def handle_user_message(message, history_state):
|
643 |
-
print(f"User submitted: '{message}'")
|
644 |
-
if not message or not message.strip():
|
645 |
-
print("Empty message submitted, doing nothing.")
|
646 |
-
# Return unchanged state if message is empty
|
647 |
-
# Need to return values for all outputs of the .submit/.click
|
648 |
-
# history_state, chatbot_ui, user_input, output_vis
|
649 |
-
return history_state, history_state, "", [] # No change to chatbot UI yet
|
650 |
-
|
651 |
-
# Add user message to history state (with None for bot response initially)
|
652 |
-
updated_history_state = add_message_to_history(history_state, message, None)
|
653 |
-
|
654 |
-
# Prepare updated history for display in Chatbot UI
|
655 |
-
# We only display the user message now, bot response comes later
|
656 |
-
chatbot_display = updated_history_state.copy()
|
657 |
-
|
658 |
-
# Clear the input textbox and visualization
|
659 |
-
return updated_history_state, chatbot_display, "", []
|
660 |
-
|
661 |
-
# Function to generate bot response (triggered after user message is handled)
|
662 |
-
# Uses yield for streaming visualization updates
|
663 |
-
def generate_bot_response(
|
664 |
-
history_state, # The current state [[user, None], ...]
|
665 |
-
gen_length_val, steps_val, constraints_text, delay_val,
|
666 |
-
temperature_val, top_p_val, alg_val, alg_temp_val
|
667 |
-
):
|
668 |
-
print("\n--- Streaming Bot Response ---")
|
669 |
-
if not history_state or history_state[-1][1] is not None:
|
670 |
-
print("History empty or last message already has response. Skipping generation.")
|
671 |
-
# Yield current state if called unnecessarily
|
672 |
-
yield history_state, [] # Chatbot UI, Visualization
|
673 |
-
return
|
674 |
-
|
675 |
-
# Get the conversation history in the format the model expects
|
676 |
-
messages_for_model = format_chat_history(history_state) # Includes the latest user query
|
677 |
-
|
678 |
-
# Parse constraints from the textbox
|
679 |
-
parsed_constraints = parse_constraints(constraints_text)
|
680 |
-
|
681 |
-
# Generate response with visualization (this function handles the core logic)
|
682 |
-
vis_states, response_text = dream_generate_response_with_visualization(
|
683 |
-
messages_for_model,
|
684 |
-
gen_length=gen_length_val,
|
685 |
-
steps=steps_val,
|
686 |
-
constraints=parsed_constraints,
|
687 |
-
temperature=temperature_val,
|
688 |
-
top_p=top_p_val,
|
689 |
-
alg=alg_val,
|
690 |
-
alg_temp=alg_temp_val
|
691 |
)
|
692 |
|
693 |
-
|
694 |
-
|
695 |
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
|
|
700 |
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
else:
|
706 |
-
# Handle case where generation failed or produced no visualization
|
707 |
-
print("Warning: No visualization states generated.")
|
708 |
-
yield history_state, [("No visualization generated.", "orange")] # Update chatbot UI, show warning in vis
|
709 |
-
|
710 |
-
print("--- Streaming Complete ---")
|
711 |
-
|
712 |
-
|
713 |
-
# Function to clear everything
|
714 |
-
def clear_conversation_state():
|
715 |
-
print("Clearing conversation.")
|
716 |
-
# Reset state and UI components
|
717 |
-
return [], [], "", [] # chat_history (State), chatbot_ui, user_input, output_vis
|
718 |
-
|
719 |
-
# --- Wire UI elements to functions ---
|
720 |
-
|
721 |
-
# Define shared inputs for generation to avoid repetition
|
722 |
-
generation_inputs = [
|
723 |
-
chat_history, gen_length, steps, constraints_input, visualization_delay,
|
724 |
-
temperature, top_p, remasking_strategy, alg_temp
|
725 |
-
]
|
726 |
-
# Define shared outputs for streaming
|
727 |
-
streaming_outputs = [chatbot_ui, output_vis]
|
728 |
-
|
729 |
-
# Typing in Textbox and pressing Enter
|
730 |
-
user_input.submit(
|
731 |
-
fn=handle_user_message,
|
732 |
-
inputs=[user_input, chat_history],
|
733 |
-
outputs=[chat_history, chatbot_ui, user_input, output_vis], # Update history state, chatbot display, clear input, clear vis
|
734 |
-
queue=False # Process user input immediately
|
735 |
-
).then(
|
736 |
-
fn=generate_bot_response,
|
737 |
-
inputs=generation_inputs,
|
738 |
-
outputs=streaming_outputs, # Stream updates to chatbot and visualization
|
739 |
-
#api_name="generate_stream" # Optional: Name for API endpoint
|
740 |
-
)
|
741 |
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
754 |
|
755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
756 |
clear_btn.click(
|
757 |
-
fn=
|
758 |
inputs=[],
|
759 |
-
outputs=[chat_history, chatbot_ui, user_input, output_vis]
|
760 |
-
queue=False # Clearing should be instant
|
761 |
)
|
762 |
|
763 |
return demo
|
764 |
|
765 |
-
#
|
766 |
if __name__ == "__main__":
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
# share=True generates a public link (useful for Colab/Spaces)
|
772 |
-
# debug=True provides helpful Gradio logs in the console
|
773 |
-
gradio_demo.queue().launch(share=True, debug=False) # Set debug=True for more verbose logs if needed
|
|
|
1 |
+
# dream_app.py
|
2 |
import torch
|
3 |
+
import numpy as np
|
4 |
import gradio as gr
|
5 |
import spaces
|
|
|
6 |
import time
|
7 |
+
import re
|
8 |
+
from transformers import AutoModel, AutoTokenizer
|
9 |
+
from threading import Lock
|
10 |
+
from queue import Queue
|
11 |
+
|
12 |
+
# --- Configuration ---
|
13 |
+
MODEL_PATH = "Dream-org/Dream-v0-Instruct-7B"
|
14 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
15 |
+
print(f"Using device: {DEVICE}")
|
16 |
+
|
17 |
+
# --- Load Model and Tokenizer ---
|
18 |
+
print("Loading model and tokenizer...")
|
19 |
+
# Need configuration files for trust_remote_code
|
20 |
+
# Make sure config.json, configuration_dream.py, modeling_dream.py,
|
21 |
+
# generation_utils.py, generation_config.json are in the same directory
|
22 |
+
# or accessible in the Hugging Face cache.
|
23 |
+
model = AutoModel.from_pretrained(
|
24 |
+
MODEL_PATH,
|
25 |
+
torch_dtype=torch.bfloat16,
|
26 |
+
trust_remote_code=True
|
27 |
+
).to(DEVICE).eval()
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
29 |
+
MODEL_PATH,
|
30 |
+
trust_remote_code=True
|
31 |
+
)
|
32 |
+
print("Model and tokenizer loaded.")
|
33 |
+
|
34 |
+
# --- Constants ---
|
35 |
+
# Get IDs from tokenizer/config if possible, otherwise hardcode from provided files
|
36 |
+
MASK_TOKEN = tokenizer.mask_token # Should be "<|mask|>"
|
37 |
try:
|
38 |
+
MASK_ID = tokenizer.mask_token_id # Should be 151666
|
39 |
+
if MASK_ID is None: raise AttributeError # Handle case where it might not be set directly
|
|
|
40 |
except AttributeError:
|
41 |
+
print("Warning: Could not directly get mask_token_id, using hardcoded value 151666.")
|
42 |
+
MASK_ID = 151666
|
43 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
try:
|
45 |
+
EOS_ID = tokenizer.eos_token_id # Should be 151643
|
46 |
+
PAD_ID = tokenizer.pad_token_id # Should be 151643
|
47 |
+
if EOS_ID is None or PAD_ID is None: raise AttributeError
|
48 |
+
except AttributeError:
|
49 |
+
print("Warning: Could not directly get eos/pad_token_id, using hardcoded value 151643.")
|
50 |
+
EOS_ID = 151643
|
51 |
+
PAD_ID = 151643
|
52 |
+
|
53 |
+
# Ensure MASK_TOKEN and MASK_ID are valid
|
54 |
+
if MASK_TOKEN is None or MASK_ID is None:
|
55 |
+
raise ValueError("Mask token or ID is not defined correctly.")
|
56 |
+
if EOS_ID is None or PAD_ID is None:
|
57 |
+
raise ValueError("EOS/PAD token or ID is not defined correctly.")
|
58 |
+
|
59 |
+
# --- Helper Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
def parse_constraints(constraints_text):
|
62 |
"""Parse constraints in format: 'position:word, position:word, ...'"""
|
|
|
66 |
|
67 |
parts = constraints_text.split(',')
|
68 |
for part in parts:
|
|
|
69 |
if ':' not in part:
|
70 |
continue
|
71 |
try:
|
72 |
pos_str, word = part.split(':', 1)
|
73 |
pos = int(pos_str.strip())
|
74 |
word = word.strip()
|
|
|
75 |
if word and pos >= 0:
|
76 |
+
# Tokenize the word - handle potential multi-token words
|
77 |
+
# Add space prefix for consistency, similar to how model might see words mid-sentence
|
78 |
+
tokens = tokenizer.encode(" " + word, add_special_tokens=False)
|
79 |
+
for i, token_id in enumerate(tokens):
|
80 |
+
constraints[pos + i] = token_id
|
81 |
except ValueError:
|
82 |
+
continue
|
83 |
+
except Exception as e:
|
84 |
+
print(f"Error parsing constraint part '{part}': {e}")
|
85 |
continue
|
86 |
|
87 |
return constraints
|
88 |
|
89 |
def format_chat_history(history):
|
90 |
"""
|
91 |
+
Format chat history for the Dream model using its chat template logic.
|
92 |
|
93 |
Args:
|
94 |
history: List of [user_message, assistant_message] pairs
|
95 |
|
96 |
Returns:
|
97 |
+
Formatted list of message dictionaries for the model
|
98 |
"""
|
99 |
messages = []
|
100 |
+
# Add system prompt if history is empty or doesn't start with system
|
101 |
+
if not history or history[0][0].lower() != 'system':
|
102 |
+
# Check if the tokenizer's template expects an explicit system message
|
103 |
+
# The template provided in tokenizer_config.json handles adding a default one
|
104 |
+
pass # Let apply_chat_template handle the default system message
|
105 |
+
|
106 |
for user_msg, assistant_msg in history:
|
107 |
+
if user_msg: # Handle potential initial system message possibility if needed
|
108 |
messages.append({"role": "user", "content": user_msg})
|
109 |
+
if assistant_msg is not None: # Skip if None (for the latest user message)
|
110 |
messages.append({"role": "assistant", "content": assistant_msg})
|
111 |
|
112 |
return messages
|
113 |
|
114 |
+
# --- Core Generation Logic with Visualization ---
|
115 |
|
116 |
+
# Use a thread-safe queue to pass visualization states from the hook
|
117 |
+
vis_queue = Queue()
|
118 |
+
# Lock to prevent race conditions when accessing shared state like previous_x
|
119 |
+
state_lock = Lock()
|
120 |
+
# Store the previous state for comparison in the hook
|
121 |
+
previous_x_shared = None
|
122 |
+
|
123 |
+
@spaces.GPU
|
124 |
+
def generate_response_with_visualization(
|
125 |
+
messages, # List of message dicts from format_chat_history
|
126 |
+
max_new_tokens=64,
|
127 |
+
steps=64, # Default steps based on README example
|
128 |
constraints=None,
|
129 |
+
temperature=0.6, # Default from demo_token_control
|
130 |
+
top_p=0.95, # Default from demos
|
131 |
+
alg="entropy", # Default from demos
|
132 |
+
alg_temp=0.1, # Default from demo_multiturn_chat
|
133 |
):
|
134 |
"""
|
135 |
+
Generate text with Dream model and capture visualization states using a hook.
|
136 |
|
137 |
Args:
|
138 |
+
messages: List of message dictionaries with 'role' and 'content'.
|
139 |
+
max_new_tokens: Max tokens to generate.
|
140 |
+
steps: Diffusion steps.
|
141 |
+
constraints: Dictionary mapping positions (relative to response start) to token IDs.
|
142 |
+
temperature: Sampling temperature.
|
143 |
+
top_p: Nucleus sampling p.
|
144 |
+
alg: Remasking algorithm ('origin', 'entropy', 'maskgit_plus', 'topk_margin').
|
145 |
+
alg_temp: Temperature for confidence-based algorithms.
|
146 |
|
147 |
Returns:
|
148 |
Tuple: (List of visualization states, final generated text string)
|
149 |
"""
|
150 |
+
global previous_x_shared, vis_queue
|
|
|
|
|
|
|
|
|
151 |
if constraints is None:
|
152 |
constraints = {}
|
153 |
|
154 |
+
visualization_states = []
|
155 |
+
|
156 |
+
# Clear the queue for a new generation
|
157 |
+
while not vis_queue.empty():
|
158 |
+
try:
|
159 |
+
vis_queue.get_nowait()
|
160 |
+
except Queue.Empty:
|
161 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
# Prepare the prompt using chat template
|
164 |
+
# The template automatically adds the generation prompt like "<|im_start|>assistant\n"
|
165 |
try:
|
166 |
inputs = tokenizer.apply_chat_template(
|
167 |
messages,
|
168 |
return_tensors="pt",
|
169 |
+
add_generation_prompt=True,
|
170 |
+
return_dict=True
|
171 |
)
|
172 |
+
input_ids = inputs.input_ids.to(device=DEVICE)
|
173 |
+
# Dream doesn't seem to explicitly use attention_mask in simple demos,
|
174 |
+
# but it's good practice if padding were involved.
|
175 |
+
# For now, assume no padding in this interactive demo.
|
176 |
+
attention_mask = inputs.attention_mask.to(device=DEVICE) if 'attention_mask' in inputs else None
|
177 |
+
|
|
|
178 |
except Exception as e:
|
179 |
print(f"Error applying chat template: {e}")
|
180 |
+
# Provide a fallback or error state
|
181 |
+
error_state = [("Error in chat formatting.", "red")]
|
182 |
+
return [error_state], f"Error: Could not format chat history. {e}"
|
183 |
+
|
184 |
+
prompt_length = input_ids.shape[1]
|
185 |
+
total_length = prompt_length + max_new_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# --- Define the Hook Function ---
|
188 |
def generation_tokens_hook_func(step, x, logits):
|
189 |
+
global previous_x_shared, vis_queue
|
190 |
+
with state_lock: # Ensure thread safety if needed, though hooks might run sequentially
|
191 |
+
current_x = x.clone() # Shape: (batch_size, total_length)
|
192 |
+
|
193 |
+
# --- Apply Constraints ---
|
194 |
+
# Constraints are relative to the start of the *response*
|
195 |
+
for rel_pos, token_id in constraints.items():
|
196 |
+
abs_pos = prompt_length + rel_pos
|
197 |
+
if 0 <= abs_pos < current_x.shape[1]:
|
198 |
+
# Ensure constraint application doesn't go out of bounds
|
199 |
+
# Apply constraint for the first batch element (batch size is 1 here)
|
200 |
+
current_x[0, abs_pos] = token_id
|
201 |
+
|
202 |
+
# --- Create Visualization State ---
|
203 |
+
current_vis_state = []
|
204 |
+
x_response = current_x[0, prompt_length:] # Get the response part for batch 0
|
205 |
+
prev_x_response = previous_x_shared[0, prompt_length:] if previous_x_shared is not None else None
|
206 |
+
|
207 |
+
for i in range(max_new_tokens):
|
208 |
+
current_token_id = x_response[i].item()
|
209 |
+
token_str = tokenizer.decode([current_token_id], skip_special_tokens=False) # Keep special tokens for vis
|
210 |
+
|
211 |
+
# Clean up visual representation of special tokens
|
212 |
+
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
|
213 |
+
token_str = "[EOS/PAD]" # Make it visually distinct
|
214 |
+
elif token_str == tokenizer.mask_token:
|
215 |
+
token_str = "[MASK]"
|
216 |
+
elif token_str.strip() == "": # Handle empty strings from decoding potentially odd tokens
|
217 |
+
token_str = "[UNK/SPACE]"
|
218 |
+
|
219 |
+
|
220 |
+
color = "#DDDDDD" # Default background
|
221 |
+
|
222 |
+
if current_token_id == MASK_ID:
|
223 |
+
color = "#444444" # Dark gray for masks
|
224 |
+
elif prev_x_response is not None and prev_x_response[i].item() == MASK_ID:
|
225 |
+
# Token was mask, now it's revealed in this step
|
226 |
+
# Use green for newly revealed
|
227 |
+
color = "#66CC66" # Light green
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
else:
|
229 |
+
# Token was already revealed in a previous step or is a constraint
|
230 |
+
# Check if it's a constraint applied *now*
|
231 |
+
is_constraint = (prompt_length + i - prompt_length) in constraints and \
|
232 |
+
constraints[prompt_length + i - prompt_length] == current_token_id
|
233 |
+
|
234 |
+
if is_constraint:
|
235 |
+
color = "#FFD700" # Gold for constraints
|
236 |
+
else:
|
237 |
+
color = "#6699CC" # Light blue for previously revealed
|
238 |
+
|
239 |
+
current_vis_state.append((token_str, color))
|
240 |
+
|
241 |
+
# --- Update shared state and put vis state in queue ---
|
242 |
+
previous_x_shared = current_x.clone() # Update for the *next* step's comparison
|
243 |
+
vis_queue.put(current_vis_state)
|
244 |
+
|
245 |
+
# The hook must return the potentially modified tensor `x`
|
246 |
+
return current_x
|
247 |
+
# --- End of Hook Function ---
|
248 |
+
|
249 |
+
# Initialize previous_x_shared before generation starts
|
250 |
+
# Create initial masked state for visualization
|
251 |
+
initial_x = input_ids.clone()
|
252 |
+
if initial_x.shape[1] < total_length:
|
253 |
+
padding = torch.full((1, total_length - initial_x.shape[1]), MASK_ID, dtype=torch.long, device=DEVICE)
|
254 |
+
initial_x = torch.cat([initial_x, padding], dim=1)
|
255 |
+
else:
|
256 |
+
initial_x = initial_x[:, :total_length] # Truncate if prompt is too long
|
257 |
+
|
258 |
+
# Apply initial constraints to the starting state
|
259 |
+
for rel_pos, token_id in constraints.items():
|
260 |
+
abs_pos = prompt_length + rel_pos
|
261 |
+
if 0 <= abs_pos < initial_x.shape[1]:
|
262 |
+
initial_x[0, abs_pos] = token_id
|
263 |
+
|
264 |
+
with state_lock:
|
265 |
+
previous_x_shared = initial_x.clone()
|
266 |
+
|
267 |
+
# Add the initial all-masked state (or with constraints) to the visualization queue
|
268 |
+
initial_vis_state = []
|
269 |
+
initial_x_response = initial_x[0, prompt_length:]
|
270 |
+
for i in range(max_new_tokens):
|
271 |
+
token_id = initial_x_response[i].item()
|
272 |
+
if token_id == MASK_ID:
|
273 |
+
initial_vis_state.append((MASK_TOKEN, "#444444"))
|
274 |
+
else:
|
275 |
+
# Must be a pre-applied constraint
|
276 |
+
token_str = tokenizer.decode([token_id], skip_special_tokens=False)
|
277 |
+
if token_str == tokenizer.eos_token or token_str == tokenizer.pad_token:
|
278 |
+
token_str = "[EOS/PAD]"
|
279 |
+
elif token_str.strip() == "":
|
280 |
+
token_str = "[UNK/SPACE]"
|
281 |
+
initial_vis_state.append((token_str, "#FFD700")) # Gold for constraints
|
282 |
+
vis_queue.put(initial_vis_state)
|
283 |
+
|
284 |
+
|
285 |
+
# --- Run Generation ---
|
286 |
try:
|
287 |
+
# output_history=False because the hook handles state capture
|
288 |
+
# return_dict_in_generate=True to get the GenerationOutput object
|
|
|
|
|
|
|
|
|
289 |
output = model.diffusion_generate(
|
290 |
+
initial_x, # Start with the potentially constraint-applied tensor
|
291 |
+
attention_mask=None, # Assuming no padding needed for interactive use
|
292 |
+
max_new_tokens=max_new_tokens, # This might not be strictly needed if total_length is fixed
|
293 |
+
output_history=False,
|
294 |
return_dict_in_generate=True,
|
295 |
steps=steps,
|
296 |
temperature=temperature,
|
297 |
top_p=top_p,
|
298 |
alg=alg,
|
299 |
+
alg_temp=alg_temp if alg != 'origin' else None, # alg_temp only for confidence algs
|
|
|
300 |
generation_tokens_hook_func=generation_tokens_hook_func
|
|
|
|
|
|
|
301 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
|
303 |
+
final_sequence = output.sequences[0] # Batch size 1
|
304 |
+
|
305 |
+
# Decode the final response text, cleaning up special tokens
|
306 |
+
response_tokens = final_sequence[prompt_length:]
|
307 |
+
# Filter out EOS/PAD tokens for the final text display
|
308 |
+
response_tokens_filtered = [tok for tok in response_tokens.tolist() if tok != EOS_ID and tok != PAD_ID]
|
309 |
+
final_text = tokenizer.decode(response_tokens_filtered,
|
310 |
+
skip_special_tokens=True,
|
311 |
+
clean_up_tokenization_spaces=True) # Standard cleanup
|
312 |
|
313 |
except Exception as e:
|
314 |
+
print(f"Error during generation: {e}")
|
315 |
import traceback
|
316 |
traceback.print_exc()
|
317 |
+
# Provide error state
|
318 |
+
error_state = [("Generation Error.", "red")]
|
319 |
+
visualization_states.append(error_state)
|
320 |
+
final_text = f"Error: Generation failed. {e}"
|
321 |
+
# Add any states captured before the error
|
322 |
+
while not vis_queue.empty():
|
323 |
+
try:
|
324 |
+
visualization_states.append(vis_queue.get_nowait())
|
325 |
+
except Queue.Empty:
|
326 |
+
break
|
327 |
+
return visualization_states, final_text
|
328 |
+
|
329 |
+
# Retrieve all visualization states captured by the hook
|
330 |
+
while not vis_queue.empty():
|
331 |
+
try:
|
332 |
+
visualization_states.append(vis_queue.get_nowait())
|
333 |
+
except Queue.Empty:
|
334 |
+
break
|
335 |
|
336 |
+
# If somehow no states were captured, add the initial one
|
337 |
+
if not visualization_states:
|
338 |
+
visualization_states.append(initial_vis_state)
|
|
|
339 |
|
|
|
340 |
|
341 |
+
return visualization_states, final_text.strip()
|
342 |
|
343 |
+
|
344 |
+
# --- Gradio UI ---
|
345 |
|
346 |
css = '''
|
347 |
.category-legend{display:none}
|
348 |
+
button{height: 60px}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
'''
|
350 |
def create_chatbot_demo():
|
351 |
+
with gr.Blocks(css=css) as demo:
|
352 |
+
gr.Markdown("# Dream 7B - Diffusion Language Model Demo")
|
353 |
+
gr.Markdown("Chat with the Dream 7B Instruct model and visualize the diffusion generation process.")
|
354 |
+
gr.Markdown("Model: [Dream-org/Dream-v0-Instruct-7B](https://huggingface.co/Dream-org/Dream-v0-Instruct-7B)")
|
|
|
|
|
|
|
355 |
|
356 |
# STATE MANAGEMENT
|
357 |
+
chat_history = gr.State([])
|
358 |
|
359 |
+
# UI COMPONENTS
|
360 |
with gr.Row():
|
|
|
361 |
with gr.Column(scale=3):
|
362 |
+
chatbot_ui = gr.Chatbot(label="Conversation", height=500, avatar_images=["user.png", "robot.png"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
+
# Message input
|
365 |
+
with gr.Group():
|
366 |
+
with gr.Row():
|
367 |
user_input = gr.Textbox(
|
368 |
label="Your Message",
|
369 |
+
placeholder="Type your message here...",
|
370 |
+
show_label=False,
|
371 |
+
scale=9
|
|
|
372 |
)
|
373 |
+
send_btn = gr.Button("Send", scale=1)
|
374 |
|
375 |
constraints_input = gr.Textbox(
|
376 |
label="Word Constraints (Optional)",
|
377 |
+
info="Place words at specific positions (0-indexed from response start). Format: 'pos:word, pos:word,...'. Example: '0:Once, 5:upon, 10:a'",
|
378 |
+
placeholder="0:Once, 5:upon, 10:a",
|
379 |
+
value=""
|
380 |
)
|
|
|
|
|
381 |
with gr.Column(scale=2):
|
|
|
382 |
output_vis = gr.HighlightedText(
|
383 |
+
label="Diffusion Process Visualization",
|
|
|
384 |
combine_adjacent=False,
|
385 |
+
show_legend=True, # Keep legend hidden via CSS if desired
|
386 |
+
height=560 # Adjust height to match chatbot area
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
)
|
388 |
+
# Legend (colors defined in generate_response_with_visualization)
|
389 |
+
gr.Markdown(
|
390 |
+
"<small>Color Legend: <span style='background-color:#444444; color:white;'>[MASK]</span>"
|
391 |
+
" <span style='background-color:#66CC66;'>Newly Revealed</span>"
|
392 |
+
" <span style='background-color:#6699CC;'>Previously Revealed</span>"
|
393 |
+
" <span style='background-color:#FFD700;'>Constraint</span>"
|
394 |
+
" <span style='background-color:#DDDDDD;'>[EOS/PAD/UNK]</span></small>"
|
|
|
|
|
|
|
|
|
|
|
395 |
)
|
396 |
|
397 |
+
# Advanced generation settings
|
398 |
+
with gr.Accordion("Generation Settings", open=False):
|
399 |
+
max_new_tokens_slider = gr.Slider(
|
400 |
+
minimum=16, maximum=512, value=128, step=16, # Increased default/max
|
401 |
+
label="Max New Tokens (Generation Length)"
|
402 |
+
)
|
403 |
+
steps_slider = gr.Slider(
|
404 |
+
minimum=8, maximum=512, value=128, step=8, # Increased default/max
|
405 |
+
label="Diffusion Steps"
|
406 |
+
)
|
407 |
+
temp_slider = gr.Slider(
|
408 |
+
minimum=0.0, maximum=1.0, value=0.6, step=0.05, # Finer steps for temp
|
409 |
+
label="Temperature"
|
410 |
+
)
|
411 |
+
top_p_slider = gr.Slider(
|
412 |
+
minimum=0.0, maximum=1.0, value=0.95, step=0.05,
|
413 |
+
label="Top-P (Nucleus Sampling)"
|
414 |
+
)
|
415 |
+
alg_radio = gr.Radio(
|
416 |
+
# Choices from README
|
417 |
+
choices=['origin', 'entropy', 'maskgit_plus', 'topk_margin'],
|
418 |
+
value='entropy',
|
419 |
+
label="Remasking Algorithm"
|
420 |
+
)
|
421 |
+
alg_temp_slider = gr.Slider(
|
422 |
+
minimum=0.0, maximum=1.0, value=0.1, step=0.05,
|
423 |
+
label="Algorithm Temperature (for confidence-based algs)"
|
424 |
+
)
|
425 |
+
vis_delay_slider = gr.Slider(
|
426 |
+
minimum=0.0, maximum=0.5, value=0.03, step=0.01, # Faster default delay
|
427 |
+
label="Visualization Delay (seconds)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
428 |
)
|
429 |
|
430 |
+
# Clear button
|
431 |
+
clear_btn = gr.Button("Clear Conversation")
|
432 |
|
433 |
+
# HELPER FUNCTIONS (UI Logic)
|
434 |
+
def add_message_to_history(history, message, response):
|
435 |
+
"""Add a message pair to the history state"""
|
436 |
+
new_history = history + [[message, response]]
|
437 |
+
return new_history
|
438 |
|
439 |
+
def user_message_submitted(message, history):
|
440 |
+
""" Handle user sending a message: update history, clear input """
|
441 |
+
if not message or message.strip() == "":
|
442 |
+
return history, history, "", [] # No change if empty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
|
444 |
+
# Add user message, response is initially None
|
445 |
+
new_history = add_message_to_history(history, message, None)
|
446 |
+
|
447 |
+
# Prepare display version (immediately shows user message)
|
448 |
+
display_history = new_history
|
449 |
+
|
450 |
+
# Clear input box
|
451 |
+
message_out = ""
|
452 |
+
|
453 |
+
# Clear visualization
|
454 |
+
vis_out = []
|
455 |
+
|
456 |
+
return new_history, display_history, message_out, vis_out
|
457 |
+
|
458 |
+
def bot_response_generator(history, constraints_str, max_tokens, steps, temp, top_p, alg, alg_temp, delay):
|
459 |
+
""" Generator function to stream bot response and visualization """
|
460 |
+
if not history or history[-1][1] is not None: # Ensure there's a user msg waiting for response
|
461 |
+
print("Warning: Bot response triggered without pending user message.")
|
462 |
+
yield history, [], "Error: No user message to respond to." # Send error state back?
|
463 |
+
return
|
464 |
+
|
465 |
+
# Get the full conversation history formatted for the model
|
466 |
+
last_user_message = history[-1][0]
|
467 |
+
messages_for_model = format_chat_history(history[:-1]) # History *before* the last user msg
|
468 |
+
messages_for_model.append({"role": "user", "content": last_user_message})
|
469 |
+
|
470 |
+
# Parse constraints
|
471 |
+
try:
|
472 |
+
parsed_constraints = parse_constraints(constraints_str)
|
473 |
+
except Exception as e:
|
474 |
+
print(f"Error parsing constraints: {e}")
|
475 |
+
yield history, [("Constraint Error", "red")], f"Error: Failed to parse constraints: {e}"
|
476 |
+
return
|
477 |
+
|
478 |
+
# Generate response and visualization states
|
479 |
+
try:
|
480 |
+
vis_states, final_response_text = generate_response_with_visualization(
|
481 |
+
messages_for_model,
|
482 |
+
max_new_tokens=max_tokens,
|
483 |
+
steps=steps,
|
484 |
+
constraints=parsed_constraints,
|
485 |
+
temperature=temp,
|
486 |
+
top_p=top_p,
|
487 |
+
alg=alg,
|
488 |
+
alg_temp=alg_temp
|
489 |
+
)
|
490 |
+
except Exception as e:
|
491 |
+
print(f"Error in generate_response_with_visualization: {e}")
|
492 |
+
import traceback
|
493 |
+
traceback.print_exc()
|
494 |
+
yield history, [("Generation Error", "red")], f"Error: Generation failed: {e}"
|
495 |
+
return
|
496 |
|
497 |
+
# Update the history state with the final response *once*
|
498 |
+
history[-1][1] = final_response_text # Update the None placeholder
|
499 |
+
|
500 |
+
# Yield initial state immediately
|
501 |
+
if vis_states:
|
502 |
+
yield history, vis_states[0]
|
503 |
+
else:
|
504 |
+
yield history, [] # Should not happen if generation worked
|
505 |
+
|
506 |
+
# Stream intermediate visualization states
|
507 |
+
for state in vis_states[1:]:
|
508 |
+
time.sleep(delay)
|
509 |
+
yield history, state
|
510 |
+
|
511 |
+
# Final yield ensures the chatbot UI has the complete history
|
512 |
+
# The last state in vis_states should already be yielded by the loop
|
513 |
+
# yield history, vis_states[-1] if vis_states else []
|
514 |
+
|
515 |
+
|
516 |
+
def clear_conversation():
|
517 |
+
"""Clear the conversation history and visualization"""
|
518 |
+
return [], [], "", [] # history, chatbot_ui, user_input, output_vis
|
519 |
+
|
520 |
+
# EVENT HANDLERS
|
521 |
+
|
522 |
+
# User presses Enter or Send button
|
523 |
+
submit_args = {
|
524 |
+
"fn": user_message_submitted,
|
525 |
+
"inputs": [user_input, chat_history],
|
526 |
+
"outputs": [chat_history, chatbot_ui, user_input, output_vis]
|
527 |
+
}
|
528 |
+
user_input.submit(**submit_args)
|
529 |
+
send_btn.click(**submit_args)
|
530 |
+
|
531 |
+
# After user message is submitted, trigger bot response generation
|
532 |
+
generate_args = {
|
533 |
+
"fn": bot_response_generator,
|
534 |
+
"inputs": [
|
535 |
+
chat_history, constraints_input, max_new_tokens_slider, steps_slider,
|
536 |
+
temp_slider, top_p_slider, alg_radio, alg_temp_slider, vis_delay_slider
|
537 |
+
],
|
538 |
+
"outputs": [chatbot_ui, output_vis] # Update chatbot history and visualization
|
539 |
+
}
|
540 |
+
# Trigger generation after submit OR click
|
541 |
+
user_input.submit(None, None, None, queue=True).then(**generate_args)
|
542 |
+
send_btn.click(None, None, None, queue=True).then(**generate_args)
|
543 |
+
|
544 |
+
|
545 |
+
# Clear button handler
|
546 |
clear_btn.click(
|
547 |
+
fn=clear_conversation,
|
548 |
inputs=[],
|
549 |
+
outputs=[chat_history, chatbot_ui, user_input, output_vis]
|
|
|
550 |
)
|
551 |
|
552 |
return demo
|
553 |
|
554 |
+
# Launch the demo
|
555 |
if __name__ == "__main__":
|
556 |
+
demo = create_chatbot_demo()
|
557 |
+
# queue() allows streaming and handling multiple users
|
558 |
+
# share=True creates a public link (use with caution)
|
559 |
+
demo.queue().launch(share=True, debug=True)
|
|
|
|
|
|