oucgc1996 commited on
Commit
15d8fff
·
verified ·
1 Parent(s): 1c09c7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -9,9 +9,6 @@ import time
9
 
10
  is_stopped = False
11
 
12
- seed = random.randint(0,100000)
13
- setup_seed(seed)
14
-
15
  def temperature_sampling(logits, temperature):
16
  logits = logits / temperature
17
  probabilities = torch.softmax(logits, dim=-1)
@@ -23,7 +20,13 @@ def stop_generation():
23
  is_stopped = True
24
  return "Generation stopped."
25
 
26
- def CTXGen(X1, X2, τ, g_num, length_range, model_name):
 
 
 
 
 
 
27
  global is_stopped
28
  is_stopped = False
29
  start, end = length_range
@@ -158,6 +161,8 @@ with gr.Blocks() as demo:
158
  gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
159
  gr.Markdown("✅**Length range**: expected length range of conotoxins generated")
160
  gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
 
 
161
  with gr.Row():
162
  X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
163
  '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
@@ -168,8 +173,10 @@ with gr.Blocks() as demo:
168
  X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
169
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
170
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
 
171
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
172
  model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
 
173
  with gr.Row():
174
  start_button = gr.Button("Start Generation")
175
  stop_button = gr.Button("Stop Generation")
@@ -178,7 +185,7 @@ with gr.Blocks() as demo:
178
  with gr.Row():
179
  output_df = gr.DataFrame(label="Generated Conotoxins")
180
 
181
- start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name], outputs=[output_file, output_df])
182
  stop_button.click(stop_generation, outputs=None)
183
 
184
  demo.launch()
 
9
 
10
  is_stopped = False
11
 
 
 
 
12
  def temperature_sampling(logits, temperature):
13
  logits = logits / temperature
14
  probabilities = torch.softmax(logits, dim=-1)
 
20
  is_stopped = True
21
  return "Generation stopped."
22
 
23
+ def CTXGen(X1, X2, τ, g_num, length_range, model_name, seed):
24
+ if seed =='random':
25
+ seed = random.randint(0,100000)
26
+ setup_seed(seed)
27
+ else:
28
+ setup_seed(int(seed))
29
+
30
  global is_stopped
31
  is_stopped = False
32
  start, end = length_range
 
161
  gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
162
  gr.Markdown("✅**Length range**: expected length range of conotoxins generated")
163
  gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
164
+ gr.Markdown("✅**Seed**: Enter an integer as the random seed to ensure reproducible results. The default is random")
165
+
166
  with gr.Row():
167
  X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
168
  '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
 
173
  X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
174
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
175
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
176
+ with gr.Row():
177
  length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
178
  model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
179
+ seed = gr.Textbox(label="Seed", value="random")
180
  with gr.Row():
181
  start_button = gr.Button("Start Generation")
182
  stop_button = gr.Button("Stop Generation")
 
185
  with gr.Row():
186
  output_df = gr.DataFrame(label="Generated Conotoxins")
187
 
188
+ start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name,seed], outputs=[output_file, output_df])
189
  stop_button.click(stop_generation, outputs=None)
190
 
191
  demo.launch()