oucgc1996 commited on
Commit
20db68b
·
verified ·
1 Parent(s): e4fcf38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -171
app.py CHANGED
@@ -1,172 +1,172 @@
1
- import torch
2
- import random
3
- import pandas as pd
4
- from utils import create_vocab, setup_seed
5
- from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
- import gradio as gr
7
- from gradio_rangeslider import RangeSlider
8
- import time
9
-
10
- is_stopped = False
11
-
12
- seed = random.randint(0,100000)
13
- setup_seed(seed)
14
-
15
- device = torch.device("cpu")
16
- vocab_mlm = create_vocab()
17
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
18
- save_path = 'mlm-model-27.pt'
19
- train_seqs = pd.read_csv('C0_seq.csv')
20
- train_seq = train_seqs['Seq'].tolist()
21
- model = torch.load(save_path)
22
- model = model.to(device)
23
-
24
- def temperature_sampling(logits, temperature):
25
- logits = logits / temperature
26
- probabilities = torch.softmax(logits, dim=-1)
27
- sampled_token = torch.multinomial(probabilities, 1)
28
- return sampled_token
29
-
30
- def stop_generation():
31
- global is_stopped
32
- is_stopped = True
33
- return "Generation stopped."
34
-
35
- def CTXGen(X1, X2, τ, g_num, length_range):
36
- global is_stopped
37
- is_stopped = False
38
- start, end = length_range
39
-
40
- msa_data = pd.read_csv('conoData_C0.csv')
41
- msa = msa_data['Sequences'].tolist()
42
- msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
43
- msa = random.choice(msa)
44
- X4 = msa.split("|")[3]
45
- X5 = msa.split("|")[4]
46
- X6 = msa.split("|")[5]
47
-
48
- model.eval()
49
- with torch.no_grad():
50
- new_seq = None
51
- IDs = []
52
- generated_seqs = []
53
- generated_seqs_FINAL = []
54
- cls_probability_all = []
55
- act_probability_all = []
56
-
57
- count = 0
58
- gen_num = int(g_num)
59
- NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
60
- '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
61
- '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
62
- '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
63
- '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
64
- '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
65
- start_time = time.time()
66
- while count < gen_num:
67
- if is_stopped:
68
- return pd.DataFrame(), "output.csv"
69
-
70
- if time.time() - start_time > 1200:
71
- break
72
-
73
- gen_len = random.randint(int(start), int(end))
74
- X3 = "X" * gen_len
75
- seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
76
- vocab_mlm.token_to_idx["X"] = 4
77
-
78
- padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
79
- input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
80
-
81
- gen_length = len(input_text)
82
- length = gen_length - sum(1 for x in input_text if x != '[MASK]')
83
-
84
- for i in range(length):
85
- if is_stopped:
86
- return pd.DataFrame(), "output.csv"
87
-
88
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
89
- idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
90
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
91
- attn_idx = torch.tensor(attn_idx).to(device)
92
-
93
- mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
94
- mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
95
-
96
- logits = model(idx_seq,idx_msa, attn_idx)
97
- mask_logits = logits[0, mask_position.item(), :] #
98
-
99
- predicted_token_id = temperature_sampling(mask_logits, τ)
100
-
101
- predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
102
- input_text[mask_position.item()] = predicted_token
103
- padded_seq[mask_position.item()] = predicted_token.strip()
104
- new_seq = padded_seq
105
-
106
- generated_seq = input_text
107
-
108
- generated_seq[1] = "[MASK]"
109
- input_ids = vocab_mlm.__getitem__(generated_seq)
110
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
111
- cls_mask_logits = logits[0, 1, :]
112
- cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=5)
113
-
114
- generated_seq[2] = "[MASK]"
115
- input_ids = vocab_mlm.__getitem__(generated_seq)
116
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
117
- act_mask_logits = logits[0, 2, :]
118
- act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
119
-
120
- cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
121
- act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
122
-
123
- if X1 in cls_pos and X2 in act_pos:
124
- cls_proba = cls_probability[cls_pos.index(X1)].item()
125
- act_proba = act_probability[act_pos.index(X2)].item()
126
- generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
127
- if act_proba>=0.5 and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
128
- generated_seqs.append("".join(generated_seq))
129
- if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
130
- generated_seqs_FINAL.append("".join(generated_seq))
131
- cls_probability_all.append(cls_proba)
132
- act_probability_all.append(act_proba)
133
- IDs.append(count+1)
134
- out = pd.DataFrame({
135
- 'ID':IDs,
136
- 'Generated_seq': generated_seqs_FINAL,
137
- 'Subtype': X1,
138
- 'cls_probability': cls_probability_all,
139
- 'Potency': X2,
140
- 'Potency_probability': act_probability_all,
141
- 'random_seed': seed
142
- })
143
- out.to_csv("output.csv", index=False, encoding='utf-8-sig')
144
- count += 1
145
- yield out, "output.csv"
146
- return out, "output.csv"
147
-
148
- with gr.Blocks() as demo:
149
- gr.Markdown("# Conotoxin Generation")
150
- with gr.Row():
151
- X1 = gr.Dropdown(choices=['<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
152
- '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
153
- '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
154
- '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
155
- '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
156
- '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
157
- X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
158
- τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
159
- g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
160
- length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
161
- with gr.Row():
162
- start_button = gr.Button("Start Generation")
163
- stop_button = gr.Button("Stop Generation")
164
- with gr.Row():
165
- output_df = gr.DataFrame(label="Generated Conotoxins")
166
- with gr.Row():
167
- output_file = gr.File(label="Download generated conotoxins")
168
-
169
- start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range], outputs=[output_df, output_file])
170
- stop_button.click(stop_generation, outputs=None)
171
-
172
  demo.launch()
 
