Spaces:
Sleeping
Sleeping
File size: 6,663 Bytes
6c8ba1a 1a6b68a 6c8ba1a a17547d 6c8ba1a a17547d bbcb42f f53dfa1 6c8ba1a f53dfa1 1b3ac66 dd178bc 1b3ac66 7437113 1b3ac66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import re
from multiprocessing import cpu_count
from keras.src.saving import load_model
import pandas as pd
from keras.src.utils import set_random_seed
from numpy import int64
from pandarallel import pandarallel
from sklearn.preprocessing import RobustScaler
import gradio as gr
set_random_seed(65536)
pandarallel.initialize(use_memory_fs=True, nb_workers=cpu_count())
model = load_model('./sqid.keras')
def sql_tokenize(sql_query):
sql_query = sql_query.replace('`', ' ').replace('%20', ' ').replace('=', ' = ').replace('((', ' (( ').replace(
'))', ' )) ').replace('(', ' ( ').replace(')', ' ) ').replace('||', ' || ').replace(',', '').replace(
'--', ' -- ').replace(':', ' : ').replace('%23', ' # ').replace('+', ' + ').replace('!=',
' != ') \
.replace('"', ' " ').replace('%26', ' and ').replace('$', ' $ ').replace('%28', ' ( ').replace('%2A', ' * ') \
.replace('%7C', ' | ').replace('&', ' & ').replace(']', ' ] ').replace('[', ' [ ').replace(';',
' ; ').replace(
'/*', ' /* ')
sql_reserved = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'ORDER', 'BY', 'GROUP', 'HAVING',
'LIMIT', 'BETWEEN', 'IS', 'NULL', '%', 'LIKE', 'MIN', 'MAX', 'AS', 'UPPER', 'LOWER', 'TO_DATE',
'=', '>', '<', '>=', '<=', '!=', '<>', 'BETWEEN', 'LIKE', 'EXISTS', 'JOIN', 'UNION', 'ALL',
'ASC', 'DESC', '||', 'AVG', 'LIMIT', 'EXCEPT', 'INTERSECT', 'CASE', 'WHEN', 'THEN', 'IF',
'IF', 'ANY', 'CAST', 'CONVERT', 'COALESCE', 'NULLIF', 'INNER', 'OUTER', 'LEFT', 'RIGHT', 'FULL',
'CROSS', 'OVER', 'PARTITION', 'SUM', 'COUNT', 'WITH', 'INTERVAL', 'WINDOW', 'OVER',
'ROW_NUMBER', 'RANK',
'DENSE_RANK', 'NTILE', 'FIRST_VALUE', 'LAST_VALUE', 'LAG', 'LEAD', 'DISTINCT', 'COMMENT',
'INSERT',
'UPDATE', 'DELETED', 'MERGE', '*', 'generate_series', 'char', 'chr', 'substr', 'lpad',
'extract',
'year', 'month', 'day', 'timestamp', 'number', 'string', 'concat', 'INFORMATION_SCHEMA',
"SQLITE_MASTER", 'TABLES', 'COLUMNS', 'CUBE', 'ROLLUP', 'RECURSIVE', 'FILTER', 'EXCLUDE',
'AUTOINCREMENT', 'WITHOUT', 'ROWID', 'VIRTUAL', 'INDEXED', 'UNINDEXED', 'SERIAL',
'DO', 'RETURNING', 'ILIKE', 'ARRAY', 'ANYARRAY', 'JSONB', 'TSQUERY', 'SEQUENCE',
'SYNONYM', 'CONNECT', 'START', 'LEVEL', 'ROWNUM', 'NOCOPY', 'MINUS', 'AUTO_INCREMENT', 'BINARY',
'ENUM', 'REPLACE', 'SET', 'SHOW', 'DESCRIBE', 'USE', 'EXPLAIN', 'STORED', 'VIRTUAL', 'RLIKE',
'MD5', 'SLEEP', 'BENCHMARK', '@@VERSION', 'VERSION', '@VERSION', 'CONVERT', 'NVARCHAR', '#',
'##', 'INJECTX',
'DELAY', 'WAITFOR', 'RAND',
}
tokens = sql_query.split()
tokens = [re.sub(r"""[^*\w\s.=\-><_|()!"']""", '', token) for token in tokens]
for i, token in enumerate(tokens):
if token.strip().upper() in sql_reserved:
continue
if token.strip().isnumeric():
tokens[i] = '#NUMBER#'
elif re.match(r'^[a-zA-Z_.|][a-zA-Z0-9_.|]*$', token.strip()):
tokens[i] = '#IDENTIFIER#'
elif re.match(r'^[\d:]*$', token.strip()):
tokens[i] = '#TIMESTAMP#'
elif '%' in token.strip():
tokens[i] = ' '.join(
[j.strip() if j.strip() in ('%', "'", "'") else '#IDENTIFIER#' for j in token.strip().split('%')])
return ' '.join(tokens)
def add_features(x):
x['Query'] = x['Query'].copy().parallel_apply(lambda a: sql_tokenize(a))
x['num_tables'] = x['Query'].str.lower().str.count(r'FROM\s+#IDENTIFIER#', flags=re.I)
x['num_columns'] = x['Query'].str.lower().str.count(r'SELECT\s+#IDENTIFIER#', flags=re.I)
x['num_literals'] = x['Query'].str.lower().str.count("'[^']*'", flags=re.I) + x['Query'].str.lower().str.count(
'"[^"]"', flags=re.I)
x['num_parentheses'] = x['Query'].str.lower().str.count("\\(", flags=re.I) + x['Query'].str.lower().str.count(
'\\)',
flags=re.I)
x['has_union'] = x['Query'].str.lower().str.count(" union |union all", flags=re.I) > 0
x['has_union'] = x['has_union'].astype(int64)
x['depth_nested_queries'] = x['Query'].str.lower().str.count("\\(", flags=re.I)
x['num_join'] = x['Query'].str.lower().str.count(
" join |inner join|outer join|full outer join|full inner join|cross join|left join|right join",
flags=re.I)
x['num_sp_chars'] = x['Query'].parallel_apply(lambda a: len(re.findall(r'[\'";\-*/%=><|#]', a)))
x['has_mismatched_quotes'] = x['Query'].parallel_apply(
lambda sql_query: 1 if re.search(r"'.*[^']$|\".*[^\"]$", sql_query) else 0)
x['has_tautology'] = x['Query'].parallel_apply(lambda sql_query: 1 if re.search(r"'[\s]*=[\s]*'", sql_query) else 0)
return x
def is_malicious_sql(sql, threshold):
input_df = pd.DataFrame([sql], columns=['Query'])
input_df = add_features(input_df)
numeric_features = ["num_tables", "num_columns", "num_literals", "num_parentheses", "has_union",
"depth_nested_queries", "num_join", "num_sp_chars", "has_mismatched_quotes", "has_tautology"]
scaler = RobustScaler()
x_in = scaler.fit_transform(input_df[numeric_features])
preds = model.predict([input_df['Query'], x_in]).tolist()[0][0]
if preds > float(threshold):
return f'Malicious - {preds}'
return f'Safe - {preds}'
def respond(
message,
history,
threshold
):
if len(history) > 5:
history = history[1:]
for val in history:
if val[0].lower().strip() == message.lower().strip():
return val[1]
val = (message.lower().strip(), is_malicious_sql(message, threshold))
print(val)
return val[1]
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
title='SafeSQL-v1-Demo',
description='Please enter a SQL query as your input. You may adjust the minimum probability threshold for reporting SQLs as malicious using the slider below.',
additional_inputs=[
gr.Slider(minimum=0.01, maximum=0.99, value=0.75, step=0.01, label="Detection Probability Threshold "),
],
)
if __name__ == "__main__":
demo.launch()
|