beyoru's picture
Update app.py
f0d82b5 verified
raw
history blame
8.22 kB
import duckdb
import pandas as pd
from tabulate import tabulate
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from sentence_transformers import SentenceTransformer, util
import threading
# --------------------------
# Setup: Load data and models
# --------------------------
# Load dữ liệu từ file Excel vào DataFrame
df = pd.read_excel("mau_bao_cao.xlsx")
# (Tùy chọn) Tạo bảng trong DuckDB nếu cần
conn = duckdb.connect('mau_bao_cao.db')
conn.execute("""\
CREATE TABLE IF NOT EXISTS production_data AS
SELECT * FROM read_xlsx('mau_bao_cao.xlsx');
""")
# Load mô hình embedding để tính toán embedding cho cột và dòng dữ liệu
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct")
column_names = df.columns.tolist()
column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True)
row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1)
row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True)
# Load mô hình Qwen và tokenizer cho việc tạo phản hồi
fc_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', torch_dtype=torch.float16, device_map="auto")
fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct')
# --------------------------
# Helper function: Trích xuất bảng dữ liệu
# --------------------------
def extract_table(user_query: str):
"""
Dựa trên câu truy vấn của người dùng:
- Tính embedding cho câu truy vấn.
- Lấy top k cột và top m dòng phù hợp.
- Trả về DataFrame đã lọc, danh sách tên cột và bảng dạng text.
"""
# Embed câu truy vấn
question_embedding = embedding_model.encode(user_query, convert_to_tensor=True)
# Lấy top 3 cột phù hợp
k = 3
column_similarities = util.cos_sim(question_embedding, column_embeddings)[0]
best_column_indices = torch.topk(column_similarities, k).indices.tolist()
best_column_names = [column_names[i] for i in best_column_indices]
# Lấy top 10 dòng phù hợp
row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0)
m = 10
best_row_indices = torch.topk(row_similarities, m).indices.tolist()
filtered_df = df.iloc[best_row_indices][best_column_names]
# Tạo bảng text (dùng cho prompt cho mô hình)
table_text = tabulate(filtered_df, headers=best_column_names, tablefmt="grid")
return filtered_df, best_column_names, table_text
# --------------------------
# Hàm streaming tạo phản hồi từ mô hình
# --------------------------
def generate_response(user_query: str):
"""
Hàm generator để:
- Trích xuất bảng dữ liệu dựa trên câu truy vấn.
- Tạo system prompt dựa trên bảng dữ liệu đã trích xuất.
- Dùng TextIteratorStreamer để tạo phản hồi theo thời gian thực.
Yields (trả về) phản hồi được cập nhật theo từng token.
"""
# Lấy bảng dữ liệu liên quan
filtered_df, best_column_names, table_text = extract_table(user_query)
# Tạo system prompt có chứa thông tin bảng dữ liệu
system_prompt = f"""\
Bạn là một trợ lý báo cáo sản xuất thông minh, chuyên phân tích và tổng hợp dữ liệu một cách rõ ràng, dễ hiểu.
**_Chỉ báo cáo nếu người dùng yêu cầu mà nếu không thì cứ giao tiếp bình thường với họ._**
Dưới đây là dữ liệu bạn cần phân tích:
🔹 Các cột dữ liệu liên quan: {', '.join(best_column_names)}
🔹 Bảng dữ liệu:
{table_output}
📌 Nhiệm vụ của bạn:
Tóm tắt số liệu quan trọng, tránh liệt kê máy móc.
Nhận xét về xu hướng và điểm bất thường.
Nếu có thể, đề xuất giải pháp hoặc hành động tiếp theo.
📊 Cách trả lời:
✔️ Tự nhiên, dễ hiểu, không quá cứng nhắc.
✔️ Không cần nhắc lại bảng dữ liệu, hãy diễn giải nó.
✔️ Trả lời đúng trọng tâm, không dư thừa.
✔️ Nếu người dùng không hỏi về bảng dữ liệu, hãy chỉ giao tiếp bình thường.
✔️ Mô hình hóa dữ câu trả lời nếu cần thiết, giúp người dùng dễ hiểu hơn về câu trả lời.
Ví dụ:
🔹 "Hôm nay, sản lượng đạt 95%, cao hơn 5% so với tuần trước."
⚠️ "Dây chuyền A đang giảm hiệu suất, cần theo dõi thêm."
🚀 "Nếu duy trì tốc độ này, sản lượng tháng có thể vượt kế hoạch 10%."
🚀 "Không có gì nếu bạn cần thêm thông tin chi tiết hãy nói cho tôi biết nhé ;))"
Bạn đã sẵn sàng phân tích và đưa ra báo cáo!
"""
messages = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_query}
]
response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
response_inputs = fc_tokenizer(response_template, return_tensors="pt").to(fc_model.device)
# Sử dụng TextIteratorStreamer để stream phản hồi
streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True)
# Khởi chạy generation trong một thread riêng
thread = threading.Thread(
target=lambda: fc_model.generate(
**response_inputs,
max_new_tokens=512,
temperature=1,
top_p=0.95,
streamer=streamer
)
)
thread.start()
collected_text = ""
for new_text in streamer:
collected_text += new_text
yield collected_text
# --------------------------
# Hàm giao diện chat của Gradio
# --------------------------
def chat_interface(user_message, history):
"""
Generator cho giao diện chat:
- Cập nhật lịch sử cuộc trò chuyện với tin nhắn của người dùng.
- Tính toán bảng dữ liệu dựa trên truy vấn và cập nhật component hiển thị bảng.
- Stream phản hồi của mô hình theo thời gian thực.
Lịch sử cuộc trò chuyện được duy trì dưới dạng danh sách các cặp [tin nhắn người dùng, phản hồi AI].
Hàm trả về 3 giá trị: giá trị cho Textbox (reset), lịch sử chat, và bảng dữ liệu (dạng DataFrame).
"""
# Trích xuất bảng để hiển thị cho người dùng
filtered_df, _, _ = extract_table(user_message)
# Thêm một cặp tin nhắn mới với phản hồi AI ban đầu là chuỗi rỗng.
history.append([user_message, ""])
# Yield trạng thái ban đầu: clear textbox, lịch sử chat cập nhật, và bảng dữ liệu đã trích xuất.
yield "", history, filtered_df
# Stream phản hồi từ mô hình theo thời gian thực
for partial_response in generate_response(user_message):
history[-1][1] = partial_response
yield "", history, filtered_df
# --------------------------
# Xây dựng giao diện Gradio
# --------------------------
with gr.Blocks() as demo:
gr.Markdown("## Giao diện Chat với Streaming và Hiển thị Bảng Dữ liệu")
chatbot = gr.Chatbot()
state = gr.State([]) # duy trì lịch sử chat dưới dạng danh sách các cặp
table_display = gr.Dataframe(label="Bảng dữ liệu liên quan") # hiển thị bảng dữ liệu cho người dùng
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False)
send_btn = gr.Button("Gửi")
# Cả submit của Textbox và click của nút gửi đều kích hoạt hàm chat_interface,
# trả về 3 outputs: textbox, chatbot và table_display.
txt.submit(chat_interface, inputs=[txt, state], outputs=[txt, chatbot, table_display], queue=True)
send_btn.click(chat_interface, inputs=[txt, state], outputs=[txt, chatbot, table_display], queue=True)
demo.launch()