import torch import gradio as gr from utils import create_vocab, setup_seed from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab setup_seed(4) def CTXGen(X1,X2,X3,model_name): device = torch.device("cpu") vocab_mlm = create_vocab() vocab_mlm = add_tokens_to_vocab(vocab_mlm) save_path = model_name model = torch.load(save_path, weights_only=False, map_location=torch.device('cpu')) model = model.to(device) predicted_token_probability_all = [] model.eval() topk = [] with torch.no_grad(): new_seq = None seq = [f"{X1}|{X2}|{X3}|||"] vocab_mlm.token_to_idx["X"] = 4 padded_seq, _, idx_msa, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq) idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device) mask_positions = [i for i, token in enumerate(padded_seq) if token == "X"] if not mask_positions: raise ValueError("Nothing found in the sequence to predict.") for mask_position in mask_positions: padded_seq[mask_position] = "[MASK]" input_ids = vocab_mlm.__getitem__(padded_seq) input_ids = torch.tensor([input_ids]).to(device) logits = model(input_ids, idx_msa) mask_logits = logits[0, mask_position, :] predicted_token_probability, predicted_token_id = torch.topk((torch.softmax(mask_logits, dim=-1)), k=5) topk.append(predicted_token_id) predicted_token = vocab_mlm.idx_to_token[predicted_token_id[0].item()] predicted_token_probability_all.append(predicted_token_probability[0].item()) padded_seq[mask_position] = predicted_token cls_pos = vocab_mlm.to_tokens(list(topk[0])) if X1 != "X": Topk = X1 Subtype = X1 Potency = padded_seq[2],predicted_token_probability_all[0] elif X2 != "X": Topk = cls_pos Subtype = padded_seq[1],predicted_token_probability_all[0] Potency = X2 else: Topk = cls_pos Subtype = padded_seq[1],predicted_token_probability_all[0] Potency = padded_seq[2],predicted_token_probability_all[1] return Subtype, Potency, Topk iface = gr.Interface( fn=CTXGen, inputs=[ gr.Dropdown(choices=['X', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '<α1AAR>', '<α1BAR>', '<α1β1γ>', '<α1β1γδ>', '<α1β1δ>', '<α1β1δε>', '<α1β1ε>', '<α2β2>', '<α2β4>', '<α3β2>', '<α3β4>', '<α4β2>', '<α4β4>', '<α6α3β2>', '<α6α3β2β3>', '<α6α3β4>', '<α6α3β4β3>', '<α6β3β4>', '<α6β4>', '<α7>', '<α7α6β2>', '<α75HT3>', '<α9>', '<α9α10>'], label="Subtype"), gr.Dropdown(choices=['X','',''], label="Potency"), gr.Textbox(label="Conotoxin"), gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model") ], outputs=[ gr.Textbox(label="Subtype"), gr.Textbox(label="Potency"), gr.Textbox(label="Top5") ], title="Conotoxin Label Prediction", description=""" 🔗 **[Label Prediction](https://huggingface.co/spaces/oucgc1996/CreoPep_Label_Prediction)** 🔗 **[Unconstrained Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_Unconstrained_generation)** 🔗 **[Conditional Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_conditional_generation)** 🔗 **[Optimization Generation](https://huggingface.co/spaces/oucgc1996/CreoPep_optimization_generation)** ✅ **Subtype**: X if needs to be predicted. ✅ **Potency**: X if needs to be predicted. ✅ **Conotoxin**: conotoxin needs to be predicted. ✅ **Model**: model parameters trained at different stages of data augmentation. Please refer to the paper for details. """ ) iface.launch()