oucgc1996 commited on
Commit
aadb062
·
verified ·
1 Parent(s): f8875db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -134
app.py CHANGED
@@ -2,14 +2,14 @@ import torch
2
  import random
3
  import pandas as pd
4
  from utils import create_vocab, setup_seed
5
- from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
9
 
10
  is_stopped = False
11
 
12
- seed = random.randint(0,100000)
13
  setup_seed(seed)
14
 
15
  def temperature_sampling(logits, temperature):
@@ -32,158 +32,146 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
32
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
33
  save_path = model_name
34
  train_seqs = pd.read_csv('C0_seq.csv')
35
- train_seq = train_seqs['Seq'].tolist()
36
  model = torch.load(save_path, map_location=torch.device('cpu'))
37
  model = model.to(device)
 
38
 
39
  X3 = "X" * len(X0)
40
- msa_data = pd.read_csv('conoData_C0.csv')
41
  msa = msa_data['Sequences'].tolist()
42
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
43
  if not msa:
44
- X4 = ""
45
- X5 = ""
46
- X6 = ""
47
  else:
48
  msa = random.choice(msa)
49
- X4 = msa.split("|")[3]
50
- X5 = msa.split("|")[4]
51
- X6 = msa.split("|")[5]
52
- model.eval()
53
- with torch.no_grad():
54
- new_seq = None
55
- IDs = []
56
- generated_seqs = []
57
- generated_seqs_FINAL = []
58
- cls_probability_all = []
59
- act_probability_all = []
60
- count = 0
61
- gen_num = g_num
62
- NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
63
- '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
64
- '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
65
- '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
66
- '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
67
- '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
68
-
69
- seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
70
- padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, new_seq)
71
- idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
72
- seqseq_parent = ["[MASK]" if i=="X" else i for i in padded_seqseq_parent]
73
-
74
- seqseq_parent[1] = "[MASK]"
75
- input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
76
- logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
77
-
78
- cls_mask_logits_parent = logits_parent[0, 1, :]
79
- cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=85)
80
-
81
- seqseq_parent[2] = "[MASK]"
82
- input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
83
- logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
84
- act_mask_logits_parent = logits_parent[0, 2, :]
85
- act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
86
-
87
- cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
88
- act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
89
-
90
- cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
91
- act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
92
-
93
- start_time = time.time()
94
- while count < gen_num:
95
- gen_len = len(X0)
 
 
 
 
 
96
  if is_stopped:
97
  return pd.DataFrame(), "output.csv"
98
 
99
- if time.time() - start_time > 1200:
100
- break
101
-
102
- seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
103
- vocab_mlm.token_to_idx["X"] = 4
104
-
105
- padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
106
- input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
107
-
108
- gen_length = len(input_text)
109
- length = gen_length - sum(1 for x in input_text if x != '[MASK]')
110
- for i in range(length):
111
- if is_stopped:
112
- return pd.DataFrame(), "output.csv"
113
-
114
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
115
- idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
116
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
117
- attn_idx = torch.tensor(attn_idx).to(device)
118
-
119
- mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
120
- mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
121
-
122
- logits = model(idx_seq,idx_msa, attn_idx)
123
- mask_logits = logits[0, mask_position.item(), :]
124
-
125
- predicted_token_id = temperature_sampling(mask_logits, τ)
126
-
127
- predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
128
- input_text[mask_position.item()] = predicted_token
129
- padded_seq[mask_position.item()] = predicted_token.strip()
130
- new_seq = padded_seq
131
- generated_seq = input_text
132
-
133
- generated_seq[1] = "[MASK]"
134
- input_ids = vocab_mlm.__getitem__(generated_seq)
135
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
136
- cls_mask_logits = logits[0, 1, :]
137
- cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
138
-
139
- generated_seq[2] = "[MASK]"
140
- input_ids = vocab_mlm.__getitem__(generated_seq)
141
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
142
- act_mask_logits = logits[0, 2, :]
143
- act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
144
-
145
- cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
146
- act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
147
-
148
- if X1 in cls_pos and X2 in act_pos:
149
- cls_proba = cls_probability[cls_pos.index(X1)].item()
150
- act_proba = act_probability[act_pos.index(X2)].item()
151
- generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
152
- if cls_proba>=cls_proba_parent and act_proba>=act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
153
- generated_seqs.append("".join(generated_seq))
154
- if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
155
- generated_seqs_FINAL.append("".join(generated_seq))
156
- cls_probability_all.append(cls_proba)
157
- act_probability_all.append(act_proba)
158
- IDs.append(count+1)
159
- out = pd.DataFrame({
160
- 'ID':IDs,
161
- 'Generated_seq': generated_seqs_FINAL,
162
- 'Subtype': X1,
163
- 'Subtype_probability': cls_probability_all,
164
- 'Potency': X2,
165
- 'Potency_probability': act_probability_all,
166
- 'Random_seed': seed
167
- })
168
- out.to_csv("output.csv", index=False, encoding='utf-8-sig')
169
- count += 1
170
- yield out, "output.csv"
171
  return out, "output.csv"
