Spaces:
Sleeping
Sleeping
File size: 3,461 Bytes
23ddfce 5408638 1c1f9d0 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 23ddfce 5408638 fae5744 5408638 1c1f9d0 5408638 1c1f9d0 5408638 1c1f9d0 5408638 1c1f9d0 5408638 1c1f9d0 5408638 23ddfce 5408638 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from lime.lime_text import LimeTextExplainer
import pickle
import matplotlib.pyplot as plt
# Streamlit Title
st.title("Prompt Injection Detection and Prevention")
st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
# Cache Model Loading
@st.cache_resource
def load_model(filepath):
return tf.keras.models.load_model(filepath)
# Load Tokenizer Function
@st.cache_resource
def load_tokenizer(filepath):
try:
with open(filepath, 'rb') as handle:
tokenizer = pickle.load(handle)
return tokenizer
except Exception as e:
st.error(f"Error loading tokenizer: {e}")
return None
# Preprocessing Function
def preprocess_prompt(prompt, tokenizer, max_length=100):
sequence = tokenizer.texts_to_sequences([prompt])
return pad_sequences(sequence, maxlen=max_length)
# Prediction Function
def detect_prompt(prompt, tokenizer, model):
processed_prompt = preprocess_prompt(prompt, tokenizer)
prediction = model.predict(processed_prompt)[0][0]
class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
return class_label, confidence_score
# LIME Explanation
def lime_explain(prompt, model, tokenizer, max_length=100):
explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
def predict_fn(prompts):
sequences = tokenizer.texts_to_sequences(prompts)
padded_sequences = pad_sequences(sequences, maxlen=max_length)
predictions = model.predict(padded_sequences)
return np.hstack([1 - predictions, predictions])
explanation = explainer.explain_instance(
prompt,
predict_fn,
num_features=10
)
return explanation
# Load Model Section
st.subheader("Load Your Trained Model")
model_path = st.text_input("Enter the path to your trained model (.h5):")
tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):")
model = None
tokenizer = None
if model_path and tokenizer_path:
try:
model = load_model(model_path)
tokenizer = load_tokenizer(tokenizer_path)
if model and tokenizer:
st.success("Model and Tokenizer Loaded Successfully!")
# User Prompt Input
st.subheader("Classify Your Prompt")
user_prompt = st.text_input("Enter a prompt to classify:")
if user_prompt:
class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
st.write(f"Predicted Class: **{class_label}**")
st.write(f"Confidence Score: **{confidence_score:.2f}%**")
# LIME Explanation
st.subheader("LIME Explanation")
explanation = lime_explain(user_prompt, model, tokenizer)
explanation_as_html = explanation.as_html()
st.components.v1.html(explanation_as_html, height=500)
else:
st.error("Failed to load model or tokenizer.")
except Exception as e:
st.error(f"Error Loading Model or Tokenizer: {e}")
# Footer
st.write("---")
st.write("Developed for detecting and preventing prompt injection attacks.")
|