beyoru commited on
Commit
43ce954
·
verified ·
1 Parent(s): 8b0451a

Upload 13 files

Browse files
Files changed (14) hide show
  1. .gitattributes +2 -0
  2. app.py +4 -0
  3. client.py +106 -0
  4. client_old.py +49 -0
  5. createDB.py +69 -0
  6. data/data.db +3 -0
  7. data/fakedb.db +3 -0
  8. database.py +36 -0
  9. init.py +39 -0
  10. router.py +70 -0
  11. style.css +63 -0
  12. testdb.py +10 -0
  13. ui.py +54 -0
  14. utils.py +17 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/data.db filter=lfs diff=lfs merge=lfs -text
37
+ data/fakedb.db filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from ui import demo
2
+
3
+ if __name__ == "__main__":
4
+ demo.launch()
client.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ from init import ACCESS_TOKEN, SYSTEM_PROMPT
3
+ from utils import extract_sql, is_sql
4
+ from database import execute
5
+
6
+
7
+ client = InferenceClient()
8
+
9
+
10
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
11
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
12
+ # Xử lý lịch sử chat
13
+ for val in history:
14
+ if val[0]:
15
+ messages.append({"role": "user", "content": val[0]})
16
+ if val[1]:
17
+ messages.append({"role": "assistant", "content": val[1]})
18
+
19
+ messages.append({"role": "user", "content": message})
20
+
21
+ # Tạo response đầu tiên
22
+ response = ""
23
+ for message in client.chat.completions.create(
24
+ model="Qwen/Qwen2.5-3B-Instruct",
25
+ max_tokens=max_tokens,
26
+ stream=True,
27
+ temperature=temperature,
28
+ top_p=top_p,
29
+ messages=messages,
30
+ ):
31
+ token = message.choices[0].delta.content
32
+ response += token
33
+ yield response
34
+
35
+ # Xử lý logic SQL và retry
36
+ if is_sql(response):
37
+ sql_query = extract_sql(response)
38
+ max_attempts = 3
39
+ attempts = 0
40
+ sql_result = None
41
+ last_error = None
42
+
43
+ while attempts < max_attempts:
44
+ try:
45
+ sql_result = execute(sql_query)
46
+ break
47
+ except Exception as e:
48
+ last_error = str(e)
49
+ attempts += 1
50
+ if attempts < max_attempts:
51
+ # Thêm thông tin lỗi vào context và yêu cầu mô hình hỏi lại người dùng
52
+ clarification_prompt = f"""Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}
53
+ Bạn có thể cung cấp thêm thông tin hoặc chỉnh sửa câu hỏi để tôi có thể sửa truy vấn không?"""
54
+ messages += [
55
+ {"role": "assistant", "content": response},
56
+ {"role": "user", "content": clarification_prompt},
57
+ ]
58
+
59
+ # Tạo response yêu cầu thông tin thêm
60
+ response = ""
61
+ for message in client.chat.completions.create(
62
+ model="Qwen/Qwen2.5-3B-Instruct",
63
+ max_tokens=max_tokens,
64
+ stream=True,
65
+ temperature=temperature,
66
+ top_p=top_p,
67
+ messages=messages,
68
+ ):
69
+ token = message.choices[0].delta.content
70
+ response += token
71
+ yield response
72
+
73
+ # Nếu mô hình cung cấp SQL mới, tiếp tục thử
74
+ if is_sql(response):
75
+ sql_query = extract_sql(response)
76
+ else:
77
+ # Nếu sau 3 lần vẫn lỗi, tiếp tục hỏi lại người dùng thay vì in lỗi
78
+ retry_prompt = f"""Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}
79
+ Bạn có thể cung cấp thêm chi tiết về dữ liệu cần truy vấn không?"""
80
+ messages.append({"role": "assistant", "content": retry_prompt})
81
+ yield retry_prompt
82
+ return
83
+
84
+ # Nếu thực hiện truy vấn thành công
85
+ if sql_result is not None:
86
+ reformulation_prompt = f"""Kết quả truy vấn SQL:
87
+ {sql_result}
88
+ Hãy tóm tắt kết quả thành phản hồi tự nhiên cho người dùng."""
89
+ messages += [
90
+ {"role": "assistant", "content": response},
91
+ {"role": "user", "content": reformulation_prompt},
92
+ ]
93
+
94
+ # Tạo response tóm tắt
95
+ reformulated_response = ""
96
+ for message in client.chat.completions.create(
97
+ model="Qwen/Qwen2.5-3B-Instruct",
98
+ max_tokens=512,
99
+ stream=True,
100
+ temperature=temperature,
101
+ top_p=top_p,
102
+ messages=messages,
103
+ ):
104
+ token = message.choices[0].delta.content
105
+ reformulated_response += token
106
+ yield reformulated_response
client_old.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ from init import ACCESS_TOKEN, SYSTEM_PROMPT
3
+ from utils import extract_sql, is_sql
4
+ from database import execute
5
+
6
+ client = InferenceClient(api_key=ACCESS_TOKEN)
7
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
8
+
9
+
10
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
11
+ for val in history:
12
+ if val[0]:
13
+ messages.append({"role": "user", "content": val[0]})
14
+ if val[1]:
15
+ messages.append({"role": "assistant", "content": val[1]})
16
+
17
+ messages.append({"role": "user", "content": message})
18
+
19
+ response = ""
20
+ for message in client.chat.completions.create(
21
+ model="Qwen/Qwen2.5-3B-Instruct",
22
+ max_tokens=max_tokens,
23
+ stream=True,
24
+ temperature=temperature,
25
+ top_p=top_p,
26
+ messages=messages,
27
+ ):
28
+ token = message.choices[0].delta.content
29
+ response += token
30
+ yield response
31
+ if is_sql(response):
32
+ sql_query = extract_sql(response)
33
+ sql_result = execute(sql_query)
34
+
35
+ reformulation_prompt = f"Kết quả truy vấn SQL:\n{sql_result}\n\nHãy diễn đạt lại kết quả cho người dùng một cách dễ hiểu."
36
+ messages.append({"role": "user", "content": reformulation_prompt})
37
+
38
+ reformulated_response = ""
39
+ for msg in client.chat.completions.create(
40
+ model="Qwen/Qwen2.5-3B-Instruct",
41
+ max_tokens=512,
42
+ stream=True,
43
+ temperature=temperature,
44
+ top_p=top_p,
45
+ messages=messages,
46
+ ):
47
+ token = msg.choices[0].delta.content
48
+ reformulated_response += token
49
+ yield reformulated_response
createDB.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### This file use only for created a fakedb for testing purpose
2
+
3
+ import duckdb
4
+
5
+
6
+ conn = duckdb.connect("./data/fakedb.db")
7
+
8
+ # init all here
9
+ conn.execute(
10
+ """\
11
+ -- Tạo bảng trong DuckDB
12
+ CREATE TABLE gmes_production_report (
13
+ Model TEXT,
14
+ Process TEXT,
15
+ Total_Yield FLOAT,
16
+ Total_OK INTEGER,
17
+ Total_NG INTEGER,
18
+ Total INTEGER,
19
+ Yield_2024_02_12 FLOAT,
20
+ Yield_2024_02_13 FLOAT,
21
+ Yield_2024_02_14 FLOAT
22
+ );
23
+
24
+ -- Chèn dữ liệu giả
25
+ INSERT INTO gmes_production_report (Model, Process, Total_Yield, Total_OK, Total_NG, Total, Yield_2024_02_12, Yield_2024_02_13, Yield_2024_02_14) VALUES
26
+ ('Model A', 'Process 1', 98.5, 500, 10, 510, 98.2, 98.6, 98.4),
27
+ ('Model B', 'Process 2', 97.2, 480, 14, 494, 97.0, 97.3, 97.1),
28
+ ('Model C', 'Process 3', 99.0, 600, 6, 606, 99.1, 99.0, 98.9),
29
+ ('Model D', 'Process 1', 96.8, 450, 20, 470, 96.5, 96.7, 96.9),
30
+ ('Model E', 'Process 2', 95.5, 420, 22, 442, 95.4, 95.6, 95.3),
31
+ ('Model F', 'Process 3', 98.0, 510, 10, 520, 97.8, 98.1, 98.2),
32
+ ('Model G', 'Process 1', 99.2, 630, 5, 635, 99.0, 99.3, 99.1),
33
+ ('Model H', 'Process 2', 97.6, 470, 12, 482, 97.5, 97.7, 97.4),
34
+ ('Model I', 'Process 3', 98.9, 590, 7, 597, 98.7, 98.8, 99.0),
35
+ ('Model J', 'Process 1', 97.3, 490, 15, 505, 97.1, 97.4, 97.2),
36
+ ('Model K', 'Process 2', 96.0, 440, 18, 458, 95.8, 96.1, 95.9),
37
+ ('Model L', 'Process 3', 98.3, 520, 9, 529, 98.2, 98.4, 98.1),
38
+ ('Model M', 'Process 1', 99.1, 625, 6, 631, 99.0, 99.2, 98.9),
39
+ ('Model N', 'Process 2', 97.9, 485, 11, 496, 97.8, 98.0, 97.7),
40
+ ('Model O', 'Process 3', 98.6, 580, 8, 588, 98.5, 98.7, 98.4),
41
+ ('Model P', 'Process 1', 96.7, 445, 19, 464, 96.6, 96.8, 96.5),
42
+ ('Model Q', 'Process 2', 95.8, 430, 23, 453, 95.7, 95.9, 95.6),
43
+ ('Model R', 'Process 3', 97.4, 495, 14, 509, 97.3, 97.5, 97.2),
44
+ ('Model S', 'Process 1', 98.8, 600, 7, 607, 98.7, 98.9, 98.6),
45
+ ('Model T', 'Process 2', 97.1, 475, 13, 488, 97.0, 97.2, 97.3);
46
+
47
+ -- Tạo bảng Table Worst
48
+ CREATE TABLE table_worst (
49
+ Model TEXT,
50
+ Process TEXT,
51
+ Error_Name TEXT,
52
+ Error_Count INTEGER,
53
+ Error_Percentage FLOAT
54
+ );
55
+
56
+ -- Chèn dữ liệu giả vào Table Worst
57
+ INSERT INTO table_worst (Model, Process, Error_Name, Error_Count, Error_Percentage) VALUES
58
+ ('Model A', 'Process 1', 'Defect A', 5, 1.0),
59
+ ('Model B', 'Process 2', 'Defect B', 8, 1.6),
60
+ ('Model C', 'Process 3', 'Defect C', 3, 0.5),
61
+ ('Model D', 'Process 1', 'Defect D', 10, 2.1),
62
+ ('Model E', 'Process 2', 'Defect E', 12, 2.7),
63
+ ('Model F', 'Process 3', 'Defect F', 7, 1.3),
64
+ ('Model G', 'Process 1', 'Defect G', 4, 0.8),
65
+ ('Model H', 'Process 2', 'Defect H', 6, 1.2),
66
+ ('Model I', 'Process 3', 'Defect I', 5, 1.0),
67
+ ('Model J', 'Process 1', 'Defect J', 9, 1.8);
68
+ """
69
+ )
data/data.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dd5494e07d68aec0e4e5a166ae7b1189f03c8c109b06c0934f3fa0271141e40
3
+ size 1847296
data/fakedb.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f29b623ea3e713ae83491b3d69ad20085104f88e874a591f6ebb24d183ed59eb
3
+ size 798720
database.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import duckdb
2
+
3
+ conn = duckdb.connect("./data/fakedb.db")
4
+
5
+
6
+ def execute(sql_query):
7
+ try:
8
+ return conn.sql(sql_query).to_df().to_string()
9
+ except Exception as e:
10
+ return f"An error occurred: {str(e)}"
11
+
12
+
13
+ def formattedDB():
14
+ try:
15
+ tables = conn.execute("SHOW TABLES").fetchall()
16
+ result = ""
17
+
18
+ for table in tables:
19
+ table_name = table[0]
20
+ result += f"CREATE TABLE {table_name} (\n"
21
+ columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
22
+
23
+ column_definitions = [
24
+ f" {col[1]} {col[2]} {'NOT NULL' if col[3] else ''} {'DEFAULT ' + str(col[4]) if col[4] else ''}".strip()
25
+ for col in columns
26
+ ]
27
+
28
+ result += ",\n".join(column_definitions)
29
+ result += "\n);\n"
30
+
31
+ return result
32
+ except Exception as e:
33
+ return f"An error occurred: {str(e)}"
34
+
35
+
36
+ db_schema = formattedDB()
init.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from database import db_schema
2
+ import os
3
+ from router import gmes, worst
4
+
5
+ ACCESS_TOKEN = os.getenv("HF_TOKEN")
6
+
7
+ SYSTEM_PROMPT = f"""You are a helpful assistant with the ability to generate valid DuckDB SQL queries based on a given database schema.
8
+
9
+ Here is the database schema that the SQL query will run on:
10
+ {db_schema}
11
+
12
+ ### Table descriptions:
13
+ {gmes}
14
+
15
+ {worst}
16
+
17
+ ### Guidelines for generating SQL queries:
18
+ 1. Generate an SQL query **only if**:
19
+ - The question can be answered directly using the given schema.
20
+ - The required tables and columns exist in the schema.
21
+ - The query is a valid `SELECT` statement (no `INSERT`, `UPDATE`, or `DELETE`).
22
+ - The question has a clear meaning without ambiguity.
23
+
24
+ 2. Ask the user for clarification **if**:
25
+ - The question is vague or open-ended.
26
+ - The necessary tables or columns are missing from the schema.
27
+ - The question requires additional details.
28
+ - There are multiple possible interpretations of the question.
29
+
30
+ 3. Do **not** generate an SQL query **if**:
31
+ - The request is unrelated to the database schema.
32
+ - The query requires modifying data instead of reading it.
33
+ - The question involves computations too complex for SQL alone.
34
+
35
+ If the question is valid and meets the above criteria, return an SQL query in the following format:
36
+ ```sql
37
+ <SQL query>
38
+ ```
39
+ """
router.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Router methods
2
+ from sentence_transformers import SentenceTransformer
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+
5
+ model = SentenceTransformer("sentence-transformers/stsb-xlm-r-multilingual")
6
+
7
+ gmes = """## **1. Bảng gmes_production_report**
8
+ Bảng này lưu trữ **dữ liệu hiệu suất sản xuất** cho các mô hình và quy trình khác nhau. Bảng theo dõi tỷ lệ năng suất, tổng số lượng sản xuất và hiệu suất hàng ngày theo thời gian.
9
+
10
+ ### **Cột:**
11
+ - **Mô hình (`TEXT`)** – Tên hoặc mã định danh của mô hình sản phẩm đang được sản xuất.
12
+ - **Quy trình (`TEXT`)** – Quy trình hoặc giai đoạn sản xuất cụ thể (ví dụ: lắp ráp, thử nghiệm).
13
+ - **Tổng_năng suất (`FLOAT`)** – Tỷ lệ năng suất chung cho mô hình trong quy trình đó, được tính là `(Tổng_đồng ý / Tổng) * 100`.
14
+ - **Total_OK (`INTEGER`)** – Tổng số đơn vị đã vượt qua kiểm soát chất lượng.
15
+ - **Total_NG (`INTEGER`)** – Tổng số đơn vị bị lỗi (không tốt) không vượt qua kiểm soát chất lượng.
16
+ - **Total (`INTEGER`)** – Tổng số đơn vị đã xử lý (tổng của `Total_OK` và `Total_NG`).
17
+ - **Yield_2024_02_12 (`FLOAT`)** – Tỷ lệ phần trăm sản lượng được ghi nhận vào **ngày 12 tháng 2 năm 2024**.
18
+ - **Yield_2024_02_13 (`FLOAT`)** – Tỷ lệ phần trăm sản lượng được ghi nhận vào **ngày 13 tháng 2 năm 2024**.
19
+ - **Yield_2024_02_14 (`FLOAT`)** – Tỷ lệ phần trăm sản lượng được ghi nhận vào **ngày 14 tháng 2 năm 2024**.
20
+
21
+ ### **Cách sử dụng:**
22
+ - Giúp theo dõi **hiệu quả sản xuất** theo thời gian.
23
+ - Cho phép **phân tích xu hướng năng suất** hàng ngày.
24
+ - Hỗ trợ **đánh giá kiểm soát chất lượng** bằng cách so sánh tỷ lệ lỗi giữa các mô hình và quy trình khác nhau.
25
+ """
26
+
27
+
28
+ worst = """## **2. table_worst Bảng**
29
+ Bảng này theo dõi **thông tin liên quan đến lỗi**, làm nổi bật các lỗi phổ biến nhất xảy ra trong quá trình sản xuất.
30
+
31
+ ### **Cột:**
32
+ - **Mô hình (`TEXT`)** – Mô hình sản phẩm liên quan đến lỗi đã ghi lại.
33
+ - **Quy trình (`TEXT`)** – Quy trình sản xuất cụ thể nơi xảy ra lỗi.
34
+ - **Error_Name (`TEXT`)** – Tên hoặc danh mục lỗi (ví dụ: "Lỗi A", "Sai lệch").
35
+ - **Error_Count (`INTEGER`)** – Số lần lỗi này được ghi lại đối với mô hình và quy trình đã cho.
36
+ - **Error_Percentage (`FLOAT`)** – Tỷ lệ phần trăm các đơn vị bị lỗi do lỗi cụ thể này, được tính là `(Error_Count / Total) * 100`.
37
+
38
+ ### **Cách sử dụng:**
39
+ - Giúp xác định **các lỗi có vấn đề** trong dây chuyền sản xuất.
40
+ - Cho phép **phân tích nguyên nhân gốc rễ** bằng cách liên kết các lỗi với các quy trình cụ thể.
41
+ - Hỗ trợ **cải tiến liên tục** trong kiểm soát chất lượng bằng cách giải quyết các lỗi thường gặp nhất.
42
+
43
+ """
44
+
45
+
46
+ def create_metadata_embedings(metadata: list, model):
47
+ embeddings = model.encode(metadata)
48
+ return embeddings
49
+
50
+
51
+ def find_best_fit(embeddings, model, user_query):
52
+ query_embedding = model.encode([user_query])
53
+ similarities = cosine_similarity(query_embedding, embeddings)
54
+ best_match_table = similarities.argmax()
55
+ if best_match_table == 0:
56
+ table_metadata = gmes
57
+ elif best_match_table == 1:
58
+ table_metadata = worst
59
+
60
+ return table_metadata
61
+
62
+
63
+ user_query = "Tôi muốn biết tổng lỗi lặp của Model A"
64
+
65
+ metadata = [gmes, worst]
66
+ embeddings = create_metadata_embedings(metadata, model)
67
+ table_metadata = find_best_fit(embeddings, model, user_query)
68
+
69
+
70
+ print(table_metadata)
style.css ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* General styles */
2
+ body {
3
+ font-family: 'Arial', sans-serif;
4
+ background-color: #f4f4f9;
5
+ margin: 0;
6
+ padding: 0;
7
+ }
8
+
9
+ /* Chatbot container */
10
+ .gradio-container {
11
+ max-width: 800px;
12
+ margin: auto;
13
+ background: #ffffff;
14
+ padding: 20px;
15
+ border-radius: 10px;
16
+ box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.1);
17
+ }
18
+
19
+ /* Chatbot messages */
20
+ .gradio-chatbot {
21
+ background-color: #f9f9f9;
22
+ border-radius: 8px;
23
+ padding: 15px;
24
+ height: 600px;
25
+ overflow-y: auto;
26
+ }
27
+
28
+ /* User input box */
29
+ input[type="text"],
30
+ textarea {
31
+ width: 100%;
32
+ padding: 12px;
33
+ margin-top: 10px;
34
+ border: 1px solid #ccc;
35
+ border-radius: 5px;
36
+ font-size: 16px;
37
+ }
38
+
39
+ /* Sliders */
40
+ .gradio-slider {
41
+ margin-top: 15px;
42
+ }
43
+
44
+ .gradio-slider label {
45
+ font-weight: bold;
46
+ color: #333;
47
+ }
48
+
49
+ /* Buttons */
50
+ button {
51
+ background-color: #007bff;
52
+ color: white;
53
+ border: none;
54
+ padding: 12px 20px;
55
+ margin-top: 10px;
56
+ cursor: pointer;
57
+ border-radius: 5px;
58
+ font-size: 16px;
59
+ }
60
+
61
+ button:hover {
62
+ background-color: #0056b3;
63
+ }
testdb.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import duckdb
2
+
3
+ conn = duckdb.connect("./data/data.db")
4
+
5
+
6
+ conn.sql(
7
+ """\
8
+ SELECT * FROM Users
9
+ """
10
+ ).show()
ui.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import gradio as gr
3
+ import requests
4
+ from client import respond
5
+ from huggingface_hub.errors import HfHubHTTPError
6
+
7
+
8
+ """
9
+ API Huggingface some time return 503 error, so we need to retry multiple times
10
+ """
11
+
12
+
13
+ def robust_respond(*args, **kwargs):
14
+ max_retries = 10
15
+ wait_time = 2
16
+
17
+ for attempt in range(max_retries):
18
+ try:
19
+ yield from respond(*args, **kwargs)
20
+ return
21
+ except HfHubHTTPError as e:
22
+ if "503" in str(e):
23
+ print(
24
+ f"Attempt {attempt+1}: Hugging Face API is down. Retrying in {wait_time}s..."
25
+ )
26
+ time.sleep(wait_time)
27
+ wait_time *= 2
28
+ else:
29
+ yield f"Error: {str(e)}"
30
+ return
31
+
32
+ yield "Server busy right now !"
33
+
34
+
35
+ chatbot = gr.Chatbot(height=600)
36
+
37
+ demo = gr.ChatInterface(
38
+ robust_respond,
39
+ additional_inputs=[
40
+ gr.Textbox(value="", label="System message"),
41
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
42
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
43
+ gr.Slider(
44
+ minimum=0.1,
45
+ maximum=1.0,
46
+ value=0.95,
47
+ step=0.05,
48
+ label="Top-P",
49
+ ),
50
+ ],
51
+ fill_height=True,
52
+ chatbot=chatbot,
53
+ theme="Nymbo/Nymbo_Theme",
54
+ )
utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ # def extract_sql(response):
5
+ # match = re.search(r"```sql\s+(.*?)\s+```", response, re.DOTALL | re.IGNORECASE)
6
+ # return match.group(1) if match else None
7
+
8
+
9
+ def extract_sql(response):
10
+ matches = re.findall(r"```sql\s+(.*?)\s+```", response, re.DOTALL | re.IGNORECASE)
11
+ if matches:
12
+ return matches[0].strip()
13
+ return None
14
+
15
+
16
+ def is_sql(response):
17
+ return bool(re.search(r"```sql\s+.*?```", response, re.DOTALL | re.IGNORECASE))