1
+ import torch
2
+ import random
3
+ import pandas as pd
4
+ from utils import create_vocab, setup_seed
5
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
+ import gradio as gr
7
+ from gradio_rangeslider import RangeSlider
8
+ import time
9
+
10
+ is_stopped = False
11
+
12
+ seed = random.randint(0,100000)
13
+ setup_seed(seed)
14
+
15
+ device = torch.device("cpu")
16
+ vocab_mlm = create_vocab()
17
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
18
+ save_path = 'mlm-model-27.pt'
19
+ train_seqs = pd.read_csv('C0_seq.csv')
20
+ train_seq = train_seqs['Seq'].tolist()
21
+ model = torch.load(save_path, map_location=torch.device('cpu'))
22
+ model = model.to(device)
23
+
24
+ def temperature_sampling(logits, temperature):
25
+ logits = logits / temperature
26
+ probabilities = torch.softmax(logits, dim=-1)
27
+ sampled_token = torch.multinomial(probabilities, 1)
28
+ return sampled_token
29
+
30
+ def stop_generation():
31
+ global is_stopped
32
+ is_stopped = True
33
+ return "Generation stopped."
34
+
35
+ def CTXGen(X1, X2, τ, g_num, length_range):
36
+ global is_stopped
37
+ is_stopped = False
38
+ start, end = length_range
39
+
40
+ msa_data = pd.read_csv('conoData_C0.csv')
41
+ msa = msa_data['Sequences'].tolist()
42
+ msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
43
+ msa = random.choice(msa)
44
+ X4 = msa.split("|")[3]
45
+ X5 = msa.split("|")[4]
46
+ X6 = msa.split("|")[5]
47
+
48
+ model.eval()
49
+ with torch.no_grad():
50
+ new_seq = None
51
+ IDs = []
52
+ generated_seqs = []
53
+ generated_seqs_FINAL = []
54
+ cls_probability_all = []
55
+ act_probability_all = []
56
+
57
+ count = 0
58
+ gen_num = int(g_num)
59
+ NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
60
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
61
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
62
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
63
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
64
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
65
+ start_time = time.time()
66
+ while count < gen_num:
67
+ if is_stopped:
68
+ return pd.DataFrame(), "output.csv"
69
+
70
+ if time.time() - start_time > 1200:
71
+ break
72
+
73
+ gen_len = random.randint(int(start), int(end))
74
+ X3 = "X" * gen_len
75
+ seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
76
+ vocab_mlm.token_to_idx["X"] = 4
77
+
78
+ padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
79
+ input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
80
+
81
+ gen_length = len(input_text)
82
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
83
+
84
+ for i in range(length):
85
+ if is_stopped:
86
+ return pd.DataFrame(), "output.csv"
87
+
88
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
89
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
90
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
91
+ attn_idx = torch.tensor(attn_idx).to(device)
92
+
93
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
94
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
95
+
96
+ logits = model(idx_seq,idx_msa, attn_idx)
97
+ mask_logits = logits[0, mask_position.item(), :] #
98
+
99
+ predicted_token_id = temperature_sampling(mask_logits, τ)
100
+
101
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
102
+ input_text[mask_position.item()] = predicted_token
103
+ padded_seq[mask_position.item()] = predicted_token.strip()
104
+ new_seq = padded_seq
105
+
106
+ generated_seq = input_text
107
+
108
+ generated_seq[1] = "[MASK]"
109
+ input_ids = vocab_mlm.__getitem__(generated_seq)
110
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
111
+ cls_mask_logits = logits[0, 1, :]
112
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=5)
113
+
114
+ generated_seq[2] = "[MASK]"
115
+ input_ids = vocab_mlm.__getitem__(generated_seq)
116
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
117
+ act_mask_logits = logits[0, 2, :]
118
+ act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
119
+
120
+ cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
121
+ act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
122
+
123
+ if X1 in cls_pos and X2 in act_pos:
124
+ cls_proba = cls_probability[cls_pos.index(X1)].item()
125
+ act_proba = act_probability[act_pos.index(X2)].item()
126
+ generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
127
+ if act_proba>=0.5 and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
128
+ generated_seqs.append("".join(generated_seq))
129
+ if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
130
+ generated_seqs_FINAL.append("".join(generated_seq))
131
+ cls_probability_all.append(cls_proba)
132
+ act_probability_all.append(act_proba)
133
+ IDs.append(count+1)
134
+ out = pd.DataFrame({
135
+ 'ID':IDs,
136
+ 'Generated_seq': generated_seqs_FINAL,
137
+ 'Subtype': X1,
138
+ 'cls_probability': cls_probability_all,
139
+ 'Potency': X2,
140
+ 'Potency_probability': act_probability_all,
141
+ 'random_seed': seed
142
+ })
143
+ out.to_csv("output.csv", index=False, encoding='utf-8-sig')
144
+ count += 1
145
+ yield out, "output.csv"
146
+ return out, "output.csv"
147
+
148
+ with gr.Blocks() as demo:
149
+ gr.Markdown("# Conotoxin Generation")
150
+ with gr.Row():
151
+ X1 = gr.Dropdown(choices=['<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
152
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
153
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
154
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
155
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
156
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
157
+ X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
158
+ τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
159
+ g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
160
+ length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
161
+ with gr.Row():
162
+ start_button = gr.Button("Start Generation")
163
+ stop_button = gr.Button("Stop Generation")
164
+ with gr.Row():
165
+ output_df = gr.DataFrame(label="Generated Conotoxins")
166
+ with gr.Row():
167
+ output_file = gr.File(label="Download generated conotoxins")
168
+
169
+ start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range], outputs=[output_df, output_file])
170
+ stop_button.click(stop_generation, outputs=None)
171
+
172
  demo.launch()