beyoru commited on
Commit
678b978
·
verified ·
1 Parent(s): 27ce8e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -93
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import duckdb
2
  import pandas as pd
3
- from tabulate import tabulate
4
  import gradio as gr
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
@@ -8,93 +7,65 @@ from sentence_transformers import SentenceTransformer, util
8
  import threading
9
 
10
  # --------------------------
11
- # Setup: Load data and models
12
  # --------------------------
13
 
14
- # Load Excel data into a pandas DataFrame.
15
  df = pd.read_excel("mau_bao_cao.xlsx")
16
 
17
- # (Optional) Create a DuckDB table if needed.
18
  conn = duckdb.connect('mau_bao_cao.db')
19
  conn.execute("""\
20
  CREATE TABLE IF NOT EXISTS production_data AS
21
  SELECT * FROM read_xlsx('mau_bao_cao.xlsx');
22
  """)
23
 
24
- # Load embedding model for computing embeddings.
 
 
 
25
  embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct")
26
  column_names = df.columns.tolist()
27
  column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True)
28
  row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1)
29
  row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True)
30
 
31
- # Load Qwen model and tokenizer for conversational generation.
32
- fc_model = AutoModelForCausalLM.from_pretrained(
33
- "Qwen/Qwen2.5-3B-Instruct",
34
- torch_dtype=torch.float16,
35
- )
36
-
37
  fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct')
38
 
39
- def extract_table(user_query: str):
40
- """
41
- Dựa trên câu truy vấn của người dùng:
42
- - Tính embedding cho câu truy vấn.
43
- - Lấy top k cột và top m dòng phù hợp.
44
- - Trả về DataFrame đã lọc, danh sách tên cột và bảng dạng text.
45
- """
46
- # Embed câu truy vấn
47
- question_embedding = embedding_model.encode(user_query, convert_to_tensor=True)
48
-
49
- # Lấy top 3 cột phù hợp
50
- k = 3
51
- column_similarities = util.cos_sim(question_embedding, column_embeddings)[0]
52
- best_column_indices = torch.topk(column_similarities, k).indices.tolist()
53
- best_column_names = [column_names[i] for i in best_column_indices]
54
-
55
- # Lấy top 10 dòng phù hợp
56
- row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0)
57
- m = 10
58
- best_row_indices = torch.topk(row_similarities, m).indices.tolist()
59
- filtered_df = df.iloc[best_row_indices][best_column_names]
60
-
61
- # Tạo bảng text (dùng cho prompt cho mô hình)
62
- table_text = tabulate(filtered_df, headers=best_column_names, tablefmt="grid")
63
-
64
- return filtered_df, best_column_names, table_text
65
-
66
  # --------------------------
67
- # Define the streaming generation function
68
  # --------------------------
69
 
70
  def generate_response(user_query: str):
71
  """
72
- A generator function that:
73
- 1. Embeds the query.
74
- 2. Selects top matching columns and rows from the data.
75
- 3. Prepares a system prompt with the extracted table.
76
- 4. Uses TextIteratorStreamer to stream the generated response.
77
- Yields the partial generated text as it is updated.
78
  """
79
- # 1. Embed the user query.
80
  question_embedding = embedding_model.encode(user_query, convert_to_tensor=True)
81
-
82
- # 2. Find best matching columns (top 10).
83
- k = 10
84
  column_similarities = util.cos_sim(question_embedding, column_embeddings)[0]
85
  best_column_indices = torch.topk(column_similarities, k).indices.tolist()
86
  best_column_names = [column_names[i] for i in best_column_indices]
87
-
88
- # 3. Select top matching rows (top 10).
89
  row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0)
90
  m = 10
91
  best_row_indices = torch.topk(row_similarities, m).indices.tolist()
92
  filtered_df = df.iloc[best_row_indices][best_column_names]
93
-
94
- # 4. Format the filtered data as a table.
95
- table_output = tabulate(filtered_df, headers=best_column_names, tablefmt="grid")
96
-
97
- # 5. Build the system prompt.
 
98
  system_prompt = f"""\
99
  Bạn là một trợ lý báo cáo sản xuất thông minh, chuyên phân tích và tổng hợp dữ liệu một cách rõ ràng, dễ hiểu.
100
  **_Chỉ báo cáo nếu người dùng yêu cầu mà nếu không thì cứ giao tiếp bình thường với họ._**
@@ -131,21 +102,17 @@ Ví dụ:
131
 
132
  Bạn đã sẵn sàng phân tích và đưa ra báo cáo!
133
  """
