File size: 11,438 Bytes
1be6080
 
 
 
3fce019
1be6080
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5144693
4c8bb10
5144693
 
 
 
5ab180b
 
 
15657d4
 
 
 
 
3fce019
15657d4
 
5ab180b
3fce019
1be6080
 
52628d3
3fce019
 
 
52628d3
 
3fce019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6804da
05c30a5
0642c51
05c30a5
3fce019
 
 
 
 
 
 
 
db4c65b
48636f3
 
db4c65b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fce019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197fa95
3fce019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5144693
3fce019
 
 
5e721c8
 
1be6080
 
5e721c8
 
 
 
1be6080
1604e8c
 
 
 
 
 
 
 
b0f03e8
1be6080
 
ae5a889
3fce019
 
 
 
 
1be6080
3fce019
78cba2e
1be6080
 
3fce019
5144693
1be6080
 
 
 
 
2f72a2c
 
 
5144693
1be6080
 
44d4ac0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import torch
import random
import pandas as pd
from utils import create_vocab, setup_seed
from dataset_mlm import  get_paded_token_idx_gen, add_tokens_to_vocab
import gradio as gr
from gradio_rangeslider import RangeSlider
import time

is_stopped = False

def temperature_sampling(logits, temperature):
    logits = logits / temperature
    probabilities = torch.softmax(logits, dim=-1)
    sampled_token = torch.multinomial(probabilities, 1)
    return sampled_token

def stop_generation():
    global is_stopped
    is_stopped = True
    return "Generation stopped."

def CTXGen(X0, X3, X1, X2, τ, g_num, model_name, seed):
    if seed =='random':
        seed = random.randint(0,100000)
        setup_seed(seed)
    else:
        setup_seed(int(seed))
    global is_stopped
    is_stopped = False

    device = torch.device("cpu")
    vocab_mlm = create_vocab()
    vocab_mlm = add_tokens_to_vocab(vocab_mlm)
    save_path = model_name
    train_seqs = pd.read_csv('C0_seq.csv')
    train_seq = train_seqs['Seq'].tolist()
    model = torch.load(save_path, map_location=torch.device('cpu'))
    model = model.to(device)

    msa_data = pd.read_csv('conoData_C0.csv') 
    msa = msa_data['Sequences'].tolist()
    msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
    if not msa:
        X4 = ""
        X5 = ""
        X6 = ""
    else:
        msa = random.choice(msa)
        X4 = msa.split("|")[3]
        X5 = msa.split("|")[4]
        X6 = msa.split("|")[5]
    model.eval()
    with torch.no_grad():
        IDs = []
        generated_seqs = []
        generated_seqs_FINAL = []
        cls_probability_all = []
        act_probability_all = []
        count = 0
        gen_num = g_num
        NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
                        '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', 
                        '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', 
                        '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
                        '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', 
                        '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
        
        seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
        padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, None)
        idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
        seqseq_parent = ["[MASK]" if i=="X" else i for i in padded_seqseq_parent]

        seqseq_parent[1] = "[MASK]"
        input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
        logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)

        cls_mask_logits_parent = logits_parent[0, 1, :]
        cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=85)

        seqseq_parent[2] = "[MASK]"
        input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
        logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
        act_mask_logits_parent = logits_parent[0, 2, :]
        act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)

        cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
        act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))

        cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
        act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()

        start_time = time.time()
        while count < gen_num:
            new_seq = None
            gen_len = len(X3)
            if is_stopped:
                return "output.csv", pd.DataFrame()

            if time.time() - start_time > 1200:
                break

            seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
            vocab_mlm.token_to_idx["X"] = 4

            padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
            input_text = ["[MASK]" if i=="X" else i for i in padded_seq]

            gen_length = len(input_text)
            length = gen_length - sum(1 for x in input_text if x != '[MASK]')
            for i in range(length):
                if is_stopped:
                    return "output.csv", pd.DataFrame()
                
                _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
                idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
                idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
                attn_idx = torch.tensor(attn_idx).to(device)

                mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
                mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
                
                logits = model(idx_seq,idx_msa, attn_idx) 
                mask_logits = logits[0, mask_position.item(), :] 

                predicted_token_id = temperature_sampling(mask_logits, τ)

                predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
                input_text[mask_position.item()] = predicted_token
                padded_seq[mask_position.item()] = predicted_token.strip()
                new_seq = padded_seq
            generated_seq = input_text
        
            generated_seq[1] = "[MASK]"
            input_ids = vocab_mlm.__getitem__(generated_seq)
            logits = model(torch.tensor([input_ids]).to(device), idx_msa)
            cls_mask_logits = logits[0, 1, :]
            cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)

            generated_seq[2] = "[MASK]"
            input_ids = vocab_mlm.__getitem__(generated_seq)
            logits = model(torch.tensor([input_ids]).to(device), idx_msa)
            act_mask_logits = logits[0, 2, :]
            act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)

            cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
            act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
            
            if X1 in cls_pos and X2 in act_pos:
                cls_proba = cls_probability[cls_pos.index(X1)].item()
                act_proba = act_probability[act_pos.index(X2)].item()
                generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
                if cls_proba>=cls_proba_parent and act_proba>=act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
                    generated_seqs.append("".join(generated_seq))
                    if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
                        generated_seqs_FINAL.append("".join(generated_seq))
                        cls_probability_all.append(cls_proba)
                        act_probability_all.append(act_proba)
                        IDs.append(count+1)
                        out = pd.DataFrame({
                            'ID':IDs,
                            'Generated_seq': generated_seqs_FINAL,
                            'Subtype': X1,
                            'Subtype_probability': cls_probability_all, 
                            'Potency': X2, 
                            'Potency_probability': act_probability_all, 
                            'Random_seed': int(seed)
                        })
                        out.to_csv("output.csv", index=False, encoding='utf-8-sig')
                        count += 1
                        yield "output.csv", out
    return "output.csv", out

