Spaces:
Running on Zero

Ruurd commited on
Commit
84a6c46
·
1 Parent(s): 16563e8

Change interface

Browse files
Files changed (1) hide show
  1. app.py +49 -49
app.py CHANGED
@@ -84,8 +84,17 @@ def generate_diffusion_text(input_ids, answer_start):
84
 
85
  # --- Inference Wrapper ---
86
 
 
 
 
 
 
 
 
 
 
87
  @spaces.GPU
88
- def diffusion_chat(question, eot_weight, max_it, sharpness):
89
  placeholder = "What do you know about the city of New York?"
90
  if question.strip() == "":
91
  question = placeholder
@@ -94,21 +103,15 @@ def diffusion_chat(question, eot_weight, max_it, sharpness):
94
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
95
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
96
  if answer_start is None:
97
- yield "Error: Could not find Assistant marker in input."
98
- return
99
-
100
- if len(input_ids) < 256:
101
- input_ids += [pad_token] * (256 - len(input_ids))
102
- else:
103
- input_ids = input_ids[:256]
104
 
105
- ori_input_tokens = input_ids
106
- current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
107
  prev_decoded_tokens = []
108
  last_tokens = []
 
109
 
110
  for i in range(max_it):
111
- print('Generating output')
112
  generated_tokens = generate_diffusion_text(current_tokens, answer_start)
113
  current_tokens = generated_tokens
114
 
@@ -117,24 +120,17 @@ def diffusion_chat(question, eot_weight, max_it, sharpness):
117
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
118
  filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
119
 
120
- if filtered_prev_tokens:
121
- highlighted = []
122
- for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
123
- if tok_new != tok_old:
124
- highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
125
- else:
126
- highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
127
- else:
128
- highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
129
 
130
  prev_decoded_tokens = decoded_tokens
131
- yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted)
132
-
133
  last_tokens.append(generated_tokens)
134
- if len(last_tokens) > 3:
135
- last_tokens.pop(0)
136
- if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
137
- yield f"<b>Stopped early after {i+1} iterations.</b>"
138
  break
139
 
140
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
@@ -144,27 +140,31 @@ def diffusion_chat(question, eot_weight, max_it, sharpness):
144
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
145
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
146
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
147
- print(final_output)
148
- yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
149
-
150
- # --- Gradio Interface ---
151
-
152
- print("Loading model...")
153
- model = load_model()
154
- print("✅ Model loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- demo = gr.Interface(
157
- fn=diffusion_chat,
158
- inputs=[
159
- gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
160
- gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
161
- gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
162
- gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)")
163
- ],
164
- outputs=gr.HTML(label="Diffusion Output"),
165
- title="Diffusion Language Model Chat",
166
- theme="default",
167
- description="This interface runs a diffusion-based language model to generate answers progressively."
168
- )
169
-
170
- demo.launch()
 
84
 
85
  # --- Inference Wrapper ---
86
 
87
+
88
+
89
+ # --- Gradio Interface ---
90
+
91
+ print("Loading model...")
92
+ model = load_model()
93
+ print("✅ Model loaded.")
94
+
95
+ # --- Generation logic ---
96
  @spaces.GPU
97
+ def run_diffusion_loop(question, eot_weight, max_it, sharpness):
98
  placeholder = "What do you know about the city of New York?"
99
  if question.strip() == "":
100
  question = placeholder
 
103
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
104
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
105
  if answer_start is None:
106
+ return [], "Error: Could not find Assistant marker in input."
 
 
 
 
 
 
107
 
108
+ input_ids = (input_ids + [pad_token] * (256 - len(input_ids)))[:256]
109
+ current_tokens = noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=eot_weight)
110
  prev_decoded_tokens = []
111
  last_tokens = []
112
+ history = ["**User:** " + question]
113
 
114
  for i in range(max_it):
 
115
  generated_tokens = generate_diffusion_text(current_tokens, answer_start)
116
  current_tokens = generated_tokens
117
 
 
120
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
121
  filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
122
 
123
+ highlighted = []
124
+ for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
125
+ text = tokenizer.convert_tokens_to_string([tok_new])
126
+ if tok_new != tok_old:
127
+ highlighted.append(f"<span style='color:green'>{text}</span>")
128
+ else:
129
+ highlighted.append(text)
 
 
130
 
131
  prev_decoded_tokens = decoded_tokens
 
 
132
  last_tokens.append(generated_tokens)
133
+ if len(last_tokens) == 3 and all(t == last_tokens[0] for t in last_tokens):
 
 
 
134
  break
135
 
136
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
 
140
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
141
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
142
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
143
+ history.append("**Assistant:** " + final_output)
144
+ return history, final_output
145
+
146
+ # --- UI Layout ---
147
+ css = ".category-legend{display:none}"
148
+ with gr.Blocks(css=css) as demo:
149
+ gr.Markdown("# Tini Diffusion LLM 🌀")
150
+ with gr.Row():
151
+ with gr.Column(scale=3):
152
+ chatbox = gr.Chatbot(label="Conversation", value=[], height=400)
153
+ question_input = gr.Textbox(label="Your Question", placeholder="What do you want to ask?", scale=8)
154
+ send_btn = gr.Button("Generate")
155
+ with gr.Column(scale=2):
156
+ eot_weight = gr.Slider(0, 1, value=0.4, step=0.05, label="EOT weight")
157
+ max_iters = gr.Slider(1, 512, value=64, step=1, label="Iterations")
158
+ sharpness = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Sharpness")
159
+
160
+ def handle_submit(question, eot, max_it, sharp):
161
+ history, _ = run_diffusion_loop(question, eot, max_it, sharp)
162
+ return history
163
+
164
+ send_btn.click(
165
+ fn=handle_submit,
166
+ inputs=[question_input, eot_weight, max_iters, sharpness],
167
+ outputs=[chatbox]
168
+ )
169
 
170
+ demo.queue().launch()