import streamlit as st import tensorflow as tf from tensorflow.keras.preprocessing.text import Tokenizer from tensorflow.keras.preprocessing.sequence import pad_sequences import numpy as np from lime.lime_text import LimeTextExplainer import pickle import matplotlib.pyplot as plt # Streamlit Title st.title("Prompt Injection Detection and Prevention") st.write("Classify prompts as malicious or valid and understand predictions using LIME.") # Cache Model Loading @st.cache_resource def load_model(filepath): return tf.keras.models.load_model(filepath) # Load Tokenizer Function @st.cache_resource def load_tokenizer(filepath): try: with open(filepath, 'rb') as handle: tokenizer = pickle.load(handle) return tokenizer except Exception as e: st.error(f"Error loading tokenizer: {e}") return None # Preprocessing Function def preprocess_prompt(prompt, tokenizer, max_length=100): sequence = tokenizer.texts_to_sequences([prompt]) return pad_sequences(sequence, maxlen=max_length) # Prediction Function def detect_prompt(prompt, tokenizer, model): processed_prompt = preprocess_prompt(prompt, tokenizer) prediction = model.predict(processed_prompt)[0][0] class_label = 'Malicious' if prediction >= 0.5 else 'Valid' confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100 return class_label, confidence_score # LIME Explanation def lime_explain(prompt, model, tokenizer, max_length=100): explainer = LimeTextExplainer(class_names=["Valid", "Malicious"]) def predict_fn(prompts): sequences = tokenizer.texts_to_sequences(prompts) padded_sequences = pad_sequences(sequences, maxlen=max_length) predictions = model.predict(padded_sequences) return np.hstack([1 - predictions, predictions]) explanation = explainer.explain_instance( prompt, predict_fn, num_features=10 ) return explanation # Load Model Section st.subheader("Load Your Trained Model") model_path = st.text_input("Enter the path to your trained model (.h5):") tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):") model = None tokenizer = None if model_path and tokenizer_path: try: model = load_model(model_path) tokenizer = load_tokenizer(tokenizer_path) if model and tokenizer: st.success("Model and Tokenizer Loaded Successfully!") # User Prompt Input st.subheader("Classify Your Prompt") user_prompt = st.text_input("Enter a prompt to classify:") if user_prompt: class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model) st.write(f"Predicted Class: **{class_label}**") st.write(f"Confidence Score: **{confidence_score:.2f}%**") # LIME Explanation st.subheader("LIME Explanation") explanation = lime_explain(user_prompt, model, tokenizer) explanation_as_html = explanation.as_html() st.components.v1.html(explanation_as_html, height=500) else: st.error("Failed to load model or tokenizer.") except Exception as e: st.error(f"Error Loading Model or Tokenizer: {e}") # Footer st.write("---") st.write("Developed for detecting and preventing prompt injection attacks.")