Ruurd commited on
Commit
7aaa1c3
·
1 Parent(s): bba0c8d

Make model global

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -73,7 +73,7 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
73
  noised[idx] = val
74
  return noised
75
 
76
- def generate_diffusion_text(model, input_ids, answer_start):
77
  with torch.no_grad():
78
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
79
  logits = model(input_ids=input_tensor)["logits"]
@@ -83,7 +83,7 @@ def generate_diffusion_text(model, input_ids, answer_start):
83
  return input_ids[:answer_start] + sampled[answer_start:]
84
 
85
  # --- Inference Wrapper ---
86
- def diffusion_chat(question, eot_weight, max_it, sharpness, model):
87
  placeholder = "What do you know about the city of New York?"
88
  if question.strip() == "":
89
  question = placeholder
@@ -106,7 +106,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, model):
106
  last_tokens = []
107
 
108
  for i in range(max_it):
109
- generated_tokens = generate_diffusion_text(model, current_tokens, answer_start)
110
  current_tokens = generated_tokens
111
 
112
  decoded_ids = current_tokens[answer_start:]
@@ -144,7 +144,10 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, model):
144
  yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
145
 
146
  # --- Gradio Interface ---
147
- model_state = gr.State(value=load_model()) # this just stores the object
 
 
 
148
 
149
  demo = gr.Interface(
150
  fn=diffusion_chat,
@@ -152,8 +155,7 @@ demo = gr.Interface(
152
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
153
  gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
154
  gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
155
- gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
156
- model_state
157
  ],
158
  outputs=gr.HTML(label="Diffusion Output"),
159
  title="Diffusion Language Model Chat",
 
73
  noised[idx] = val
74
  return noised
75
 
76
+ def generate_diffusion_text(input_ids, answer_start):
77
  with torch.no_grad():
78
  input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
79
  logits = model(input_ids=input_tensor)["logits"]
 
83
  return input_ids[:answer_start] + sampled[answer_start:]
84
 
85
  # --- Inference Wrapper ---
86
+ def diffusion_chat(question, eot_weight, max_it, sharpness):
87
  placeholder = "What do you know about the city of New York?"
88
  if question.strip() == "":
89
  question = placeholder
 
106
  last_tokens = []
107
 
108
  for i in range(max_it):
109
+ generated_tokens = generate_diffusion_text(current_tokens, answer_start)
110
  current_tokens = generated_tokens
111
 
112
  decoded_ids = current_tokens[answer_start:]
 
144
  yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
145
 
146
  # --- Gradio Interface ---
147
+
148
+ print("Loading model...")
149
+ model = load_model()
150
+ print("✅ Model loaded.")
151
 
152
  demo = gr.Interface(
153
  fn=diffusion_chat,
 
155
  gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
156
  gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
157
  gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
158
+ gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)")
 
159
  ],
160
  outputs=gr.HTML(label="Diffusion Output"),
161
  title="Diffusion Language Model Chat",