oucgc1996 commited on
Commit
5144693
·
verified ·
1 Parent(s): d8190f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -9,9 +9,6 @@ import time
9
 
10
  is_stopped = False
11
 
12
- # seed = random.randint(0,100000)
13
- setup_seed(4)
14
-
15
  def temperature_sampling(logits, temperature):
16
  logits = logits / temperature
17
  probabilities = torch.softmax(logits, dim=-1)
@@ -23,7 +20,12 @@ def stop_generation():
23
  is_stopped = True
24
  return "Generation stopped."
25
 
26
- def CTXGen(X0, X3, X1, X2, τ, g_num, model_name):
 
 
 
 
 
27
  global is_stopped
28
  is_stopped = False
29
 
@@ -162,7 +164,7 @@ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name):
162
  'Subtype_probability': cls_probability_all,
163
  'Potency': X2,
164
  'Potency_probability': act_probability_all,
165
- 'Random_seed': seed
166
  })
167
  out.to_csv("output.csv", index=False, encoding='utf-8-sig')
168
  count += 1
@@ -198,6 +200,7 @@ with gr.Blocks() as demo:
198
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
199
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
200
  model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
 
201
  with gr.Row():
202
  start_button = gr.Button("Start Generation")
203
  stop_button = gr.Button("Stop Generation")
@@ -206,7 +209,7 @@ with gr.Blocks() as demo:
206
  with gr.Row():
207
  output_df = gr.DataFrame(label="Generated Conotoxins")
208
 
209
- start_button.click(CTXGen, inputs=[X0, X3, X1, X2, τ, g_num, model_name], outputs=[output_file, output_df])
210
  stop_button.click(stop_generation, outputs=None)
211
 
212
  demo.launch()
 
9
 
10
  is_stopped = False
11
 
 
 
 
12
  def temperature_sampling(logits, temperature):
13
  logits = logits / temperature
14
  probabilities = torch.softmax(logits, dim=-1)
 
20
  is_stopped = True
21
  return "Generation stopped."
22
 
23
+ def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
24
+ if seed =='random'
25
+ seed = random.randint(0,100000)
26
+ setup_seed(seed)
27
+ else:
28
+ setup_seed(int(seed))
29
  global is_stopped
30
  is_stopped = False
31
 
 
164
  'Subtype_probability': cls_probability_all,
165
  'Potency': X2,
166
  'Potency_probability': act_probability_all,
167
+ 'Random_seed': int(seed)
168
  })
169
  out.to_csv("output.csv", index=False, encoding='utf-8-sig')
170
  count += 1
 
200
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
201
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
202
  model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
203
+ seed = gr.Textbox(label="Seed", value="random")
204
  with gr.Row():
205
  start_button = gr.Button("Start Generation")
206
  stop_button = gr.Button("Stop Generation")
 
209
  with gr.Row():
210
  output_df = gr.DataFrame(label="Generated Conotoxins")
211
 
212
+ start_button.click(CTXGen, inputs=[X0, X3, X1, X2, τ, g_num, model_name, seed], outputs=[output_file, output_df])
213
  stop_button.click(stop_generation, outputs=None)
214
 
215
  demo.launch()