134
-
135
- # 6. Create the conversation messages.
136
  messages = [
137
  {'role': 'system', 'content': system_prompt},
138
  {'role': 'user', 'content': user_query}
139
  ]
140
-
141
- # 7. Prepare the prompt for the model.
142
  response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
143
- response_inputs = fc_tokenizer(response_template, return_tensors="pt")
144
-
145
- # 8. Use TextIteratorStreamer to yield tokens as they are generated.
146
  streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True)
147
-
148
- # Start generation in a separate thread so we can yield tokens as they arrive.
149
  thread = threading.Thread(
150
  target=lambda: fc_model.generate(
151
  **response_inputs,
@@ -156,49 +123,44 @@ Bạn đã sẵn sàng phân tích và đưa ra báo cáo!
156
  )
157
  )
158
  thread.start()
159
-
160
- # 9. Yield tokens incrementally.
161
  collected_text = ""
162
  for new_text in streamer:
163
  collected_text += new_text
164
  yield collected_text
165
 
166
  # --------------------------
167
- # Build the Gradio conversation interface
168
  # --------------------------
169
 
170
  def chat_interface(user_message, history):
171
  """
172
- A generator function for Gradio that:
173
- - Updates the conversation history with the user message.
174
- - Streams the model's response token-by-token in real time.
175
- The history is maintained as a list of pairs [user_message, bot_response].
176
  """
177
- # Create a new conversation entry with user message and an empty bot response.
178
  history.append([user_message, ""])
179
- # Yield the initial state.
180
  yield "", history
181
-
182
- # Stream tokens from the generate_response generator.
183
  for partial_response in generate_response(user_message):
184
- # Update the latest conversation entry with the partial bot response.
185
  history[-1][1] = partial_response
186
  yield "", history
187
 
188
- with gr.Blocks() as demo:
189
- gr.Markdown("## Giao diện Chat với Streaming Hiển thị Bảng Dữ liệu")
190
- chatbot = gr.Chatbot()
191
- state = gr.State([]) # duy trì lịch sử chat dưới dạng danh sách các cặp
192
- table_display = gr.Dataframe(label="Bảng dữ liệu liên quan") # hiển thị bảng dữ liệu cho người dùng
193
-
194
- with gr.Row():
195
- txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False)
196
- send_btn = gr.Button("Gửi")
197
-
198
- # Cả submit của Textbox và click của nút gửi đều kích hoạt hàm chat_interface,
199
- # trả về 3 outputs: textbox, chatbot và table_display.
200
- txt.submit(chat_interface, inputs=[txt, state], outputs=[txt, chatbot, table_display], queue=True)
201
- send_btn.click(chat_interface, inputs=[txt, state], outputs=[txt, chatbot, table_display], queue=True)
202
-
203
 
204
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import duckdb
2
  import pandas as pd
 
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
7
  import threading
8
 
9
  # --------------------------
10
+ # Setup: Load dữ liệu và mô hình
11
  # --------------------------
12
 
13
+ # Đọc dữ liệu từ file Excel vào DataFrame
14
  df = pd.read_excel("mau_bao_cao.xlsx")
15
 
16
+ # Tạo bảng production_data trong DuckDB (nếu cần)
17
  conn = duckdb.connect('mau_bao_cao.db')
18
  conn.execute("""\
19
  CREATE TABLE IF NOT EXISTS production_data AS
20
  SELECT * FROM read_xlsx('mau_bao_cao.xlsx');
21
  """)
22
 
23
+ # Lấy mẫu bảng production_data để hiển thị (ở đây dùng 10 dòng đầu)
24
+ production_data_df = df.head(10)
25
+
26
+ # Load mô hình embedding để tính embedding cho cột và dòng dữ liệu
27
  embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct")
28
  column_names = df.columns.tolist()
29
  column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True)
30
  row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1)
31
  row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True)
32
 
33
+ # Load mô hình Qwen tokenizer cho việc tạo phản hồi
34
+ fc_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', torch_dtype=torch.float16, device_map="auto")
 
 
 
 
35
  fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct')
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # --------------------------
38
+ # Hàm tạo phản hồi streaming theo thời gian thực
39
  # --------------------------
40
 
