rmdhirr's picture
Update app.py
c8496ce verified
raw
history blame
5.04 kB
import logging
import gradio as gr
import tensorflow as tf
import numpy as np
import nltk
import pickle
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import re
from tensorflow.keras import models, optimizers
from tensorflow.keras.metrics import Precision, Recall
# Set up logging
logging.basicConfig(level=logging.DEBUG)
# Load the model
try:
model = tf.keras.models.load_model('new_phishing_detection_model.keras')
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {e}")
# Compile the model with standard loss and metrics
try:
model.compile(optimizer=optimizers.Adam(learning_rate=0.0005),
loss='binary_crossentropy',
metrics=['accuracy', Precision(), Recall()])
logging.info("Model compiled successfully.")
except Exception as e:
logging.error(f"Error compiling model: {e}")
# Preprocessing functions
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
STOPWORDS = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()
def preprocess_url(url):
try:
url = url.lower()
url = re.sub(r'https?://', '', url)
url = re.sub(r'www\.', '', url)
url = re.sub(r'[^a-zA-Z0-9]', ' ', url)
url = re.sub(r'\s+', ' ', url).strip()
tokens = word_tokenize(url)
tokens = [word for word in tokens if word not in STOPWORDS]
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
except Exception as e:
logging.error(f"Error in URL preprocessing: {e}")
return ""
def preprocess_html(html):
try:
html = re.sub(r'<[^>]+>', ' ', html)
html = html.lower()
html = re.sub(r'https?://', '', html)
html = re.sub(r'[^a-zA-Z0-9]', ' ', html)
html = re.sub(r'\s+', ' ', html).strip()
tokens = word_tokenize(html)
tokens = [word for word in tokens if word not in STOPWORDS]
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
except Exception as e:
logging.error(f"Error in HTML preprocessing: {e}")
return ""
# Define maximum lengths
max_url_length = 180
max_html_length = 2000
max_words = 10000
# Load tokenizers
try:
with open('url_tokenizer.pkl', 'rb') as f:
url_tokenizer = pickle.load(f)
with open('html_tokenizer.pkl', 'rb') as f:
html_tokenizer = pickle.load(f)
logging.info("Tokenizers loaded successfully.")
except Exception as e:
logging.error(f"Error loading tokenizers: {e}")
def preprocess_input(input_text, tokenizer, max_length):
try:
sequences = tokenizer.texts_to_sequences([input_text])
padded_sequences = pad_sequences(sequences, maxlen=max_length, padding='post', truncating='post')
return padded_sequences
except Exception as e:
logging.error(f"Error in input preprocessing: {e}")
return np.zeros((1, max_length))
def get_prediction(input_text, input_type):
try:
is_url = input_type == "URL"
if is_url:
cleaned_text = preprocess_url(input_text)
input_data = preprocess_input(cleaned_text, url_tokenizer, max_url_length)
input_data = [input_data, np.zeros((1, max_html_length))] # dummy HTML input
else:
cleaned_text = preprocess_html(input_text)
input_data = preprocess_input(cleaned_text, html_tokenizer, max_html_length)
input_data = [np.zeros((1, max_url_length)), input_data] # dummy URL input
prediction = model.predict(input_data)[0][0]
return prediction
except Exception as e:
logging.error(f"Error in prediction: {e}")
return 0.0
def ensemble_prediction(input_text, input_type, n_ensemble=5):
try:
predictions = [get_prediction(input_text, input_type) for _ in range(n_ensemble)]
avg_prediction = np.mean(predictions)
return avg_prediction
except Exception as e:
logging.error(f"Error in ensemble prediction: {e}")
return 0.0
def phishing_detection(input_text, input_type):
prediction = ensemble_prediction(input_text, input_type)
threshold = 0.5 # Keep the threshold unchanged
if prediction > threshold:
return f"Warning: This site is likely a phishing site! ({prediction:.2f})"
else:
return f"Safe: This site is not likely a phishing site. ({prediction:.2f})"
iface = gr.Interface(
fn=phishing_detection,
inputs=[
gr.components.Textbox(lines=5, placeholder="Enter URL or HTML code"),
gr.components.Radio(["URL", "HTML"], type="value", label="Input Type")
],
outputs=gr.components.Textbox(label="Phishing Detection Result"),
title="Phishing Detection Model",
description="Check if a URL or HTML is Phishing.",
theme="default"
)
iface.launch()