|
import random |
|
import torch |
|
import gradio as gr |
|
from gradio_rangeslider import RangeSlider |
|
import pandas as pd |
|
from utils import create_vocab, setup_seed |
|
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab |
|
import time |
|
|
|
is_stopped = False |
|
|
|
def temperature_sampling(logits, temperature): |
|
logits = logits / temperature |
|
probabilities = torch.softmax(logits, dim=-1) |
|
sampled_token = torch.multinomial(probabilities, 1) |
|
return sampled_token |
|
|
|
def stop_generation(): |
|
global is_stopped |
|
is_stopped = True |
|
return "Generation stopped." |
|
|
|
def CTXGen(τ, g_num, length_range, model_name, seed): |
|
if seed =='random': |
|
seed = random.randint(0,100000) |
|
setup_seed(seed) |
|
else: |
|
setup_seed(int(seed)) |
|
|
|
global is_stopped |
|
is_stopped = False |
|
|
|
device = torch.device("cpu") |
|
vocab_mlm = create_vocab() |
|
vocab_mlm = add_tokens_to_vocab(vocab_mlm) |
|
save_path = model_name |
|
train_seqs = pd.read_csv('C0_seq.csv') |
|
train_seq = train_seqs['Seq'].tolist() |
|
model = torch.load(save_path, map_location=torch.device('cpu')) |
|
model = model.to(device) |
|
|
|
start, end = length_range |
|
X1 = "X" |
|
X2 = "X" |
|
X4 = "" |
|
X5 = "" |
|
X6 = "" |
|
model.eval() |
|
with torch.no_grad(): |
|
IDs = [] |
|
generated_seqs = [] |
|
generated_seqs_FINAL = [] |
|
cls_pos_all = [] |
|
cls_probability_all = [] |
|
act_pos_all = [] |
|
act_probability_all = [] |
|
|
|
count = 0 |
|
gen_num = int(g_num) |
|
NON_AA = ["B", "O", "U", "Z", "X", '<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>', |
|
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', |
|
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', |
|
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>', '<α6β4>', '<α2β4>', |
|
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', |
|
'<α9α10>', '<α6α3β4>', '<NaTTXS>', '<Na17>', '<high>', '<low>', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] |
|
|
|
start_time = time.time() |
|
while count < gen_num: |
|
new_seq = None |
|
if is_stopped: |
|
return "output.csv", pd.DataFrame() |
|
|
|
if time.time() - start_time > 1200: |
|
break |
|
|
|
gen_len = random.randint(int(start), int(end)) |
|
X3 = "X" * gen_len |
|
seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"] |
|
vocab_mlm.token_to_idx["X"] = 4 |
|
|
|
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) |
|
input_text = ["[MASK]" if i == "X" else i for i in padded_seq] |
|
|
|
gen_length = len(input_text) |
|
length = gen_length - sum(1 for x in input_text if x != '[MASK]') |
|
|
|
for i in range(length): |
|
if is_stopped: |
|
return "output.csv", pd.DataFrame() |
|
|
|
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) |
|
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device) |
|
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device) |
|
attn_idx = torch.tensor(attn_idx).to(device) |
|
|
|
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"] |
|
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]]) |
|
|
|
logits = model(idx_seq, idx_msa, attn_idx) |
|
mask_logits = logits[0, mask_position.item(), :] |
|
|
|
predicted_token_id = temperature_sampling(mask_logits, τ) |
|
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id)) |
|
input_text[mask_position.item()] = predicted_token |
|
padded_seq[mask_position.item()] = predicted_token.strip() |
|
new_seq = padded_seq |
|
|
|
generated_seq = input_text |
|
|
|
generated_seq[1] = "[MASK]" |
|
generated_seq[2] = "[MASK]" |
|
input_ids = vocab_mlm.__getitem__(generated_seq) |
|
logits = model(torch.tensor([input_ids]).to(device), idx_msa) |
|
|
|
cls_mask_logits = logits[0, 1, :] |
|
act_mask_logits = logits[0, 2, :] |
|
|
|
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=1) |
|
act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=1) |
|
|
|
cls_pos = vocab_mlm.idx_to_token[cls_mask_probs[0].item()] |
|
act_pos = vocab_mlm.idx_to_token[act_mask_probs[0].item()] |
|
|
|
cls_probability = cls_probability[0].item() |
|
act_probability = act_probability[0].item() |
|
generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')] |
|
|
|
if generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len: |
|
generated_seqs.append("".join(generated_seq)) |
|
if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq): |
|
generated_seqs_FINAL.append("".join(generated_seq)) |
|
cls_pos_all.append(cls_pos) |
|
cls_probability_all.append(cls_probability) |
|
act_pos_all.append(act_pos) |
|
act_probability_all.append(act_probability) |
|
IDs.append(count+1) |
|
out = pd.DataFrame({ |
|
'ID':IDs, |
|
'Generated_seq': generated_seqs_FINAL, |
|
'Subtype': cls_pos_all, |
|
'Subtype_probability': cls_probability_all, |
|
'Potency': act_pos_all, |
|
'Potency_probability': act_probability_all, |
|
'random_seed': seed |
|
}) |
|
out.to_csv("output.csv", index=False, encoding='utf-8-sig') |
|
count += 1 |
|
yield "output.csv", out |
|
return "output.csv", out |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CreoPep_Label_Prediction)**" |
|
"🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_Unconstrained_generation)**" |
|
"🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_conditional_generation)**" |
|
"🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_optimization_generation)**") |
|
gr.Markdown("# Conotoxin Unconstrained Generation") |
|
gr.Markdown("#### Input") |
|
gr.Markdown("✅**τ**: temperature factor controls the diversity of conotoxins generated. The higher the value, the higher the diversity.") |
|
gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.") |
|
gr.Markdown("✅**Length range**: expected length range of conotoxins generated.") |
|
gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.") |
|
gr.Markdown("✅**Seed**: enter an integer as the random seed to ensure reproducible results. The default is random.") |
|
with gr.Row(): |
|
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ") |
|
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations") |
|
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range") |
|
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model") |
|
seed = gr.Textbox(label="Seed", value="random") |
|
with gr.Row(): |
|
start_button = gr.Button("Start Generation") |
|
stop_button = gr.Button("Stop Generation") |
|
with gr.Row(): |
|
output_file = gr.File(label="Download generated conotoxins") |
|
with gr.Row(): |
|
output_df = gr.DataFrame(label="Generated Conotoxins") |
|
|
|
start_button.click(CTXGen, inputs=[τ, g_num, length_range, model_name, seed], outputs=[output_file, output_df]) |
|
stop_button.click(stop_generation, outputs=None) |
|
|
|
demo.launch() |