41
  def generate_response(user_query: str):
42
  """
43
+ Hàm này sẽ:
44
+ - Tính embedding cho câu truy vấn của người dùng.
45
+ - Chọn ra top 3 cột top 10 dòng phù hợp từ dữ liệu.
46
+ - Tạo system prompt bao gồm bảng dữ liệu đã được format bằng tabulate.
47
+ - Sử dụng TextIteratorStreamer để stream phản hồi từ mô hình theo thời gian thực.
 
48
  """
49
+ # Tính embedding cho câu truy vấn
50
  question_embedding = embedding_model.encode(user_query, convert_to_tensor=True)
51
+
52
+ # Chọn top 3 cột phù hợp
53
+ k = 7
54
  column_similarities = util.cos_sim(question_embedding, column_embeddings)[0]
55
  best_column_indices = torch.topk(column_similarities, k).indices.tolist()
56
  best_column_names = [column_names[i] for i in best_column_indices]
57
+
58
+ # Chọn top 10 dòng phù hợp
59
  row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0)
60
  m = 10
61
  best_row_indices = torch.topk(row_similarities, m).indices.tolist()
62
  filtered_df = df.iloc[best_row_indices][best_column_names]
63
+
64
+ # Format bảng dữ liệu sử dụng tabulate
65
+ from tabulate import tabulate
66
+ table_text = tabulate(filtered_df, headers=best_column_names, tablefmt="grid")
67
+
68
+ # Tạo system prompt chứa thông tin bảng dữ liệu
69
  system_prompt = f"""\
70
  Bạn là một trợ lý báo cáo sản xuất thông minh, chuyên phân tích và tổng hợp dữ liệu một cách rõ ràng, dễ hiểu.
71
  **_Chỉ báo cáo nếu người dùng yêu cầu mà nếu không thì cứ giao tiếp bình thường với họ._**
 
102
 
103
  Bạn đã sẵn sàng phân tích và đưa ra báo cáo!
104
  """
 
 
105
  messages = [
106
  {'role': 'system', 'content': system_prompt},
107
  {'role': 'user', 'content': user_query}
108
  ]
109
+
 
110
  response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
111
+ response_inputs = fc_tokenizer(response_template, return_tensors="pt").to(fc_model.device)
112
+
113
+ # Dùng TextIteratorStreamer để stream phản hồi
114
  streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True)
115
+
 
116
  thread = threading.Thread(
117
  target=lambda: fc_model.generate(
118
  **response_inputs,
 
123
  )
124
  )
125
  thread.start()
126
+
 
127
  collected_text = ""
128
  for new_text in streamer:
129
  collected_text += new_text
130
  yield collected_text
131
 
132
  # --------------------------
133
+ # Hàm giao diện chat
134
  # --------------------------
135
 
136
  def chat_interface(user_message, history):
137
  """
138
+ Hàm này sẽ:
139
+ - Thêm tin nhắn của người dùng vào lịch sử chat (dưới dạng cặp [tin nhắn người dùng, phản hồi AI]).
140
+ - Stream phản hồi từ hình theo thời gian thực và cập nhật lịch sử.
 
141
  """
 
142
  history.append([user_message, ""])
 
143
  yield "", history
 
 
144
  for partial_response in generate_response(user_message):
 
145
  history[-1][1] = partial_response
146
  yield "", history
147
 
148
+ # --------------------------
149
+ # Xây dựng giao diện Gradio với 2 tab: Chat Production Data Sample
150
+ # --------------------------
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ with gr.Blocks() as demo:
153
+ gr.Markdown("## Giao diện Chat và Hiển thị Bảng production_data Mẫu")
154
+ with gr.Tabs():
155
+ with gr.TabItem("Chat"):
156
+ chatbot = gr.Chatbot()
157
+ state = gr.State([])
158
+ with gr.Row():
159
+ txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False)
160
+ send_btn = gr.Button("Gửi")
161
+ txt.submit(chat_interface, inputs=[txt, state], outputs=[txt, chatbot], queue=True)
162
+ send_btn.click(chat_interface, inputs=[txt, state], outputs=[txt, chatbot], queue=True)
163
+ with gr.TabItem("Production Data Sample"):
164
+ gr.Markdown("Dưới đây là bảng **production_data** mẫu:")
165
+ production_table = gr.Dataframe(value=production_data_df, label="Production Data Sample")
166
+ demo.launch()