oucgc1996 commited on
Commit
3fce019
·
verified ·
1 Parent(s): bf2755b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -123
app.py CHANGED
@@ -2,14 +2,14 @@ import torch
2
  import random
3
  import pandas as pd
4
  from utils import create_vocab, setup_seed
5
- from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
9
 
10
  is_stopped = False
11
 
12
- seed = random.randint(0, 100000)
13
  setup_seed(seed)
14
 
15
  def temperature_sampling(logits, temperature):
@@ -32,148 +32,158 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
32
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
33
  save_path = model_name
34
  train_seqs = pd.read_csv('C0_seq.csv')
35
- train_seq = set(train_seqs['Seq'].tolist()) # 使用集合加快查找速度
36
  model = torch.load(save_path, map_location=torch.device('cpu'))
37
  model = model.to(device)
38
- model.eval()
39
 
40
  X3 = "X" * len(X0)
41
- msa_data = pd.read_csv('conoData_C0.csv')
42
  msa = msa_data['Sequences'].tolist()
43
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
44
  if not msa:
45
- X4, X5, X6 = "", "", ""
 
 
46
  else:
47
  msa = random.choice(msa)
48
- X4, X5, X6 = msa.split("|")[3:6]
49
-
50
- NON_AA = {"B", "O", "U", "Z", "X", '<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
51
- '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
52
- '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>',
53
- '<α6α3β4β3>', '<α2β2>', '<α6β4>', '<α2β4>', '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
54
- '<α9α10>', '<α6α3β4>', '<NaTTXS>', '<Na17>', '<high>', '<low>', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'}
55
-
56
- seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
57
- padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, None)
58
- idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
59
- seqseq_parent = ["[MASK]" if i == "X" else i for i in padded_seqseq_parent]
60
-
61
- seqseq_parent[1] = "[MASK]"
62
- input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
63
- logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
64
- cls_mask_logits_parent = logits_parent[0, 1, :]
65
- cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=85)
66
-
67
- seqseq_parent[2] = "[MASK]"
68
- input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
69
- logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
70
- act_mask_logits_parent = logits_parent[0, 2, :]
71
- act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
72
-
73
- cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
74
- act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
75
-
76
- cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
77
- act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
78
-
79
- start_time = time.time()
80
- count = 0
81
- new_seq = None
82
- generated_seqs_FINAL = []
83
- cls_probability_all = []
84
- act_probability_all = []
85
- IDs = []
86
-
87
- while count < g_num:
88
- if is_stopped:
89
- return pd.DataFrame(), "output.csv"
90
-
91
- if time.time() - start_time > 1200:
92
- break
93
-
94
- seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
95
- padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, None)
96
- input_text = ["[MASK]" if i == "X" else i for i in padded_seq]
97
-
98
- gen_length = len(input_text)
99
- length = gen_length - sum(1 for x in input_text if x != '[MASK]')
100
- for i in range(length):
101
  if is_stopped:
102
  return pd.DataFrame(), "output.csv"
103
 
104
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
105
- idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
106
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
107
- attn_idx = torch.tensor(attn_idx).to(device)
108
-
109
- mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
110
- mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
111
-
112
- logits = model(idx_seq, idx_msa, attn_idx)
113
- mask_logits = logits[0, mask_position.item(), :]
114
-
115
- predicted_token_id = temperature_sampling(mask_logits, τ)
116
- predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
117
- input_text[mask_position.item()] = predicted_token
118
- padded_seq[mask_position.item()] = predicted_token.strip()
119
- new_seq = padded_seq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- generated_seq = input_text
122
-
123
- generated_seq[1] = "[MASK]"
124
- input_ids = vocab_mlm.__getitem__(generated_seq)
125
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
126
- cls_mask_logits = logits[0, 1, :]
127
- cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
128
-
129
- generated_seq[2] = "[MASK]"
130
- input_ids = vocab_mlm.__getitem__(generated_seq)
131
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
132
- act_mask_logits = logits[0, 2, :]
133
- act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
134
-
135
- cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
136
- act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
137
-
138
- if X1 in cls_pos and X2 in act_pos:
139
- cls_proba = cls_probability[cls_pos.index(X1)].item()
140
- act_proba = act_probability[act_pos.index(X2)].item()
141
- generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
142
- if cls_proba >= cls_proba_parent and act_proba >= act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
143
- generated_seq_str = "".join(generated_seq)
144
- if generated_seq_str not in train_seq and generated_seq_str not in generated_seqs_FINAL and not any(x in NON_AA for x in generated_seq):
145
- generated_seqs_FINAL.append(generated_seq_str)
146
- cls_probability_all.append(cls_proba)
147
- act_probability_all.append(act_proba)
148
- IDs.append(count + 1)
149
- out = pd.DataFrame({
150
- 'ID': IDs,
151
- 'Generated_seq': generated_seqs_FINAL,
152
- 'Subtype': X1,
153
- 'Subtype_probability': cls_probability_all,
154
- 'Potency': X2,
155
- 'Potency_probability': act_probability_all,
156
- 'Random_seed': seed
157
- })
158
- out.to_csv("output.csv", index=False, encoding='utf-8-sig')
159
- count += 1
160
- yield out, "output.csv"
161
  return out, "output.csv"
