Ruurd commited on
Commit
42c0401
·
1 Parent(s): 55b43fa

Change interface

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -136,19 +136,31 @@ def diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
136
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
137
  yield f"<div style='padding:0.5em'><b>Final Output:</b><br><div style='background:#e0ffe0;padding:0.5em;border-radius:0.5em'>{final_output}</div></div>"
138
 
139
- # --- Chat Interface ---
140
- demo = gr.ChatInterface(
141
- diffusion_chat,
142
- additional_inputs=[
143
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
144
- gr.Slider(0, 1, value=0.4, step=0.05, label="EOT token weight (lower = longer output)"),
145
- gr.Slider(1, 512, value=64, step=1, label="Max Iterations"),
146
- gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Noising sharpness (lower = more noise)")
147
- ],
148
- title="Diffusion Language Model Chat",
149
- description="Iterative denoising chat interface using a fine-tuned LLaMA model."
150
- )
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  if __name__ == "__main__":
154
  demo.launch()
 
136
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
137
  yield f"<div style='padding:0.5em'><b>Final Output:</b><br><div style='background:#e0ffe0;padding:0.5em;border-radius:0.5em'>{final_output}</div></div>"
138
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ with gr.Blocks() as demo:
141
+ gr.Markdown("## Diffusion Language Model Chat")
142
+ with gr.Row():
143
+ with gr.Column(scale=3):
144
+ chatbot = gr.Chatbot()
145
+ message = gr.Textbox(label="User Message")
146
+ submit = gr.Button("Send")
147
+ with gr.Column(scale=1):
148
+ system_prompt = gr.Textbox(value="You are a helpful assistant.", label="System Message")
149
+ eot_weight = gr.Slider(0, 1, value=0.4, step=0.05, label="EOT token weight")
150
+ max_it = gr.Slider(1, 512, value=64, step=1, label="Max Iterations")
151
+ sharpness = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Noising Sharpness")
152
+
153
+ def wrapped_chat(message, history, system_prompt, eot_weight, max_it, sharpness):
154
+ history = history or []
155
+ for update in diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
156
+ yield history + [(message, update)]
157
+
158
+ submit.click(
159
+ fn=wrapped_chat,
160
+ inputs=[message, chatbot, system_prompt, eot_weight, max_it, sharpness],
161
+ outputs=chatbot,
162
+ )
163
+
164
 
165
  if __name__ == "__main__":
166
  demo.launch()