beyoru commited on
Commit
09830df
·
verified ·
1 Parent(s): 6ca9284

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import duckdb
2
+ import pandas as pd
3
+ from tabulate import tabulate
4
+ import gradio as gr
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
+ from sentence_transformers import SentenceTransformer, util
8
+ import threading
9
+
10
+ # --------------------------
11
+ # Setup: Load data and models
12
+ # --------------------------
13
+
14
+ df = pd.read_excel("mau_bao_cao.xlsx")
15
+ conn = duckdb.connect('mau_bao_cao.db')
16
+ conn.execute("""\
17
+ CREATE TABLE IF NOT EXISTS production_data AS
18
+ SELECT * FROM read_xlsx('mau_bao_cao.xlsx');
19
+ """)
20
+
21
+ # Load embedding model for computing embeddings.
22
+ embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct")
23
+ column_names = df.columns.tolist()
24
+ column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True)
25
+ row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1)
26
+ row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True)
27
+
28
+ # Load Qwen model and tokenizer for conversational generation.
29
+ fc_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', torch_dtype=torch.float16, device_map="auto")
30
+ fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct')
31
+
32
+ # --------------------------
33
+ # Define the streaming generation function
34
+ # --------------------------
35
+
36
+ def generate_response(user_query: str):
37
+ """
38
+ A generator function that:
39
+ 1. Embeds the query.
40
+ 2. Selects top matching columns and rows from the data.
41
+ 3. Prepares a system prompt with the extracted table.
42
+ 4. Uses TextIteratorStreamer to stream the generated response.
43
+ Yields the partial generated text as it is updated.
44
+ """
45
+ # 1. Embed the user query.
46
+ question_embedding = embedding_model.encode(user_query, convert_to_tensor=True)
47
+
48
+ # 2. Find best matching columns (top 10).
49
+ k = 3
50
+ column_similarities = util.cos_sim(question_embedding, column_embeddings)[0]
51
+ best_column_indices = torch.topk(column_similarities, k).indices.tolist()
52
+ best_column_names = [column_names[i] for i in best_column_indices]
53
+
54
+ # 3. Select top matching rows (top 10).
55
+ row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0)
56
+ m = 10
57
+ best_row_indices = torch.topk(row_similarities, m).indices.tolist()
58
+ filtered_df = df.iloc[best_row_indices][best_column_names]
59
+
60
+ # 4. Format the filtered data as a table.
61
+ table_output = tabulate(filtered_df, headers=best_column_names, tablefmt="grid")
62
+
63
+ # 5. Build the system prompt.
64
+ system_prompt = """\
65
+ Bạn là một trợ lý báo cáo sản xuất thông minh, chuyên phân tích và tổng hợp dữ liệu một cách rõ ràng, dễ hiểu.
66
+
67
+ Dưới đây là dữ liệu bạn cần phân tích:
68
+
69
+ 🔹 Các cột dữ liệu liên quan: {', '.join(best_column_names)}
70
+ 🔹 Bảng dữ liệu:
71
+ {table_output}
72
+
73
+ 📌 Nhiệm vụ của bạn:
74
+
75
+ Tóm tắt số liệu quan trọng, tránh liệt kê máy móc.
76
+
77
+ Nhận xét về xu hướng và điểm bất thường.
78
+
79
+ Nếu có thể, đề xuất giải pháp hoặc hành động tiếp theo.
80
+
81
+ 📊 Cách trả lời:
82
+ ✔️ Tự nhiên, dễ hiểu, không quá cứng nhắc.
83
+ ✔️ Không cần nhắc lại bảng dữ liệu, hãy diễn giải nó.
84
+ ✔️ Trả lời đúng trọng tâm, không dư thừa.
85
+
86
+ Ví dụ:
87
+
88
+ 🔹 "Hôm nay, sản lượng đạt 95%, cao hơn 5% so với tuần trước."
89
+
90
+ ⚠️ "Dây chuyền A đang giảm hiệu suất, cần theo dõi thêm."
91
+
92
+ 🚀 "Nếu duy trì tốc độ này, sản lượng tháng có thể vượt kế hoạch 10%."
93
+
94
+ Bạn đã sẵn sàng phân tích và đưa ra báo cáo!
95
+ """
96
+
97
+ # 6. Create the conversation messages.
98
+ messages = [
99
+ {'role': 'system', 'content': system_prompt},
100
+ {'role': 'user', 'content': user_query}
101
+ ]
102
+
103
+ # 7. Prepare the prompt for the model.
104
+ response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105
+ response_inputs = fc_tokenizer(response_template, return_tensors="pt").to(fc_model.device)
106
+
107
+ # 8. Use TextIteratorStreamer to yield tokens as they are generated.
108
+ streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True)
109
+
110
+ # Start generation in a separate thread so we can yield tokens as they arrive.
111
+ thread = threading.Thread(
112
+ target=lambda: fc_model.generate(
113
+ **response_inputs,
114
+ max_new_tokens=512,
115
+ temperature=1,
116
+ top_p=0.95,
117
+ streamer=streamer
118
+ )
119
+ )
120
+ thread.start()
121
+
122
+ # 9. Yield tokens incrementally.
123
+ collected_text = ""
124
+ for new_text in streamer:
125
+ collected_text += new_text
126
+ yield collected_text
127
+
128
+ # --------------------------
129
+ # Build the Gradio conversation interface
130
+ # --------------------------
131
+
132
+ def chat_interface(user_message, history):
133
+ """
134
+ A generator function for Gradio that:
135
+ - Updates the conversation history with the user message.
136
+ - Streams the model's response token-by-token in real time.
137
+ The history is maintained as a list of pairs [user_message, bot_response].
138
+ """
139
+ # Create a new conversation entry with user message and an empty bot response.
140
+ history.append([user_message, ""])
141
+ # Yield the initial state.
142
+ yield "", history
143
+
144
+ # Stream tokens from the generate_response generator.
145
+ for partial_response in generate_response(user_message):
146
+ # Update the latest conversation entry with the partial bot response.
147
+ history[-1][1] = partial_response
148
+ yield "", history
149
+
150
+ with gr.Blocks() as demo:
151
+ gr.Markdown("## Gradio Chat Interface with Real-Time Streaming")
152
+ chatbot = gr.Chatbot()
153
+ state = gr.State([]) # maintain conversation history as a list of pairs
154
+ with gr.Row():
155
+ txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False)
156
+ send_btn = gr.Button("Gửi")
157
+
158
+ # Both submit and click trigger the chat_interface generator.
159
+ txt.submit(chat_interface, [txt, state], [txt, chatbot], queue=True)
160
+ send_btn.click(chat_interface, [txt, state], [txt, chatbot], queue=True)
161
+
162
+ demo.launch()