Update app.py
Browse files
app.py
CHANGED
@@ -9,9 +9,6 @@ import time
|
|
9 |
|
10 |
is_stopped = False
|
11 |
|
12 |
-
seed = random.randint(0,100000)
|
13 |
-
setup_seed(seed)
|
14 |
-
|
15 |
def temperature_sampling(logits, temperature):
|
16 |
logits = logits / temperature
|
17 |
probabilities = torch.softmax(logits, dim=-1)
|
@@ -23,7 +20,13 @@ def stop_generation():
|
|
23 |
is_stopped = True
|
24 |
return "Generation stopped."
|
25 |
|
26 |
-
def CTXGen(X1, X2, τ, g_num, length_range, model_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
global is_stopped
|
28 |
is_stopped = False
|
29 |
start, end = length_range
|
@@ -158,6 +161,8 @@ with gr.Blocks() as demo:
|
|
158 |
gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
|
159 |
gr.Markdown("✅**Length range**: expected length range of conotoxins generated")
|
160 |
gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
|
|
|
|
|
161 |
with gr.Row():
|
162 |
X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
|
163 |
'<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
|
@@ -168,8 +173,10 @@ with gr.Blocks() as demo:
|
|
168 |
X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
|
169 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
170 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
|
|
171 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
172 |
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
|
|
|
173 |
with gr.Row():
|
174 |
start_button = gr.Button("Start Generation")
|
175 |
stop_button = gr.Button("Stop Generation")
|
@@ -178,7 +185,7 @@ with gr.Blocks() as demo:
|
|
178 |
with gr.Row():
|
179 |
output_df = gr.DataFrame(label="Generated Conotoxins")
|
180 |
|
181 |
-
start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name], outputs=[output_file, output_df])
|
182 |
stop_button.click(stop_generation, outputs=None)
|
183 |
|
184 |
demo.launch()
|
|
|
9 |
|
10 |
is_stopped = False
|
11 |
|
|
|
|
|
|
|
12 |
def temperature_sampling(logits, temperature):
|
13 |
logits = logits / temperature
|
14 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
20 |
is_stopped = True
|
21 |
return "Generation stopped."
|
22 |
|
23 |
+
def CTXGen(X1, X2, τ, g_num, length_range, model_name, seed):
|
24 |
+
if seed =='random':
|
25 |
+
seed = random.randint(0,100000)
|
26 |
+
setup_seed(seed)
|
27 |
+
else:
|
28 |
+
setup_seed(int(seed))
|
29 |
+
|
30 |
global is_stopped
|
31 |
is_stopped = False
|
32 |
start, end = length_range
|
|
|
161 |
gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
|
162 |
gr.Markdown("✅**Length range**: expected length range of conotoxins generated")
|
163 |
gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
|
164 |
+
gr.Markdown("✅**Seed**: Enter an integer as the random seed to ensure reproducible results. The default is random")
|
165 |
+
|
166 |
with gr.Row():
|
167 |
X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
|
168 |
'<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
|
|
|
173 |
X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
|
174 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
175 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
176 |
+
with gr.Row():
|
177 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
178 |
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
|
179 |
+
seed = gr.Textbox(label="Seed", value="random")
|
180 |
with gr.Row():
|
181 |
start_button = gr.Button("Start Generation")
|
182 |
stop_button = gr.Button("Stop Generation")
|
|
|
185 |
with gr.Row():
|
186 |
output_df = gr.DataFrame(label="Generated Conotoxins")
|
187 |
|
188 |
+
start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name,seed], outputs=[output_file, output_df])
|
189 |
stop_button.click(stop_generation, outputs=None)
|
190 |
|
191 |
demo.launch()
|