172
 
173
  with gr.Blocks() as demo:
174
  gr.Markdown("# Conotoxin Optimization Generation")
175
  with gr.Row():
176
  X0 = gr.Textbox(label="conotoxin")
177
- X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
178
- '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
179
- '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
180
- '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<Na18>',
181
- '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
182
  '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
183
- X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
184
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
185
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
186
- model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
187
  with gr.Row():
188
  start_button = gr.Button("Start Generation")
189
  stop_button = gr.Button("Stop Generation")
@@ -191,7 +179,7 @@ with gr.Blocks() as demo:
191
  output_df = gr.DataFrame(label="Generated Conotoxins")
192
  with gr.Row():
193
  output_file = gr.File(label="Download generated conotoxins")
194
-
195
  start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num, model_name], outputs=[output_df, output_file])
196
  stop_button.click(stop_generation, outputs=None)
197
 
 
2
  import random
3
  import pandas as pd
4
  from utils import create_vocab, setup_seed
5
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
9
 
10
  is_stopped = False
11
 
12
+ seed = random.randint(0, 100000)
13
  setup_seed(seed)
14
 
15
  def temperature_sampling(logits, temperature):
 
32
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
33
  save_path = model_name
34
  train_seqs = pd.read_csv('C0_seq.csv')
35
+ train_seq = set(train_seqs['Seq'].tolist()) # 使用集合加快查找速度
36
  model = torch.load(save_path, map_location=torch.device('cpu'))
37
  model = model.to(device)
38
+ model.eval()
39
 
40
  X3 = "X" * len(X0)
41
+ msa_data = pd.read_csv('conoData_C0.csv')
42
  msa = msa_data['Sequences'].tolist()
43
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
44
  if not msa:
45
+ X4, X5, X6 = "", "", ""
 
 
46
  else:
47
  msa = random.choice(msa)
48
+ X4, X5, X6 = msa.split("|")[3:6]
49
+
50
+ NON_AA = {"B", "O", "U", "Z", "X", '<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
51
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
52
+ '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>',
53
+ '<α6α3β4β3>', '<α2β2>', '<α6β4>', '<α2β4>', '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
54
+ '<α9α10>', '<α6α3β4>', '<NaTTXS>', '<Na17>', '<high>', '<low>', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'}
55
+
56
+ seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
57
+ padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, None)
58
+ idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
59
+ seqseq_parent = ["[MASK]" if i == "X" else i for i in padded_seqseq_parent]
60
+
61
+ seqseq_parent[1] = "[MASK]"
62
+ input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
63
+ logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
64
+ cls_mask_logits_parent = logits_parent[0, 1, :]
65
+ cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=85)
66
+
67
+ seqseq_parent[2] = "[MASK]"
68
+ input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
69
+ logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
70
+ act_mask_logits_parent = logits_parent[0, 2, :]
71
+ act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
72
+
73
+ cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
74
+ act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
75
+
76
+ cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
77
+ act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
78
+
79
+ start_time = time.time()
80
+ count = 0
81
+ generated_seqs_FINAL = []
82
+ cls_probability_all = []
83
+ act_probability_all = []
84
+ IDs = []
85
+
86
+ while count < g_num:
87
+ if is_stopped:
88
+ return pd.DataFrame(), "output.csv"
89
+
90
+ if time.time() - start_time > 1200:
91
+ break
92
+
93
+ seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
94
+ padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, None)
95
+ input_text = ["[MASK]" if i == "X" else i for i in padded_seq]
96
+
97
+ gen_length = len(input_text)
98
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
99
+ for i in range(length):
100
  if is_stopped:
101
  return pd.DataFrame(), "output.csv"
102
 
103
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, None)
104
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
105
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
106
+ attn_idx = torch.tensor(attn_idx).to(device)
107
+
108
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
109
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
110
+
111
+ logits = model(idx_seq, idx_msa, attn_idx)
112
+ mask_logits = logits[0, mask_position.item(), :]
113
+
114
+ predicted_token_id = temperature_sampling(mask_logits, τ)
115
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
116
+ input_text[mask_position.item()] = predicted_token
117
+ padded_seq[mask_position.item()] = predicted_token.strip()
118
+
119
+ generated_seq = input_text
120
+
121
+ generated_seq[1] = "[MASK]"
122
+ input_ids = vocab_mlm.__getitem__(generated_seq)
123
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
124
+ cls_mask_logits = logits[0, 1, :]
125
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
126
+
127
+ generated_seq[2] = "[MASK]"
128
+ input_ids = vocab_mlm.__getitem__(generated_seq)
129
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
130
+ act_mask_logits = logits[0, 2, :]
131
+ act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
132
+
133
+ cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
134
+ act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
135
+
136
+ if X1 in cls_pos and X2 in act_pos:
137
+ cls_proba = cls_probability[cls_pos.index(X1)].item()
138
+ act_proba = act_probability[act_pos.index(X2)].item()
139
+ generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
140
+ if cls_proba >= cls_proba_parent and act_proba >= act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
141
+ generated_seq_str = "".join(generated_seq)
142
+ if generated_seq_str not in train_seq and generated_seq_str not in generated_seqs_FINAL and not any(x in NON_AA for x in generated_seq):
143
+ generated_seqs_FINAL.append(generated_seq_str)
144
+ cls_probability_all.append(cls_proba)
145
+ act_probability_all.append(act_proba)
146
+ IDs.append(count + 1)
147
+ out = pd.DataFrame({
148
+ 'ID': IDs,
149
+ 'Generated_seq': generated_seqs_FINAL,
150
+ 'Subtype': X1,
151
+ 'Subtype_probability': cls_probability_all,
152
+ 'Potency': X2,
153
+ 'Potency_probability': act_probability_all,
154
+ 'Random_seed': seed
155
+ })
156
+ out.to_csv("output.csv", index=False, encoding='utf-8-sig')
157
+ count += 1
158
+ yield out, "output.csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  return out, "output.csv"
160
 
161
  with gr.Blocks() as demo:
162
  gr.Markdown("# Conotoxin Optimization Generation")
163
  with gr.Row():
164
  X0 = gr.Textbox(label="conotoxin")
165
+ X1 = gr.Dropdown(choices=['<α7>', '<AChBP>', '<α4β2>', '<α3β4>', '<Ca22>', '<α3β2>', '<Na12>', '<α9α10>', '<K16>', '<α1β1γδ>',
166
+ '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
167
+ '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
168
+ '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<Na18>',
169
+ '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>', '<α6β4>', '<α2β4>', '<Na13>', '<Na15>', '<α4β4>',
170
  '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
171
+ X2 = gr.Dropdown(choices=['<high>', '<low>'], label="Potency")
172
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
173
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
174
+ model_name = gr.Dropdown(choices=['model_final.pt', 'model_C1.pt', 'model_C2.pt', 'model_C3.pt', 'model_C4.pt', 'model_C5.pt', 'model_mlm.pt'], label="Model")
175
  with gr.Row():
176
  start_button = gr.Button("Start Generation")
177
  stop_button = gr.Button("Stop Generation")
 
179
  output_df = gr.DataFrame(label="Generated Conotoxins")
180
  with gr.Row():
181
  output_file = gr.File(label="Download generated conotoxins")
182
+
183
  start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num, model_name], outputs=[output_df, output_file])
184
  stop_button.click(stop_generation, outputs=None)
185