with gr.Blocks() as demo:
    gr.Markdown("🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CTXGen_Label_Prediction)**")
    gr.Markdown("🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CTXGen_Unconstrained_generation)**")
    gr.Markdown("🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CTXGen_conditional_generation)**")
    gr.Markdown("🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CTXGen_optimization_generation)**")
    gr.Markdown("# Conotoxin Optimization Generation")
    gr.Markdown("#### Input")
    gr.Markdown("✅**Conotoxin**: a conotoxin that needs to be optimized. For example, GCCSDPRCAWRC")
    gr.Markdown("✅**Positions**: the positions that need to be optimized, replaced by X. For example, GCCXXXXCAHRC")
    gr.Markdown("✅**Subtype**: subtype of action. For example, α7")
    gr.Markdown("✅**Potency**: required potency. For example, High")
    gr.Markdown("✅**τ**: temperature factor controls the diversity of conotoxins generated. The higher the value, the higher the diversity")
    gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
    gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
    
    with gr.Row():
        X0 = gr.Textbox(label="conotoxin")
        X3 = gr.Textbox(label="Positions that needs optimization")
        X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>', 
                                  '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>', 
                                  '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>', 
                                  '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>',  '<Na18>', 
                                  '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>', 
                                  '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
        X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
    with gr.Row():
        τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
        g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
        model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
        seed = gr.Textbox(label="Seed", value="random")
    with gr.Row():
        start_button = gr.Button("Start Generation")
        stop_button = gr.Button("Stop Generation")
    with gr.Row():
        output_file = gr.File(label="Download generated conotoxins")
    with gr.Row():
        output_df = gr.DataFrame(label="Generated Conotoxins")

    start_button.click(CTXGen, inputs=[X0, X3, X1, X2, τ, g_num, model_name, seed], outputs=[output_file, output_df])
    stop_button.click(stop_generation, outputs=None)

demo.launch()