Spaces:
Sleeping
Sleeping
import duckdb | |
import pandas as pd | |
from tabulate import tabulate | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from sentence_transformers import SentenceTransformer, util | |
import threading | |
# -------------------------- | |
# Setup: Load data and models | |
# -------------------------- | |
df = pd.read_excel("mau_bao_cao.xlsx") | |
conn = duckdb.connect('mau_bao_cao.db') | |
conn.execute("""\ | |
CREATE TABLE IF NOT EXISTS production_data AS | |
SELECT * FROM read_xlsx('mau_bao_cao.xlsx'); | |
""") | |
# Load embedding model for computing embeddings. | |
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct") | |
column_names = df.columns.tolist() | |
column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True) | |
row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1) | |
row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True) | |
# Load Qwen model and tokenizer for conversational generation. | |
fc_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', torch_dtype=torch.float16, device_map="auto") | |
fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct') | |
# -------------------------- | |
# Define the streaming generation function | |
# -------------------------- | |
def generate_response(user_query: str): | |
""" | |
A generator function that: | |
1. Embeds the query. | |
2. Selects top matching columns and rows from the data. | |
3. Prepares a system prompt with the extracted table. | |
4. Uses TextIteratorStreamer to stream the generated response. | |
Yields the partial generated text as it is updated. | |
""" | |
# 1. Embed the user query. | |
question_embedding = embedding_model.encode(user_query, convert_to_tensor=True) | |
# 2. Find best matching columns (top 10). | |
k = 3 | |
column_similarities = util.cos_sim(question_embedding, column_embeddings)[0] | |
best_column_indices = torch.topk(column_similarities, k).indices.tolist() | |
best_column_names = [column_names[i] for i in best_column_indices] | |
# 3. Select top matching rows (top 10). | |
row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0) | |
m = 10 | |
best_row_indices = torch.topk(row_similarities, m).indices.tolist() | |
filtered_df = df.iloc[best_row_indices][best_column_names] | |
# 4. Format the filtered data as a table. | |
table_output = tabulate(filtered_df, headers=best_column_names, tablefmt="grid") | |
# 5. Build the system prompt. | |
system_prompt = f"""\ | |
Bạn là một trợ lý báo cáo sản xuất thông minh, chuyên phân tích và tổng hợp dữ liệu một cách rõ ràng, dễ hiểu. | |
Dưới đây là dữ liệu bạn cần phân tích: | |
🔹 Các cột dữ liệu liên quan: {', '.join(best_column_names)} | |
🔹 Bảng dữ liệu: | |
{table_output} | |
📌 Nhiệm vụ của bạn: | |
Tóm tắt số liệu quan trọng, tránh liệt kê máy móc. | |
Nhận xét về xu hướng và điểm bất thường. | |
Nếu có thể, đề xuất giải pháp hoặc hành động tiếp theo. | |
📊 Cách trả lời: | |
✔️ Tự nhiên, dễ hiểu, không quá cứng nhắc. | |
✔️ Không cần nhắc lại bảng dữ liệu, hãy diễn giải nó. | |
✔️ Trả lời đúng trọng tâm, không dư thừa. | |
Ví dụ: | |
🔹 "Hôm nay, sản lượng đạt 95%, cao hơn 5% so với tuần trước." | |
⚠️ "Dây chuyền A đang giảm hiệu suất, cần theo dõi thêm." | |
🚀 "Nếu duy trì tốc độ này, sản lượng tháng có thể vượt kế hoạch 10%." | |
Bạn đã sẵn sàng phân tích và đưa ra báo cáo! | |
""" | |
# 6. Create the conversation messages. | |
messages = [ | |
{'role': 'system', 'content': system_prompt}, | |
{'role': 'user', 'content': user_query} | |
] | |
# 7. Prepare the prompt for the model. | |
response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
response_inputs = fc_tokenizer(response_template, return_tensors="pt").to(fc_model.device) | |
# 8. Use TextIteratorStreamer to yield tokens as they are generated. | |
streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Start generation in a separate thread so we can yield tokens as they arrive. | |
thread = threading.Thread( | |
target=lambda: fc_model.generate( | |
**response_inputs, | |
max_new_tokens=512, | |
temperature=1, | |
top_p=0.95, | |
streamer=streamer | |
) | |
) | |
thread.start() | |
# 9. Yield tokens incrementally. | |
collected_text = "" | |
for new_text in streamer: | |
collected_text += new_text | |
yield collected_text | |
# -------------------------- | |
# Build the Gradio conversation interface | |
# -------------------------- | |
def chat_interface(user_message, history): | |
""" | |
A generator function for Gradio that: | |
- Updates the conversation history with the user message. | |
- Streams the model's response token-by-token in real time. | |
The history is maintained as a list of pairs [user_message, bot_response]. | |
""" | |
# Create a new conversation entry with user message and an empty bot response. | |
history.append([user_message, ""]) | |
# Yield the initial state. | |
yield "", history | |
# Stream tokens from the generate_response generator. | |
for partial_response in generate_response(user_message): | |
# Update the latest conversation entry with the partial bot response. | |
history[-1][1] = partial_response | |
yield "", history | |
with gr.Blocks() as demo: | |
gr.Markdown("## Gradio Chat Interface with Real-Time Streaming") | |
chatbot = gr.Chatbot() | |
state = gr.State([]) # maintain conversation history as a list of pairs | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False) | |
send_btn = gr.Button("Gửi") | |
# Both submit and click trigger the chat_interface generator. | |
txt.submit(chat_interface, [txt, state], [txt, chatbot], queue=True) | |
send_btn.click(chat_interface, [txt, state], [txt, chatbot], queue=True) | |
demo.launch() | |