162
 
163
  with gr.Blocks() as demo:
164
  gr.Markdown("# Conotoxin Optimization Generation")
165
  with gr.Row():
166
  X0 = gr.Textbox(label="conotoxin")
167
- X1 = gr.Dropdown(choices=['<α7>', '<AChBP>', '<α4β2>', '<α3β4>', '<Ca22>', '<α3β2>', '<Na12>', '<α9α10>', '<K16>', '<α1β1γδ>',
168
- '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
169
- '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
170
- '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<Na18>',
171
- '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>', '<α6β4>', '<α2β4>', '<Na13>', '<Na15>', '<α4β4>',
172
  '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
173
- X2 = gr.Dropdown(choices=['<high>', '<low>'], label="Potency")
174
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
175
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
176
- model_name = gr.Dropdown(choices=['model_final.pt', 'model_C1.pt', 'model_C2.pt', 'model_C3.pt', 'model_C4.pt', 'model_C5.pt', 'model_mlm.pt'], label="Model")
177
  with gr.Row():
178
  start_button = gr.Button("Start Generation")
179
  stop_button = gr.Button("Stop Generation")
@@ -181,7 +191,7 @@ with gr.Blocks() as demo:
181
  output_df = gr.DataFrame(label="Generated Conotoxins")
182
  with gr.Row():
183
  output_file = gr.File(label="Download generated conotoxins")
184
-
185
  start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num, model_name], outputs=[output_df, output_file])
186
  stop_button.click(stop_generation, outputs=None)
187
 
 
2
  import random
3
  import pandas as pd
4
  from utils import create_vocab, setup_seed
5
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
  import gradio as gr
7
  from gradio_rangeslider import RangeSlider
8
  import time
9
 
10
  is_stopped = False
11
 
12
+ seed = random.randint(0,100000)
13
  setup_seed(seed)
14
 
15
  def temperature_sampling(logits, temperature):
 
32
  vocab_mlm = add_tokens_to_vocab(vocab_mlm)
33
  save_path = model_name
34
  train_seqs = pd.read_csv('C0_seq.csv')
35
+ train_seq = train_seqs['Seq'].tolist()
36
  model = torch.load(save_path, map_location=torch.device('cpu'))
37
  model = model.to(device)
 
38
 
39
  X3 = "X" * len(X0)
40
+ msa_data = pd.read_csv('conoData_C0.csv')
41
  msa = msa_data['Sequences'].tolist()
42
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
43
  if not msa:
44
+ X4 = ""
45
+ X5 = ""
46
+ X6 = ""
47
  else:
48
  msa = random.choice(msa)
