deepugaur commited on
Commit
23ddfce
·
verified ·
1 Parent(s): 90db606

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import streamlit as st
4
+ import tensorflow as tf
5
+ from tensorflow.keras.preprocessing.text import Tokenizer
6
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+ import numpy as np
8
+ from lime.lime_text import LimeTextExplainer
9
+ import matplotlib.pyplot as plt
10
+
11
+ # Streamlit Title
12
+ st.title("Prompt Injection Detection and Prevention")
13
+ st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
14
+
15
+ # Cache Model Loading
16
+ @st.cache_resource
17
+ def load_model(filepath):
18
+ return tf.keras.models.load_model(filepath)
19
+
20
+ # Tokenizer Setup
21
+ @st.cache_resource
22
+ def setup_tokenizer():
23
+ tokenizer = Tokenizer(num_words=5000)
24
+ # Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
25
+ tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
26
+ return tokenizer
27
+
28
+ # Preprocessing Function
29
+ def preprocess_prompt(prompt, tokenizer, max_length=100):
30
+ sequence = tokenizer.texts_to_sequences([prompt])
31
+ return pad_sequences(sequence, maxlen=max_length)
32
+
33
+ # Prediction Function
34
+ def detect_prompt(prompt, tokenizer, model):
35
+ processed_prompt = preprocess_prompt(prompt, tokenizer)
36
+ prediction = model.predict(processed_prompt)[0][0]
37
+ class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
38
+ confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
39
+ return class_label, confidence_score
40
+
41
+ # LIME Explanation
42
+ def lime_explain(prompt, model, tokenizer, max_length=100):
43
+ explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
44
+
45
+ def predict_fn(prompts):
46
+ sequences = tokenizer.texts_to_sequences(prompts)
47
+ padded_sequences = pad_sequences(sequences, maxlen=max_length)
48
+ predictions = model.predict(padded_sequences)
49
+ return np.hstack([1 - predictions, predictions])
50
+
51
+ explanation = explainer.explain_instance(
52
+ prompt,
53
+ predict_fn,
54
+ num_features=10
55
+ )
56
+ return explanation
57
+
58
+ # Load Model Section
59
+ st.subheader("Load Your Trained Model")
60
+ model_path = st.text_input("Enter the path to your trained model (.h5):")
61
+ model = None
62
+ tokenizer = None
63
+
64
+ if model_path:
65
+ try:
66
+ model = load_model(model_path)
67
+ tokenizer = setup_tokenizer()
68
+ st.success("Model Loaded Successfully!")
69
+
70
+ # User Prompt Input
71
+ st.subheader("Classify Your Prompt")
72
+ user_prompt = st.text_input("Enter a prompt to classify:")
73
+
74
+ if user_prompt:
75
+ class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
76
+ st.write(f"Predicted Class: **{class_label}**")
77
+ st.write(f"Confidence Score: **{confidence_score:.2f}%**")
78
+
79
+ # LIME Explanation
80
+ st.subheader("LIME Explanation")
81
+ explanation = lime_explain(user_prompt, model, tokenizer)
82
+ explanation_as_html = explanation.as_html()
83
+ st.components.v1.html(explanation_as_html, height=500)
84
+
85
+ except Exception as e:
86
+ st.error(f"Error Loading Model: {e}")
87
+
88
+ # Footer
89
+ st.write("---")
90
+ st.write("Developed for detecting and preventing prompt injection attacks.")