Spaces:
Sleeping
Sleeping
import duckdb | |
import pandas as pd | |
from tabulate import tabulate | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from sentence_transformers import SentenceTransformer, util | |
import threading | |
# -------------------------- | |
# Setup: Load data and models | |
# -------------------------- | |
# Load dữ liệu từ file Excel vào DataFrame | |
df = pd.read_excel("mau_bao_cao.xlsx") | |
# (Tùy chọn) Tạo bảng trong DuckDB nếu cần | |
conn = duckdb.connect('mau_bao_cao.db') | |
conn.execute("""\ | |
CREATE TABLE IF NOT EXISTS production_data AS | |
SELECT * FROM read_xlsx('mau_bao_cao.xlsx'); | |
""") | |
# Load mô hình embedding để tính toán embedding cho cột và dòng dữ liệu | |
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct") | |
column_names = df.columns.tolist() | |
column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True) | |
row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1) | |
row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True) | |
# Load mô hình Qwen và tokenizer cho việc tạo phản hồi | |
fc_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', torch_dtype=torch.float16, device_map="auto") | |
fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct') | |
# -------------------------- | |
# Helper function: Trích xuất bảng dữ liệu | |
# -------------------------- | |
def extract_table(user_query: str): | |
""" | |
Dựa trên câu truy vấn của người dùng: | |
- Tính embedding cho câu truy vấn. | |
- Lấy top k cột và top m dòng phù hợp. | |
- Trả về DataFrame đã lọc, danh sách tên cột và bảng dạng text. | |
""" | |
# Embed câu truy vấn | |
question_embedding = embedding_model.encode(user_query, convert_to_tensor=True) | |
# Lấy top 3 cột phù hợp | |
k = 3 | |
column_similarities = util.cos_sim(question_embedding, column_embeddings)[0] | |
best_column_indices = torch.topk(column_similarities, k).indices.tolist() | |
best_column_names = [column_names[i] for i in best_column_indices] | |
# Lấy top 10 dòng phù hợp | |
row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0) | |
m = 10 | |
best_row_indices = torch.topk(row_similarities, m).indices.tolist() | |
filtered_df = df.iloc[best_row_indices][best_column_names] | |
# Tạo bảng text (dùng cho prompt cho mô hình) | |
table_text = tabulate(filtered_df, headers=best_column_names, tablefmt="grid") | |
return filtered_df, best_column_names, table_text | |
# -------------------------- | |
# Hàm streaming tạo phản hồi từ mô hình | |
# -------------------------- | |
def generate_response(user_query: str): | |
""" | |
Hàm generator để: | |
- Trích xuất bảng dữ liệu dựa trên câu truy vấn. | |
- Tạo system prompt dựa trên bảng dữ liệu đã trích xuất. | |
- Dùng TextIteratorStreamer để tạo phản hồi theo thời gian thực. | |
Yields (trả về) phản hồi được cập nhật theo từng token. | |
""" | |
# Lấy bảng dữ liệu liên quan | |
filtered_df, best_column_names, table_text = extract_table(user_query) | |
# Tạo system prompt có chứa thông tin bảng dữ liệu | |
system_prompt = f"""\ | |
Bạn là một trợ lý báo cáo sản xuất thông minh, chuyên phân tích và tổng hợp dữ liệu một cách rõ ràng, dễ hiểu. | |
**_Chỉ báo cáo nếu người dùng yêu cầu mà nếu không thì cứ giao tiếp bình thường với họ._** | |
Dưới đây là dữ liệu bạn cần phân tích: | |
🔹 Các cột dữ liệu liên quan: {', '.join(best_column_names)} | |
🔹 Bảng dữ liệu: | |
{table_output} | |
📌 Nhiệm vụ của bạn: | |
Tóm tắt số liệu quan trọng, tránh liệt kê máy móc. | |
Nhận xét về xu hướng và điểm bất thường. | |
Nếu có thể, đề xuất giải pháp hoặc hành động tiếp theo. | |
📊 Cách trả lời: | |
✔️ Tự nhiên, dễ hiểu, không quá cứng nhắc. | |
✔️ Không cần nhắc lại bảng dữ liệu, hãy diễn giải nó. | |
✔️ Trả lời đúng trọng tâm, không dư thừa. | |
✔️ Nếu người dùng không hỏi về bảng dữ liệu, hãy chỉ giao tiếp bình thường. | |
✔️ Mô hình hóa dữ câu trả lời nếu cần thiết, giúp người dùng dễ hiểu hơn về câu trả lời. | |
Ví dụ: | |
🔹 "Hôm nay, sản lượng đạt 95%, cao hơn 5% so với tuần trước." | |
⚠️ "Dây chuyền A đang giảm hiệu suất, cần theo dõi thêm." | |
🚀 "Nếu duy trì tốc độ này, sản lượng tháng có thể vượt kế hoạch 10%." | |
🚀 "Không có gì nếu bạn cần thêm thông tin chi tiết hãy nói cho tôi biết nhé ;))" | |
Bạn đã sẵn sàng phân tích và đưa ra báo cáo! | |
""" | |
messages = [ | |
{'role': 'system', 'content': system_prompt}, | |
{'role': 'user', 'content': user_query} | |
] | |
response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
response_inputs = fc_tokenizer(response_template, return_tensors="pt").to(fc_model.device) | |
# Sử dụng TextIteratorStreamer để stream phản hồi | |
streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Khởi chạy generation trong một thread riêng | |
thread = threading.Thread( | |
target=lambda: fc_model.generate( | |
**response_inputs, | |
max_new_tokens=512, | |
temperature=1, | |
top_p=0.95, | |
streamer=streamer | |
) | |
) | |
thread.start() | |
collected_text = "" | |
for new_text in streamer: | |
collected_text += new_text | |
yield collected_text | |
# -------------------------- | |
# Hàm giao diện chat của Gradio | |
# -------------------------- | |
def chat_interface(user_message, history): | |
""" | |
Generator cho giao diện chat: | |
- Cập nhật lịch sử cuộc trò chuyện với tin nhắn của người dùng. | |
- Tính toán bảng dữ liệu dựa trên truy vấn và cập nhật component hiển thị bảng. | |
- Stream phản hồi của mô hình theo thời gian thực. | |
Lịch sử cuộc trò chuyện được duy trì dưới dạng danh sách các cặp [tin nhắn người dùng, phản hồi AI]. | |
Hàm trả về 3 giá trị: giá trị cho Textbox (reset), lịch sử chat, và bảng dữ liệu (dạng DataFrame). | |
""" | |
# Trích xuất bảng để hiển thị cho người dùng | |
filtered_df, _, _ = extract_table(user_message) | |
# Thêm một cặp tin nhắn mới với phản hồi AI ban đầu là chuỗi rỗng. | |
history.append([user_message, ""]) | |
# Yield trạng thái ban đầu: clear textbox, lịch sử chat cập nhật, và bảng dữ liệu đã trích xuất. | |
yield "", history, filtered_df | |
# Stream phản hồi từ mô hình theo thời gian thực | |
for partial_response in generate_response(user_message): | |
history[-1][1] = partial_response | |
yield "", history, filtered_df | |
# -------------------------- | |
# Xây dựng giao diện Gradio | |
# -------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## Giao diện Chat với Streaming và Hiển thị Bảng Dữ liệu") | |
chatbot = gr.Chatbot() | |
state = gr.State([]) # duy trì lịch sử chat dưới dạng danh sách các cặp | |
table_display = gr.Dataframe(label="Bảng dữ liệu liên quan") # hiển thị bảng dữ liệu cho người dùng | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False) | |
send_btn = gr.Button("Gửi") | |
# Cả submit của Textbox và click của nút gửi đều kích hoạt hàm chat_interface, | |
# trả về 3 outputs: textbox, chatbot và table_display. | |
txt.submit(chat_interface, inputs=[txt, state], outputs=[txt, chatbot, table_display], queue=True) | |
send_btn.click(chat_interface, inputs=[txt, state], outputs=[txt, chatbot, table_display], queue=True) | |
demo.launch() | |