File size: 8,517 Bytes
6ecb800
 
 
 
 
 
e946f65
 
6ecb800
e946f65
 
6ecb800
 
 
 
 
 
e946f65
 
 
 
 
39d3f04
 
 
 
 
 
 
e946f65
fdecd8f
f3acc3b
 
 
 
 
 
 
 
 
 
5cb1127
2452cd6
 
 
 
 
 
 
bbb5726
2452cd6
 
 
 
 
 
 
 
 
e946f65
 
 
 
 
 
2452cd6
d28f5ed
2452cd6
f3acc3b
fdecd8f
0d11814
e946f65
0bfbdcc
d28f5ed
e946f65
2452cd6
 
 
 
 
 
e946f65
2452cd6
 
 
 
 
fdecd8f
0d11814
e946f65
2452cd6
 
 
 
 
 
 
e946f65
 
2452cd6
 
 
 
 
 
 
 
 
e946f65
2452cd6
 
 
 
e946f65
2452cd6
 
e946f65
2452cd6
 
 
 
 
 
 
 
 
ca5630e
2452cd6
 
 
 
 
 
 
 
bbb5726
b883e7a
bbb5726
b883e7a
 
 
 
 
 
 
9a2664b
2452cd6
0d11814
 
2452cd6
c34757d
430f01f
3109018
 
 
f5c4387
adf7ae1
ca5630e
adf7ae1
ca5630e
adf7ae1
ca5630e
c34757d
 
fdecd8f
 
f3acc3b
39d3f04
c34757d
 
 
b883e7a
 
0d11814
 
 
39d3f04
f138db5
e946f65
c34757d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import random
import torch
import gradio as gr
from gradio_rangeslider import RangeSlider
import pandas as pd
from utils import create_vocab, setup_seed
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
import time

is_stopped = False

def temperature_sampling(logits, temperature):
    logits = logits / temperature
    probabilities = torch.softmax(logits, dim=-1)
    sampled_token = torch.multinomial(probabilities, 1)
    return sampled_token

def stop_generation():
    global is_stopped
    is_stopped = True
    return "Generation stopped."

def CTXGen(τ, g_num, length_range, model_name, seed):
    if seed =='random':
        seed = random.randint(0,100000)
        setup_seed(seed)
    else:
        setup_seed(int(seed))
        
    global is_stopped
    is_stopped = False

    device = torch.device("cpu")
    vocab_mlm = create_vocab()
    vocab_mlm = add_tokens_to_vocab(vocab_mlm)
    save_path = model_name
    train_seqs = pd.read_csv('C0_seq.csv')
    train_seq = train_seqs['Seq'].tolist()
    model = torch.load(save_path, map_location=torch.device('cpu'))
    model = model.to(device)

    start, end = length_range
    X1 = "X"
    X2 = "X"
    X4 = ""
    X5 = ""
    X6 = ""
    model.eval()
    with torch.no_grad():
        IDs = []
        generated_seqs = []
        generated_seqs_FINAL = []
        cls_pos_all = []
        cls_probability_all = []
        act_pos_all = []
        act_probability_all = []

        count = 0
        gen_num = int(g_num)
        NON_AA = ["B", "O", "U", "Z", "X", '<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
                  '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
                  '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
                  '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>', '<α6β4>', '<α2β4>',
                  '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
                  '<α9α10>', '<α6α3β4>', '<NaTTXS>', '<Na17>', '<high>', '<low>', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']

        start_time = time.time()
        while count < gen_num:
            new_seq = None
            if is_stopped:
                return "output.csv", pd.DataFrame()

            if time.time() - start_time > 1200:
                break

            gen_len = random.randint(int(start), int(end))
            X3 = "X" * gen_len
            seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
            vocab_mlm.token_to_idx["X"] = 4

            padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
            input_text = ["[MASK]" if i == "X" else i for i in padded_seq]

            gen_length = len(input_text)
            length = gen_length - sum(1 for x in input_text if x != '[MASK]')

            for i in range(length):
                if is_stopped:
                    return "output.csv", pd.DataFrame()

                _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
                idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
                idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
                attn_idx = torch.tensor(attn_idx).to(device)

                mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
                mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])

                logits = model(idx_seq, idx_msa, attn_idx)
                mask_logits = logits[0, mask_position.item(), :]

                predicted_token_id = temperature_sampling(mask_logits, τ)
                predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
                input_text[mask_position.item()] = predicted_token
                padded_seq[mask_position.item()] = predicted_token.strip()
                new_seq = padded_seq

            generated_seq = input_text

            generated_seq[1] = "[MASK]"
            generated_seq[2] = "[MASK]"
            input_ids = vocab_mlm.__getitem__(generated_seq)
            logits = model(torch.tensor([input_ids]).to(device), idx_msa)

            cls_mask_logits = logits[0, 1, :]
            act_mask_logits = logits[0, 2, :]

            cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1)
            act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1)

            cls_pos = vocab_mlm.idx_to_token[cls_mask_probs[0].item()]
            act_pos = vocab_mlm.idx_to_token[act_mask_probs[0].item()]

            cls_probability = cls_probability[0].item()
            act_probability = act_probability[0].item()
            generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]

            if generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
                generated_seqs.append("".join(generated_seq))
                if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
                    generated_seqs_FINAL.append("".join(generated_seq))
                    cls_pos_all.append(cls_pos)
                    cls_probability_all.append(cls_probability)
                    act_pos_all.append(act_pos)
                    act_probability_all.append(act_probability)
                    IDs.append(count+1)
                    out = pd.DataFrame({
                        'ID':IDs,
                        'Generated_seq': generated_seqs_FINAL,
                        'Subtype': cls_pos_all,
                        'Subtype_probability': cls_probability_all,
                        'Potency': act_pos_all,
                        'Potency_probability': act_probability_all,
                        'random_seed': seed
                    })
                    out.to_csv("output.csv", index=False, encoding='utf-8-sig')
                    count += 1
                    yield "output.csv", out
    return "output.csv", out

with gr.Blocks() as demo:
    gr.Markdown("🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CreoPep_Label_Prediction)**"
                "🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_Unconstrained_generation)**"
                "🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_conditional_generation)**"
                "🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_optimization_generation)**")
    gr.Markdown("# Conotoxin Unconstrained Generation")
    gr.Markdown("#### Input")
    gr.Markdown("✅**τ**: temperature factor controls the diversity of conotoxins generated. The higher the value, the higher the diversity.")
    gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.")
    gr.Markdown("✅**Length range**: expected length range of conotoxins generated.")
    gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.")
    gr.Markdown("✅**Seed**: enter an integer as the random seed to ensure reproducible results. The default is random.")
    with gr.Row():
        τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
        g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
        length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
        model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
        seed = gr.Textbox(label="Seed", value="random")
    with gr.Row():
        start_button = gr.Button("Start Generation")
        stop_button = gr.Button("Stop Generation")
    with gr.Row():
        output_file = gr.File(label="Download generated conotoxins")
    with gr.Row():
        output_df = gr.DataFrame(label="Generated Conotoxins")

    start_button.click(CTXGen, inputs=[τ, g_num, length_range, model_name, seed], outputs=[output_file, output_df])
    stop_button.click(stop_generation, outputs=None)

demo.launch()