File size: 3,461 Bytes
23ddfce
 
 
 
5408638
1c1f9d0
5408638
23ddfce
5408638
 
 
 
 
 
23ddfce
5408638
23ddfce
 
 
 
5408638
23ddfce
5408638
 
 
 
 
 
 
 
23ddfce
5408638
23ddfce
 
5408638
23ddfce
5408638
 
 
23ddfce
5408638
23ddfce
 
 
5408638
23ddfce
5408638
 
23ddfce
 
 
 
 
 
5408638
 
 
 
 
 
fae5744
5408638
 
1c1f9d0
 
5408638
 
 
 
1c1f9d0
5408638
1c1f9d0
5408638
 
 
1c1f9d0
5408638
 
 
1c1f9d0
5408638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ddfce
5408638
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100


import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from lime.lime_text import LimeTextExplainer
import pickle
import matplotlib.pyplot as plt

# Streamlit Title
st.title("Prompt Injection Detection and Prevention")
st.write("Classify prompts as malicious or valid and understand predictions using LIME.")

# Cache Model Loading
@st.cache_resource
def load_model(filepath):
    return tf.keras.models.load_model(filepath)

# Load Tokenizer Function
@st.cache_resource
def load_tokenizer(filepath):
    try:
        with open(filepath, 'rb') as handle:
            tokenizer = pickle.load(handle)
        return tokenizer
    except Exception as e:
        st.error(f"Error loading tokenizer: {e}")
        return None

# Preprocessing Function
def preprocess_prompt(prompt, tokenizer, max_length=100):
    sequence = tokenizer.texts_to_sequences([prompt])
    return pad_sequences(sequence, maxlen=max_length)

# Prediction Function
def detect_prompt(prompt, tokenizer, model):
    processed_prompt = preprocess_prompt(prompt, tokenizer)
    prediction = model.predict(processed_prompt)[0][0]
    class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
    confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
    return class_label, confidence_score

# LIME Explanation
def lime_explain(prompt, model, tokenizer, max_length=100):
    explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
    
    def predict_fn(prompts):
        sequences = tokenizer.texts_to_sequences(prompts)
        padded_sequences = pad_sequences(sequences, maxlen=max_length)
        predictions = model.predict(padded_sequences)
        return np.hstack([1 - predictions, predictions])

    explanation = explainer.explain_instance(
        prompt,
        predict_fn,
        num_features=10
    )
    return explanation

# Load Model Section
st.subheader("Load Your Trained Model")
model_path = st.text_input("Enter the path to your trained model (.h5):")
tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):")
model = None
tokenizer = None

if model_path and tokenizer_path:
    try:
        model = load_model(model_path)
        tokenizer = load_tokenizer(tokenizer_path)
        
        if model and tokenizer:
            st.success("Model and Tokenizer Loaded Successfully!")
            
            # User Prompt Input
            st.subheader("Classify Your Prompt")
            user_prompt = st.text_input("Enter a prompt to classify:")

            if user_prompt:
                class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
                st.write(f"Predicted Class: **{class_label}**")
                st.write(f"Confidence Score: **{confidence_score:.2f}%**")

                # LIME Explanation
                st.subheader("LIME Explanation")
                explanation = lime_explain(user_prompt, model, tokenizer)
                explanation_as_html = explanation.as_html()
                st.components.v1.html(explanation_as_html, height=500)
        else:
            st.error("Failed to load model or tokenizer.")
    
    except Exception as e:
        st.error(f"Error Loading Model or Tokenizer: {e}")

# Footer
st.write("---")
st.write("Developed for detecting and preventing prompt injection attacks.")