AllerTrans / app.py
sfaezella's picture
Fix Gradio render issue
af2f4e8 verified
raw
history blame
3.94 kB
import torch
import gradio as gr
import numpy as np
from transformers import T5Tokenizer, T5EncoderModel
import esm
from inference import load_models, predict_ensemble
from transformers import AutoTokenizer, AutoModel
import spaces
# Load trained models
model_protT5, model_cat = load_models()
# Load ProtT5 model
tokenizer_t5 = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False)
model_t5 = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
model_t5 = model_t5.eval()
# Load the tokenizer and model
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer_esm = AutoTokenizer.from_pretrained(model_name)
esm_model = AutoModel.from_pretrained(model_name)
def extract_prott5_embedding(sequence):
sequence = sequence.replace(" ", "")
seq = " ".join(list(sequence))
ids = tokenizer_t5(seq, return_tensors="pt", padding=True)
with torch.no_grad():
embedding = model_t5(**ids).last_hidden_state
return torch.mean(embedding, dim=1)
# Extract ESM2 embedding
def extract_esm_embedding(sequence):
# Tokenize the sequence
inputs = tokenizer_esm(sequence, return_tensors="pt", padding=True, truncation=True)
# Forward pass through the model
with torch.no_grad():
outputs = esm_model(**inputs)
# Extract the embeddings from the 33rd layer (ESM2 layer)
token_representations = outputs.last_hidden_state # This is the default layer
return torch.mean(token_representations[0, 1:len(sequence)+1], dim=0).unsqueeze(0)
def estimate_duration(sequence):
# Estimate duration based on sequence length
base_time = 30 # Base time in seconds
time_per_residue = 0.5 # Estimated time per residue
estimated_time = base_time + len(sequence) * time_per_residue
return min(int(estimated_time), 300) # Cap at 300 seconds
@spaces.GPU(duration=120)
def classify(sequence):
protT5_emb = extract_prott5_embedding(sequence)
esm_emb = extract_esm_embedding(sequence)
concat = torch.cat((esm_emb, protT5_emb), dim=1)
pred = predict_ensemble(protT5_emb, concat, model_protT5, model_cat)
return "Potential Allergen" if pred.item() == 1 else "Non-Allergen"
description_md = """
### 📌 **About AllerTrans – An Allergenicity Prediction Tool for Protein Sequences**
**🧬 Input Format – FASTA Sequences**
This tool accepts protein sequences in FASTA format
**🧾 Output Explanation**
AllerTrans classifies your input sequence into one of the following categories:
🟢 Non-Allergen:
The protein is unlikely to cause an allergic reaction and can be considered safe in terms of allergenicity.
🔴 Potential Allergen:
The protein has the potential to trigger an allergic response or exhibit cross-reactivity in certain individuals. While not all individuals may experience reactions, these proteins cannot be considered safe.
**💡 Accepted Proteins**
- Natural and also recombinant proteins
🔎 **Note of Caution**:
While our model demonstrates promising performance—particularly with recombinant proteins, as evidenced by our additional evaluation with a recombinant protein dataset
from UniProt—**we advise caution when generalizing the results to all constructs and modifications of recombinant protein**. The generizability of the model to various recombinant scenarios has not been fully explored.
**⚠️ Disclaimer**
Although AllerTrans provides highly accurate predictions, it is intended as a screening tool. For clinical or regulatory decisions, always confirm results with experimental validation.
"""
with gr.Blocks() as demo:
gr.Interface(
fn=classify,
inputs=gr.Textbox(lines=3, placeholder="Enter protein sequence..."),
outputs=gr.Label(label="Prediction"),
)
gr.Markdown(description_md)
if __name__ == "__main__":
demo.launch()