deathsaber93 commited on
Commit
6c8ba1a
·
verified ·
1 Parent(s): 8673c59

Update app.py

Browse files

Added inferencing feature.

Files changed (1) hide show
  1. app.py +104 -40
app.py CHANGED
@@ -1,42 +1,27 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
8
 
9
 
10
  def respond(
11
  message,
12
  history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
  for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
  ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
  yield response
41
 
42
  """
@@ -45,19 +30,98 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
45
  demo = gr.ChatInterface(
46
  respond,
47
  additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
  ],
59
  )
60
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from multiprocessing import cpu_count
3
 
4
+ from keras.src.saving import load_model
5
+ import pandas as pd
6
+ from numpy import int64
7
+ from pandarallel import pandarallel
8
+ from sklearn.preprocessing import RobustScaler
9
+ import gradio as gr
10
 
11
 
12
  def respond(
13
  message,
14
  history: list[tuple[str, str]],
15
+ threshold
 
 
 
16
  ):
 
 
17
  for val in history:
18
+ if val[0].lower().strip() == message.lower().strip():
19
+ yield val[1]
 
 
20
 
21
+ for message in is_malicious_sql(message, threshold
 
 
 
 
 
 
 
 
 
22
  ):
23
+ response = message
24
+ history.append((message.lower().strip(), response))
 
25
  yield response
26
 
27
  """
 
30
  demo = gr.ChatInterface(
31
  respond,
32
  additional_inputs=[
33
+ gr.Textbox(value="Check whether a SQL is malicious or not.", label="System message"),
34
+ gr.Slider(minimum=0.01, maximum=0.99, value=0.75, step=0.01, label="Detection Probability Threshold "),
 
 
 
 
 
 
 
 
35
  ],
36
  )
37
 
38
 
39
  if __name__ == "__main__":
40
+ demo.launch()
41
+
42
+ pandarallel.initialize(use_memory_fs=True, nb_workers=cpu_count())
43
+ model = load_model('./sqid.keras')
44
+
45
+
46
+ def sql_tokenize(sql_query):
47
+ sql_query = sql_query.replace('`', ' ').replace('%20', ' ').replace('=', ' = ').replace('((', ' (( ').replace(
48
+ '))', ' )) ').replace('(', ' ( ').replace(')', ' ) ').replace('||', ' || ').replace(',', '').replace(
49
+ '--', ' -- ').replace(':', ' : ').replace('%23', ' # ').replace('+', ' + ').replace('!=',
50
+ ' != ') \
51
+ .replace('"', ' " ').replace('%26', ' and ').replace('$', ' $ ').replace('%28', ' ( ').replace('%2A', ' * ') \
52
+ .replace('%7C', ' | ').replace('&', ' & ').replace(']', ' ] ').replace('[', ' [ ').replace(';',
53
+ ' ; ').replace(
54
+ '/*', ' /* ')
55
+ sql_reserved = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'ORDER', 'BY', 'GROUP', 'HAVING',
56
+ 'LIMIT', 'BETWEEN', 'IS', 'NULL', '%', 'LIKE', 'MIN', 'MAX', 'AS', 'UPPER', 'LOWER', 'TO_DATE',
57
+ '=', '>', '<', '>=', '<=', '!=', '<>', 'BETWEEN', 'LIKE', 'EXISTS', 'JOIN', 'UNION', 'ALL',
58
+ 'ASC', 'DESC', '||', 'AVG', 'LIMIT', 'EXCEPT', 'INTERSECT', 'CASE', 'WHEN', 'THEN', 'IF',
59
+ 'IF', 'ANY', 'CAST', 'CONVERT', 'COALESCE', 'NULLIF', 'INNER', 'OUTER', 'LEFT', 'RIGHT', 'FULL',
60
+ 'CROSS', 'OVER', 'PARTITION', 'SUM', 'COUNT', 'WITH', 'INTERVAL', 'WINDOW', 'OVER',
61
+ 'ROW_NUMBER', 'RANK',
62
+ 'DENSE_RANK', 'NTILE', 'FIRST_VALUE', 'LAST_VALUE', 'LAG', 'LEAD', 'DISTINCT', 'COMMENT',
63
+ 'INSERT',
64
+ 'UPDATE', 'DELETED', 'MERGE', '*', 'generate_series', 'char', 'chr', 'substr', 'lpad',
65
+ 'extract',
66
+ 'year', 'month', 'day', 'timestamp', 'number', 'string', 'concat', 'INFORMATION_SCHEMA',
67
+ "SQLITE_MASTER", 'TABLES', 'COLUMNS', 'CUBE', 'ROLLUP', 'RECURSIVE', 'FILTER', 'EXCLUDE',
68
+ 'AUTOINCREMENT', 'WITHOUT', 'ROWID', 'VIRTUAL', 'INDEXED', 'UNINDEXED', 'SERIAL',
69
+ 'DO', 'RETURNING', 'ILIKE', 'ARRAY', 'ANYARRAY', 'JSONB', 'TSQUERY', 'SEQUENCE',
70
+ 'SYNONYM', 'CONNECT', 'START', 'LEVEL', 'ROWNUM', 'NOCOPY', 'MINUS', 'AUTO_INCREMENT', 'BINARY',
71
+ 'ENUM', 'REPLACE', 'SET', 'SHOW', 'DESCRIBE', 'USE', 'EXPLAIN', 'STORED', 'VIRTUAL', 'RLIKE',
72
+ 'MD5', 'SLEEP', 'BENCHMARK', '@@VERSION', 'VERSION', '@VERSION', 'CONVERT', 'NVARCHAR', '#',
73
+ '##', 'INJECTX',
74
+ 'DELAY', 'WAITFOR', 'RAND',
75
+ }
76
+
77
+ tokens = sql_query.split()
78
+ tokens = [re.sub(r"""[^*\w\s.=\-><_|()!"']""", '', token) for token in tokens]
79
+ for i, token in enumerate(tokens):
80
+ if token.strip().upper() in sql_reserved:
81
+ continue
82
+ if token.strip().isnumeric():
83
+ tokens[i] = '#NUMBER#'
84
+ elif re.match(r'^[a-zA-Z_.|][a-zA-Z0-9_.|]*$', token.strip()):
85
+ tokens[i] = '#IDENTIFIER#'
86
+ elif re.match(r'^[\d:]*$', token.strip()):
87
+ tokens[i] = '#TIMESTAMP#'
88
+ elif '%' in token.strip():
89
+ tokens[i] = ' '.join(
90
+ [j.strip() if j.strip() in ('%', "'", "'") else '#IDENTIFIER#' for j in token.strip().split('%')])
91
+ return ' '.join(tokens)
92
+
93
+
94
+ def add_features(x):
95
+ x['Query'] = x['Query'].copy().parallel_apply(lambda a: sql_tokenize(a))
96
+ x['num_tables'] = x['Query'].str.lower().str.count(r'FROM\s+#IDENTIFIER#', flags=re.I)
97
+ x['num_columns'] = x['Query'].str.lower().str.count(r'SELECT\s+#IDENTIFIER#', flags=re.I)
98
+ x['num_literals'] = x['Query'].str.lower().str.count("'[^']*'", flags=re.I) + x['Query'].str.lower().str.count(
99
+ '"[^"]"', flags=re.I)
100
+ x['num_parentheses'] = x['Query'].str.lower().str.count("\\(", flags=re.I) + x['Query'].str.lower().str.count(
101
+ '\\)',
102
+ flags=re.I)
103
+ x['has_union'] = x['Query'].str.lower().str.count(" union |union all", flags=re.I) > 0
104
+ x['has_union'] = x['has_union'].astype(int64)
105
+ x['depth_nested_queries'] = x['Query'].str.lower().str.count("\\(", flags=re.I)
106
+ x['num_join'] = x['Query'].str.lower().str.count(
107
+ " join |inner join|outer join|full outer join|full inner join|cross join|left join|right join",
108
+ flags=re.I)
109
+ x['num_sp_chars'] = x['Query'].parallel_apply(lambda a: len(re.findall(r'[\'";\-*/%=><|#]', a)))
110
+ x['has_mismatched_quotes'] = x['Query'].parallel_apply(
111
+ lambda sql_query: 1 if re.search(r"'.*[^']$|\".*[^\"]$", sql_query) else 0)
112
+ x['has_tautology'] = x['Query'].parallel_apply(lambda sql_query: 1 if re.search(r"'[\s]*=[\s]*'", sql_query) else 0)
113
+ return x
114
+
115
+
116
+ def is_malicious_sql(sql, threshold):
117
+ input_df = pd.DataFrame([sql], columns=['Query'])
118
+ input_df = add_features(input_df)
119
+ numeric_features = ["num_tables", "num_columns", "num_literals", "num_parentheses", "has_union",
120
+ "depth_nested_queries", "num_join", "num_sp_chars", "has_mismatched_quotes", "has_tautology"]
121
+ scaler = RobustScaler()
122
+ x_in = scaler.fit_transform(input_df[numeric_features])
123
+
124
+ preds = model.predict([input_df['Query'], x_in]).tolist()[0][0]
125
+ if preds > float(threshold):
126
+ return 'Malicious'
127
+ return 'Safe'