Spaces:
Sleeping
Sleeping
Make model global
Browse files
app.py
CHANGED
@@ -73,7 +73,7 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
|
|
73 |
noised[idx] = val
|
74 |
return noised
|
75 |
|
76 |
-
def generate_diffusion_text(
|
77 |
with torch.no_grad():
|
78 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
79 |
logits = model(input_ids=input_tensor)["logits"]
|
@@ -83,7 +83,7 @@ def generate_diffusion_text(model, input_ids, answer_start):
|
|
83 |
return input_ids[:answer_start] + sampled[answer_start:]
|
84 |
|
85 |
# --- Inference Wrapper ---
|
86 |
-
def diffusion_chat(question, eot_weight, max_it, sharpness
|
87 |
placeholder = "What do you know about the city of New York?"
|
88 |
if question.strip() == "":
|
89 |
question = placeholder
|
@@ -106,7 +106,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, model):
|
|
106 |
last_tokens = []
|
107 |
|
108 |
for i in range(max_it):
|
109 |
-
generated_tokens = generate_diffusion_text(
|
110 |
current_tokens = generated_tokens
|
111 |
|
112 |
decoded_ids = current_tokens[answer_start:]
|
@@ -144,7 +144,10 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, model):
|
|
144 |
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
|
145 |
|
146 |
# --- Gradio Interface ---
|
147 |
-
|
|
|
|
|
|
|
148 |
|
149 |
demo = gr.Interface(
|
150 |
fn=diffusion_chat,
|
@@ -152,8 +155,7 @@ demo = gr.Interface(
|
|
152 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
|
153 |
gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
|
154 |
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
|
155 |
-
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)")
|
156 |
-
model_state
|
157 |
],
|
158 |
outputs=gr.HTML(label="Diffusion Output"),
|
159 |
title="Diffusion Language Model Chat",
|
|
|
73 |
noised[idx] = val
|
74 |
return noised
|
75 |
|
76 |
+
def generate_diffusion_text(input_ids, answer_start):
|
77 |
with torch.no_grad():
|
78 |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
|
79 |
logits = model(input_ids=input_tensor)["logits"]
|
|
|
83 |
return input_ids[:answer_start] + sampled[answer_start:]
|
84 |
|
85 |
# --- Inference Wrapper ---
|
86 |
+
def diffusion_chat(question, eot_weight, max_it, sharpness):
|
87 |
placeholder = "What do you know about the city of New York?"
|
88 |
if question.strip() == "":
|
89 |
question = placeholder
|
|
|
106 |
last_tokens = []
|
107 |
|
108 |
for i in range(max_it):
|
109 |
+
generated_tokens = generate_diffusion_text(current_tokens, answer_start)
|
110 |
current_tokens = generated_tokens
|
111 |
|
112 |
decoded_ids = current_tokens[answer_start:]
|
|
|
144 |
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
|
145 |
|
146 |
# --- Gradio Interface ---
|
147 |
+
|
148 |
+
print("Loading model...")
|
149 |
+
model = load_model()
|
150 |
+
print("✅ Model loaded.")
|
151 |
|
152 |
demo = gr.Interface(
|
153 |
fn=diffusion_chat,
|
|
|
155 |
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
|
156 |
gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
|
157 |
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
|
158 |
+
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)")
|
|
|
159 |
],
|
160 |
outputs=gr.HTML(label="Diffusion Output"),
|
161 |
title="Diffusion Language Model Chat",
|