new__project / app.py
deepugaur's picture
Update app.py
fae5744 verified
raw
history blame
3.07 kB
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt
# Streamlit Title
st.title("Prompt Injection Detection and Prevention")
st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
# Cache Model Loading
@st.cache_resource
def load_model(filepath):
return tf.keras.models.load_model(filepath)
# Tokenizer Setup
@st.cache_resource
def setup_tokenizer():
tokenizer = Tokenizer(num_words=5000)
# Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
return tokenizer
# Preprocessing Function
def preprocess_prompt(prompt, tokenizer, max_length=100):
sequence = tokenizer.texts_to_sequences([prompt])
return pad_sequences(sequence, maxlen=max_length)
# Prediction Function
def detect_prompt(prompt, tokenizer, model):
processed_prompt = preprocess_prompt(prompt, tokenizer)
prediction = model.predict(processed_prompt)[0][0]
class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
return class_label, confidence_score
# LIME Explanation
def lime_explain(prompt, model, tokenizer, max_length=100):
explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
def predict_fn(prompts):
sequences = tokenizer.texts_to_sequences(prompts)
padded_sequences = pad_sequences(sequences, maxlen=max_length)
predictions = model.predict(padded_sequences)
return np.hstack([1 - predictions, predictions])
explanation = explainer.explain_instance(
prompt,
predict_fn,
num_features=10
)
return explanation
# Load Model Section
st.subheader("Load Your Trained Model")
model = None
tokenizer = None
model_path = "deep_learning_model.h5" # Ensure this file is in the same directory as app.py
try:
model = load_model(model_path)
tokenizer = setup_tokenizer()
st.success("Model Loaded Successfully!")
# User Prompt Input
st.subheader("Classify Your Prompt")
user_prompt = st.text_input("Enter a prompt to classify:")
if user_prompt:
class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
st.write(f"Predicted Class: **{class_label}**")
st.write(f"Confidence Score: **{confidence_score:.2f}%**")
# LIME Explanation
st.subheader("LIME Explanation")
explanation = lime_explain(user_prompt, model, tokenizer)
explanation_as_html = explanation.as_html()
st.components.v1.html(explanation_as_html, height=500)
except Exception as e:
st.error(f"Error Loading Model: {e}")
# Footer
st.write("---")
st.write("Developed for detecting and preventing prompt injection attacks.")