Spaces:
Running on Zero

Ruurd commited on
Commit
0e1a415
·
1 Parent(s): 84a6c46

Change interface

Browse files
Files changed (1) hide show
  1. app.py +29 -47
app.py CHANGED
@@ -73,6 +73,10 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
73
  noised[idx] = val
74
  return noised
75
 
 
 
 
 
76
  def generate_diffusion_text(input_ids, answer_start):
77
  with torch.no_grad():
78
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
@@ -82,34 +86,20 @@ def generate_diffusion_text(input_ids, answer_start):
82
  sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
83
  return input_ids[:answer_start] + sampled[answer_start:]
84
 
85
- # --- Inference Wrapper ---
86
-
87
-
88
-
89
- # --- Gradio Interface ---
90
-
91
- print("Loading model...")
92
- model = load_model()
93
- print("✅ Model loaded.")
94
-
95
- # --- Generation logic ---
96
  @spaces.GPU
97
- def run_diffusion_loop(question, eot_weight, max_it, sharpness):
98
- placeholder = "What do you know about the city of New York?"
99
- if question.strip() == "":
100
- question = placeholder
101
-
102
- prompt = f"User: {question}\nAssistant:"
103
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
104
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
105
  if answer_start is None:
106
- return [], "Error: Could not find Assistant marker in input."
 
107
 
108
  input_ids = (input_ids + [pad_token] * (256 - len(input_ids)))[:256]
109
  current_tokens = noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=eot_weight)
110
  prev_decoded_tokens = []
111
  last_tokens = []
112
- history = ["**User:** " + question]
113
 
114
  for i in range(max_it):
115
  generated_tokens = generate_diffusion_text(current_tokens, answer_start)
@@ -129,8 +119,12 @@ def run_diffusion_loop(question, eot_weight, max_it, sharpness):
129
  highlighted.append(text)
130
 
131
  prev_decoded_tokens = decoded_tokens
 
 
 
132
  last_tokens.append(generated_tokens)
133
  if len(last_tokens) == 3 and all(t == last_tokens[0] for t in last_tokens):
 
134
  break
135
 
136
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
@@ -140,31 +134,19 @@ def run_diffusion_loop(question, eot_weight, max_it, sharpness):
140
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
141
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
142
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
143
- history.append("**Assistant:** " + final_output)
144
- return history, final_output
145
-
146
- # --- UI Layout ---
147
- css = ".category-legend{display:none}"
148
- with gr.Blocks(css=css) as demo:
149
- gr.Markdown("# Tini Diffusion LLM 🌀")
150
- with gr.Row():
151
- with gr.Column(scale=3):
152
- chatbox = gr.Chatbot(label="Conversation", value=[], height=400)
153
- question_input = gr.Textbox(label="Your Question", placeholder="What do you want to ask?", scale=8)
154
- send_btn = gr.Button("Generate")
155
- with gr.Column(scale=2):
156
- eot_weight = gr.Slider(0, 1, value=0.4, step=0.05, label="EOT weight")
157
- max_iters = gr.Slider(1, 512, value=64, step=1, label="Iterations")
158
- sharpness = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Sharpness")
159
-
160
- def handle_submit(question, eot, max_it, sharp):
161
- history, _ = run_diffusion_loop(question, eot, max_it, sharp)
162
- return history
163
-
164
- send_btn.click(
165
- fn=handle_submit,
166
- inputs=[question_input, eot_weight, max_iters, sharpness],
167
- outputs=[chatbox]
168
- )
169
-
170
- demo.queue().launch()
 
73
  noised[idx] = val
74
  return noised
75
 
76
+ print("Loading model...")
77
+ model = load_model()
78
+ print("✅ Model loaded.")
79
+
80
  def generate_diffusion_text(input_ids, answer_start):
81
  with torch.no_grad():
82
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
 
86
  sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
87
  return input_ids[:answer_start] + sampled[answer_start:]
88
 
89
+ # --- Diffusion Chat Function ---
 
 
 
 
 
 
 
 
 
 
90
  @spaces.GPU
91
+ def diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
92
+ prompt = f"{system_prompt}\nUser: {message}\nAssistant:"
 
 
 
 
93
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
94
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
95
  if answer_start is None:
96
+ yield "<span style='color:red'><b>Error:</b> Could not find Assistant marker in input.</span>"
97
+ return
98
 
99
  input_ids = (input_ids + [pad_token] * (256 - len(input_ids)))[:256]
100
  current_tokens = noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=eot_weight)
101
  prev_decoded_tokens = []
102
  last_tokens = []
 
103
 
104
  for i in range(max_it):
105
  generated_tokens = generate_diffusion_text(current_tokens, answer_start)
 
119
  highlighted.append(text)
120
 
121
  prev_decoded_tokens = decoded_tokens
122
+ yield ("<div style='padding:0.5em'><b>Iteration {}</b><br>"
123
+ "<div style='background:#f5f5f5;padding:0.5em;border-radius:0.5em'>{}</div></div>").format(i+1, ''.join(highlighted))
124
+
125
  last_tokens.append(generated_tokens)
126
  if len(last_tokens) == 3 and all(t == last_tokens[0] for t in last_tokens):
127
+ yield f"<div style='color:gray'><i>Stopped early after {i+1} iterations (converged).</i></div>"
128
  break
129
 
130
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
 
134
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
135
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
136
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
137
+ yield f"<div style='padding:0.5em'><b>Final Output:</b><br><div style='background:#e0ffe0;padding:0.5em;border-radius:0.5em'>{final_output}</div></div>"
138
+
139
+ # --- Chat Interface ---
140
+ demo = gr.ChatInterface(
141
+ diffusion_chat,
142
+ additional_inputs=[
143
+ gr.Textbox(value="You are a helpful assistant.", label="System message"),
144
+ gr.Slider(0, 1, value=0.4, step=0.05, label="EOT token weight (lower = longer output)"),
145
+ gr.Slider(1, 512, value=64, step=1, label="Max Iterations"),
146
+ gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Noising sharpness (lower = more noise)")
147
+ ],
148
+ title="Diffusion Language Model Chat",
149
+ description="Iterative denoising chat interface using a fine-tuned LLaMA model."
150
+ )
151
+
152
+ demo.launch()