49
+ X4 = msa.split("|")[3]
50
+ X5 = msa.split("|")[4]
51
+ X6 = msa.split("|")[5]
52
+ model.eval()
53
+ with torch.no_grad():
54
+ IDs = []
55
+ generated_seqs = []
56
+ generated_seqs_FINAL = []
57
+ cls_probability_all = []
58
+ act_probability_all = []
59
+ count = 0
60
+ gen_num = g_num
61
+ NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
62
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
63
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
64
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
65
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
66
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
67
+
68
+ seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
69
+ padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, None)
70
+ idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
71
+ seqseq_parent = ["[MASK]" if i=="X" else i for i in padded_seqseq_parent]
72
+
73
+ seqseq_parent[1] = "[MASK]"
74
+ input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
75
+ logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
76
+
77
+ cls_mask_logits_parent = logits_parent[0, 1, :]
78
+ cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=85)
79
+
80
+ seqseq_parent[2] = "[MASK]"
81
+ input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
82
+ logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
83
+ act_mask_logits_parent = logits_parent[0, 2, :]
84
+ act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
85
+
86
+ cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
87
+ act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
88
+
89
+ cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
90
+ act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
91
+
92
+ start_time = time.time()
93
+ while count < gen_num:
94
+ new_seq = None
95
+ gen_len = len(X0)
 
 
 
 
 
 
96
  if is_stopped:
97
  return pd.DataFrame(), "output.csv"
98
 
99
+ if time.time() - start_time > 1200:
100
+ break
101
+
102
+ seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
103
+ vocab_mlm.token_to_idx["X"] = 4
104
+
105
+ padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
106
+ input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
107
+
108
+ gen_length = len(input_text)
109
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
110
+ for i in range(length):
111
+ if is_stopped:
112
+ return pd.DataFrame(), "output.csv"
113
+
114
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
115
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
116
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
117
+ attn_idx = torch.tensor(attn_idx).to(device)
118
+
119
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
120
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
121
+
122
+ logits = model(idx_seq,idx_msa, attn_idx)
123
+ mask_logits = logits[0, mask_position.item(), :]
124
+
125
+ predicted_token_id = temperature_sampling(mask_logits, τ)
126
+
127
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
128
+ input_text[mask_position.item()] = predicted_token
129
+ padded_seq[mask_position.item()] = predicted_token.strip()
130
+ new_seq = padded_seq
131
+ generated_seq = input_text
132
+
133
+ generated_seq[1] = "[MASK]"
134
+ input_ids = vocab_mlm.__getitem__(generated_seq)
135
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
136
+ cls_mask_logits = logits[0, 1, :]
137
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
138
+
139
+ generated_seq[2] = "[MASK]"
140
+ input_ids = vocab_mlm.__getitem__(generated_seq)
141
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
142
+ act_mask_logits = logits[0, 2, :]
143
+ act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
144
+
145
+ cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
146
+ act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
147
 
148
+ if X1 in cls_pos and X2 in act_pos:
149
+ cls_proba = cls_probability[cls_pos.index(X1)].item()
150
+ act_proba = act_probability[act_pos.index(X2)].item()
151
+ generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
152
+ if cls_proba>=cls_proba_parent and act_proba>=act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
153
+ generated_seqs.append("".join(generated_seq))
154
+ if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
155
+ generated_seqs_FINAL.append("".join(generated_seq))
156
+ cls_probability_all.append(cls_proba)
157
+ act_probability_all.append(act_proba)
158
+ IDs.append(count+1)
159
+ out = pd.DataFrame({
160
+ 'ID':IDs,
161
+ 'Generated_seq': generated_seqs_FINAL,
162
+ 'Subtype': X1,
163
+ 'Subtype_probability': cls_probability_all,
164
+ 'Potency': X2,
165
+ 'Potency_probability': act_probability_all,
166
+ 'Random_seed': seed
167
+ })
168
+ out.to_csv("output.csv", index=False, encoding='utf-8-sig')
169
+ count += 1
170
+ yield out, "output.csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  return out, "output.csv"
172
 
173
  with gr.Blocks() as demo:
174
  gr.Markdown("# Conotoxin Optimization Generation")
175
  with gr.Row():
176
  X0 = gr.Textbox(label="conotoxin")
177
+ X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
178
+ '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
179
+ '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
180
+ '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<Na18>',
181
+ '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
182
  '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
183
+ X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
184
  τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
185
  g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
186
+ model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
187
  with gr.Row():
188
  start_button = gr.Button("Start Generation")
189
  stop_button = gr.Button("Stop Generation")
 
191
  output_df = gr.DataFrame(label="Generated Conotoxins")
192
  with gr.Row():
193
  output_file = gr.File(label="Download generated conotoxins")
194
+
195
  start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num, model_name], outputs=[output_df, output_file])
196
  stop_button.click(stop_generation, outputs=None)
197