Spaces:
Running on Zero

Ruurd commited on
Commit
3f5293d
·
1 Parent(s): 42c0401

Last try interface

Browse files
Files changed (1) hide show
  1. app.py +54 -50
app.py CHANGED
@@ -73,10 +73,6 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
73
  noised[idx] = val
74
  return noised
75
 
76
- print("Loading model...")
77
- model = load_model()
78
- print("✅ Model loaded.")
79
-
80
  def generate_diffusion_text(input_ids, answer_start):
81
  with torch.no_grad():
82
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
@@ -86,22 +82,33 @@ def generate_diffusion_text(input_ids, answer_start):
86
  sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
87
  return input_ids[:answer_start] + sampled[answer_start:]
88
 
89
- # --- Diffusion Chat Function ---
 
90
  @spaces.GPU
91
- def diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
92
- prompt = f"{system_prompt}\nUser: {message}\nAssistant:"
 
 
 
 
93
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
94
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
95
  if answer_start is None:
96
- yield "<span style='color:red'><b>Error:</b> Could not find Assistant marker in input.</span>"
97
  return
98
 
99
- input_ids = (input_ids + [pad_token] * (256 - len(input_ids)))[:256]
100
- current_tokens = noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=eot_weight)
 
 
 
 
 
101
  prev_decoded_tokens = []
102
  last_tokens = []
103
 
104
  for i in range(max_it):
 
105
  generated_tokens = generate_diffusion_text(current_tokens, answer_start)
106
  current_tokens = generated_tokens
107
 
@@ -110,21 +117,24 @@ def diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
110
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
111
  filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
112
 
113
- highlighted = []
114
- for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
115
- text = tokenizer.convert_tokens_to_string([tok_new])
116
- if tok_new != tok_old:
117
- highlighted.append(f"<span style='color:green'>{text}</span>")
118
- else:
119
- highlighted.append(text)
 
 
120
 
121
  prev_decoded_tokens = decoded_tokens
122
- yield ("<div style='padding:0.5em'><b>Iteration {}</b><br>"
123
- "<div style='background:#f5f5f5;padding:0.5em;border-radius:0.5em'>{}</div></div>").format(i+1, ''.join(highlighted))
124
 
125
  last_tokens.append(generated_tokens)
126
- if len(last_tokens) == 3 and all(t == last_tokens[0] for t in last_tokens):
127
- yield f"<div style='color:gray'><i>Stopped early after {i+1} iterations (converged).</i></div>"
 
 
128
  break
129
 
130
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
@@ -134,33 +144,27 @@ def diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
134
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
135
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
136
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
137
- yield f"<div style='padding:0.5em'><b>Final Output:</b><br><div style='background:#e0ffe0;padding:0.5em;border-radius:0.5em'>{final_output}</div></div>"
138
-
139
-
140
- with gr.Blocks() as demo:
141
- gr.Markdown("## Diffusion Language Model Chat")
142
- with gr.Row():
143
- with gr.Column(scale=3):
144
- chatbot = gr.Chatbot()
145
- message = gr.Textbox(label="User Message")
146
- submit = gr.Button("Send")
147
- with gr.Column(scale=1):
148
- system_prompt = gr.Textbox(value="You are a helpful assistant.", label="System Message")
149
- eot_weight = gr.Slider(0, 1, value=0.4, step=0.05, label="EOT token weight")
150
- max_it = gr.Slider(1, 512, value=64, step=1, label="Max Iterations")
151
- sharpness = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Noising Sharpness")
152
-
153
- def wrapped_chat(message, history, system_prompt, eot_weight, max_it, sharpness):
154
- history = history or []
155
- for update in diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
156
- yield history + [(message, update)]
157
-
158
- submit.click(
159
- fn=wrapped_chat,
160
- inputs=[message, chatbot, system_prompt, eot_weight, max_it, sharpness],
161
- outputs=chatbot,
162
- )
163
-
164
 
165
- if __name__ == "__main__":
166
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  noised[idx] = val
74
  return noised
75
 
 
 
 
 
76
  def generate_diffusion_text(input_ids, answer_start):
77
  with torch.no_grad():
78
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
 
82
  sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
83
  return input_ids[:answer_start] + sampled[answer_start:]
84
 
85
+ # --- Inference Wrapper ---
86
+
87
  @spaces.GPU
88
+ def diffusion_chat(question, eot_weight, max_it, sharpness):
89
+ placeholder = "What do you know about the city of New York?"
90
+ if question.strip() == "":
91
+ question = placeholder
92
+
93
+ prompt = f"User: {question}\nAssistant:"
94
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
95
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
96
  if answer_start is None:
97
+ yield "Error: Could not find Assistant marker in input."
98
  return
99
 
100
+ if len(input_ids) < 256:
101
+ input_ids += [pad_token] * (256 - len(input_ids))
102
+ else:
103
+ input_ids = input_ids[:256]
104
+
105
+ ori_input_tokens = input_ids
106
+ current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
107
  prev_decoded_tokens = []
108
  last_tokens = []
109
 
110
  for i in range(max_it):
111
+ print('Generating output')
112
  generated_tokens = generate_diffusion_text(current_tokens, answer_start)
113
  current_tokens = generated_tokens
114
 
 
117
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
118
  filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
119
 
120
+ if filtered_prev_tokens:
121
+ highlighted = []
122
+ for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
123
+ if tok_new != tok_old:
124
+ highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
125
+ else:
126
+ highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
127
+ else:
128
+ highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
129
 
130
  prev_decoded_tokens = decoded_tokens
131
+ yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted)
 
132
 
133
  last_tokens.append(generated_tokens)
134
+ if len(last_tokens) > 3:
135
+ last_tokens.pop(0)
136
+ if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
137
+ yield f"<b>Stopped early after {i+1} iterations.</b>"
138
  break
139
 
140
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
 
144
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
145
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
146
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
147
+ print(final_output)
148
+ yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
149
+
150
+ # --- Gradio Interface ---
151
+
152
+ print("Loading model...")
153
+ model = load_model()
154
+ print("✅ Model loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ demo = gr.Interface(
157
+ fn=diffusion_chat,
158
+ inputs=[
159
+ gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
160
+ gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
161
+ gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
162
+ gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)")
163
+ ],
164
+ outputs=[gr.HTML(label="Diffusion Output")],
165
+ title="Diffusion Language Model Chat",
166
+ theme="default",
167
+ description="This interface runs a diffusion-based language model to generate answers progressively."
168
+ )
169
+
170
+ demo.launch(share=True)