import torch import random import pandas as pd from utils import create_vocab, setup_seed from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab import gradio as gr from gradio_rangeslider import RangeSlider import time is_stopped = False def temperature_sampling(logits, temperature): logits = logits / temperature probabilities = torch.softmax(logits, dim=-1) sampled_token = torch.multinomial(probabilities, 1) return sampled_token def stop_generation(): global is_stopped is_stopped = True return "Generation stopped." def CTXGen(X1, X2, τ, g_num, length_range, model_name, seed): if seed =='random': seed = random.randint(0,100000) setup_seed(seed) else: setup_seed(int(seed)) global is_stopped is_stopped = False start, end = length_range device = torch.device("cpu") vocab_mlm = create_vocab() vocab_mlm = add_tokens_to_vocab(vocab_mlm) save_path = model_name train_seqs = pd.read_csv('C0_seq.csv') train_seq = train_seqs['Seq'].tolist() model = torch.load(save_path, map_location=torch.device('cpu')) model = model.to(device) msa_data = pd.read_csv('conoData_C0.csv') msa = msa_data['Sequences'].tolist() msa = [x for x in msa if x.startswith(f"{X1}|{X2}")] msa = random.choice(msa) X4 = msa.split("|")[3] X5 = msa.split("|")[4] X6 = msa.split("|")[5] model.eval() with torch.no_grad(): IDs = [] generated_seqs = [] generated_seqs_FINAL = [] cls_probability_all = [] act_probability_all = [] count = 0 gen_num = int(g_num) NON_AA = ["B","O","U","Z","X",'', '<α1β1γδ>', '', '', '', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '', '<α4β2>', '', '<α75HT3>', '', '<α7>', '', '', '', '<α6β3β4>', '', '', '', '', '<α6α3β2>', '', '', '', '<α1β1δε>', '', '<α9>', '', '', '<α3β4>', '', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>', '', '', '', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '', '', '', '<α9α10>','<α6α3β4>', '', '','','','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]'] start_time = time.time() while count < gen_num: new_seq = None if is_stopped: return "output.csv", pd.DataFrame() if time.time() - start_time > 1200: break gen_len = random.randint(int(start), int(end)) X3 = "X" * gen_len seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"] vocab_mlm.token_to_idx["X"] = 4 padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) input_text = ["[MASK]" if i=="X" else i for i in padded_seq] gen_length = len(input_text) length = gen_length - sum(1 for x in input_text if x != '[MASK]') for i in range(length): if is_stopped: return "output.csv", pd.DataFrame() _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device) idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device) attn_idx = torch.tensor(attn_idx).to(device) mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"] mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]]) logits = model(idx_seq,idx_msa, attn_idx) mask_logits = logits[0, mask_position.item(), :] # predicted_token_id = temperature_sampling(mask_logits, τ) predicted_token = vocab_mlm.to_tokens(int(predicted_token_id)) input_text[mask_position.item()] = predicted_token padded_seq[mask_position.item()] = predicted_token.strip() new_seq = padded_seq generated_seq = input_text generated_seq[1] = "[MASK]" input_ids = vocab_mlm.__getitem__(generated_seq) logits = model(torch.tensor([input_ids]).to(device), idx_msa) cls_mask_logits = logits[0, 1, :] cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=85) generated_seq[2] = "[MASK]" input_ids = vocab_mlm.__getitem__(generated_seq) logits = model(torch.tensor([input_ids]).to(device), idx_msa) act_mask_logits = logits[0, 2, :] act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2) cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs)) act_pos = vocab_mlm.to_tokens(list(act_mask_probs)) if X1 in cls_pos and X2 in act_pos: cls_proba = cls_probability[cls_pos.index(X1)].item() act_proba = act_probability[act_pos.index(X2)].item() generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')] if act_proba>=0.5 and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len: generated_seqs.append("".join(generated_seq)) if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq): generated_seqs_FINAL.append("".join(generated_seq)) cls_probability_all.append(cls_proba) act_probability_all.append(act_proba) IDs.append(count+1) out = pd.DataFrame({ 'ID':IDs, 'Generated_seq': generated_seqs_FINAL, 'Subtype': X1, 'Subtype_probability': cls_probability_all, 'Potency': X2, 'Potency_probability': act_probability_all, 'Random_seed': seed }) out.to_csv("output.csv", index=False, encoding='utf-8-sig') count += 1 yield "output.csv", out return "output.csv", out with gr.Blocks() as demo: gr.Markdown("🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CreoPep_Label_Prediction)**" "🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_Unconstrained_generation)**" "🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_conditional_generation)**" "🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_optimization_generation)**") gr.Markdown("# Conotoxin Conditional Generation") gr.Markdown("#### Input") gr.Markdown("✅**Subtype**: subtype of action. For example, α7.") gr.Markdown("✅**Potency**: required potency. For example, High.") gr.Markdown("✅**τ**: temperature factor controls the diversity of conotoxins generated. The higher the value, the higher the diversity.") gr.Markdown("✅**Number of generations**: if it is not completed within 1200 seconds, it will automatically stop.") gr.Markdown("✅**Length range**: expected length range of conotoxins generated.") gr.Markdown("✅**Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details.") gr.Markdown("✅**Seed**: enter an integer as the random seed to ensure reproducible results. The default is random.") with gr.Row(): X1 = gr.Dropdown(choices=['<α7>', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '<α1AAR>', '<α1BAR>', '<α1β1γ>', '<α1β1γδ>', '<α1β1δ>', '<α1β1δε>', '<α1β1ε>', '<α2β2>', '<α2β4>', '<α3β2>', '<α3β4>', '<α4β2>', '<α4β4>', '<α6α3β2>', '<α6α3β2β3>', '<α6α3β4>', '<α6α3β4β3>', '<α6β3β4>', '<α6β4>', '<α7α6β2>', '<α75HT3>', '<α9>', '<α9α10>'], label="Subtype") X2 = gr.Dropdown(choices=['',''], label="Potency") τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ") g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations") with gr.Row(): length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range") model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model") seed = gr.Textbox(label="Seed", value="random") with gr.Row(): start_button = gr.Button("Start Generation") stop_button = gr.Button("Stop Generation") with gr.Row(): output_file = gr.File(label="Download generated conotoxins") with gr.Row(): output_df = gr.DataFrame(label="Generated Conotoxins") start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name,seed], outputs=[output_file, output_df]) stop_button.click(stop_generation, outputs=None) demo.launch()