Spaces:
Sleeping
Sleeping
import duckdb | |
import pandas as pd | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from sentence_transformers import SentenceTransformer, util | |
import threading | |
# -------------------------- | |
# Setup: Load dữ liệu và mô hình | |
# -------------------------- | |
# Đọc dữ liệu từ file Excel vào DataFrame | |
df = pd.read_excel("mau_bao_cao.xlsx") | |
# Tạo bảng production_data trong DuckDB (nếu cần) | |
conn = duckdb.connect('mau_bao_cao.db') | |
conn.execute("""\ | |
CREATE TABLE IF NOT EXISTS production_data AS | |
SELECT * FROM read_xlsx('mau_bao_cao.xlsx'); | |
""") | |
# Lấy mẫu bảng production_data để hiển thị (ở đây dùng 10 dòng đầu) | |
production_data_df = df.head(10) | |
# Load mô hình embedding để tính embedding cho cột và dòng dữ liệu | |
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct") | |
column_names = df.columns.tolist() | |
column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True) | |
row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1) | |
row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True) | |
# Load mô hình Qwen và tokenizer cho việc tạo phản hồi | |
fc_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', torch_dtype=torch.float16) | |
#.to("cuda") | |
fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct') | |
# -------------------------- | |
# Hàm tạo phản hồi streaming theo thời gian thực | |
# -------------------------- | |
def generate_response(user_query: str, history): | |
""" | |
Hàm này sẽ: | |
- Sử dụng 2 cuộc đối thoại gần nhất từ history để tính embedding. | |
- Dựa trên embedding này, chọn ra top 7 cột và top 10 dòng phù hợp. | |
- Nạp lịch sử (ví dụ 10 lượt đối thoại gần nhất) vào messages để mô hình có "ký ức". | |
- Sử dụng TextIteratorStreamer để stream phản hồi từ mô hình. | |
""" | |
# --- Phần tính embedding chỉ dùng 2 cuộc đối thoại gần nhất --- | |
num_exchanges_for_embedding = 1 | |
embedding_history = history[-num_exchanges_for_embedding:] if len(history) >= num_exchanges_for_embedding else history | |
# Ghép các lượt đối thoại (chỉ những lượt đã có phản hồi) thành chuỗi context | |
conversation_context = " ".join( | |
[f"User: {turn[0]} Assistant: {turn[1]}" for turn in embedding_history if turn[1]] | |
) | |
if conversation_context.strip() == "": | |
conversation_context = user_query | |
# Tính embedding cho context | |
context_embedding = embedding_model.encode(conversation_context, convert_to_tensor=True) | |
# --- Chọn dữ liệu từ DataFrame dựa trên embedding --- | |
# Chọn top 7 cột phù hợp | |
k = 10 | |
column_similarities = util.cos_sim(context_embedding, column_embeddings)[0] | |
best_column_indices = torch.topk(column_similarities, k).indices.tolist() | |
best_column_names = [column_names[i] for i in best_column_indices] | |
# Chọn top 10 dòng phù hợp | |
row_similarities = util.cos_sim(context_embedding, row_embeddings).squeeze(0) | |
m = 10 | |
best_row_indices = torch.topk(row_similarities, m).indices.tolist() | |
filtered_df = df.iloc[best_row_indices][best_column_names] | |
# Format bảng dữ liệu dùng tabulate | |
from tabulate import tabulate | |
table_text = tabulate(filtered_df, headers=best_column_names, tablefmt="grid") | |
# --- Tạo system prompt chứa thông tin bảng dữ liệu --- | |
system_prompt = f"""\ | |
**Notes: Always respond in Vietnamese** | |
Bạn là một trợ lý báo cáo sản xuất thông minh đồng thời là một người bạn thân thiện. | |
**Chỉ báo cáo về bảng dưới đây nếu người dùng yêu cầu, nếu không thì cứ giao tiếp tự nhiên và đừng đề cập gì đến bảng.** | |
Dưới đây là dữ liệu bạn cần phân tích và tổng hợp: | |
🔹 Các cột dữ liệu liên quan: {', '.join(best_column_names)} | |
🔹 Bảng dữ liệu: | |
{table_text} | |
## 📌 Nhiệm vụ của bạn: | |
Tóm tắt số liệu quan trọng, tránh liệt kê máy móc. | |
Nhận xét về xu hướng và điểm bất thường. | |
Nếu có thể, đề xuất giải pháp hoặc hành động tiếp theo. | |
## 📊 Cách trả lời: | |
✔️ Tự nhiên, dễ hiểu, không quá cứng nhắc. | |
✔️ Không cần nhắc lại bảng dữ liệu, hãy diễn giải nó. | |
✔️ Trả lời đúng trọng tâm, không dư thừa. | |
✔️ Nếu người dùng không hỏi về bảng dữ liệu, hãy chỉ giao tiếp bình thường. | |
✔️ Mô hình hóa câu trả lời nếu cần thiết, giúp người dùng dễ hiểu hơn. | |
## Một vài ví dụ: | |
🔹 "Hôm nay, sản lượng đạt 95%, cao hơn 5% so với tuần trước." | |
⚠️ "Dây chuyền A đang giảm hiệu suất, cần theo dõi thêm." | |
🚀 "Nếu duy trì tốc độ này, sản lượng tháng có thể vượt kế hoạch 10%." | |
🚀 "Không có gì, nếu bạn cần thêm thông tin chi tiết hãy nói cho tôi biết nhé." | |
""" | |
print(table_text) | |
# --- Nạp lịch sử đối thoại vào messages để mô hình có "ký ức" --- | |
num_exchanges_for_messages = 10 | |
messages_history = history[-num_exchanges_for_messages:] if len(history) > num_exchanges_for_messages else history | |
messages = [{'role': 'system', 'content': system_prompt}] | |
for turn in messages_history[:-1]: | |
messages.append({'role': 'user', 'content': turn[0]}) | |
messages.append({'role': 'assistant', 'content': turn[1]}) | |
# Thêm lượt hiện tại (chỉ tin nhắn của user, chưa có phản hồi) | |
messages.append({'role': 'user', 'content': messages_history[-1][0]}) | |
response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
response_inputs = fc_tokenizer(response_template, return_tensors="pt") | |
#.to("cuda") | |
# --- Stream phản hồi từ mô hình --- | |
streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True) | |
thread = threading.Thread( | |
target=lambda: fc_model.generate( | |
**response_inputs, | |
max_new_tokens=512, | |
temperature=1, | |
top_p=0.95, | |
streamer=streamer | |
) | |
) | |
thread.start() | |
collected_text = "" | |
for new_text in streamer: | |
collected_text += new_text | |
yield collected_text | |
def chat_interface(user_message, history): | |
""" | |
Hàm này: | |
- Thêm tin nhắn mới của người dùng vào history. | |
- Gọi generate_response với history (nạp cả lịch sử vào messages và dùng 2 lượt đối thoại gần nhất cho embedding). | |
- Stream phản hồi từ mô hình và cập nhật history. | |
""" | |
history.append([user_message, ""]) | |
yield "", history | |
for partial_response in generate_response(user_message, history): | |
history[-1][1] = partial_response | |
yield "", history | |
# -------------------------- | |
# Xây dựng giao diện Gradio với 2 tab: Chat và Production Data Sample | |
# -------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## DEMO darft 1") | |
with gr.Tabs(): | |
with gr.TabItem("Chat"): | |
chatbot = gr.Chatbot() | |
state = gr.State([]) | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False) | |
send_btn = gr.Button("Gửi") | |
txt.submit(chat_interface, inputs=[txt, state], outputs=[txt, chatbot], queue=True) | |
send_btn.click(chat_interface, inputs=[txt, state], outputs=[txt, chatbot], queue=True) | |
with gr.TabItem("Production Data Sample"): | |
# gr.Markdown("Dưới đây là bảng **production_data** mẫu:") | |
production_table = gr.Dataframe(value=production_data_df, label="Production Data Sample") | |
demo.launch(debug=True) | |